In [None]:
import numpy as np
from tqdm import tqdm
import sys

# https://github.com/Ujjwal-9/Knowledge-Distillation
sys.path.append('Knowledge-Distillation/utils/')

# Outputs x_batch, y_batch, and image names
from image_preprocessing_ver1 import ImageDataGenerator
from keras.models import Model

In [None]:
# Load data
from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
# # Make subdir for each class

# from keras.preprocessing.image import save_img, load_img, img_to_array
# import os
# from collections import defaultdict

# # Make subdir for each class
# save_path = r'C:\Users\ender\Pictures\data\cifar10\test'
# os.chdir(save_path)
# for label in np.unique(y_test):
#     os.mkdir(str(label))

In [None]:
# # Saves CIFAR10 to dir for use with ImageDataGenerator

# save_path = r'C:\Users\ender\Pictures\data\cifar10\test'
# name_dict = defaultdict(int)
# for i, img in enumerate(x_test):
#     img_array = img_to_array(img, data_format='channels_last')
#     name = os.path.join(save_path, str(y_test[i][0]), str(y_test[i][0])+'('+str(name_dict[y_test[i][0]])+')'+'.jpg')
#     name_dict[y_test[i][0]] += 1
#     save_img(name, img)

In [None]:
# Setup image pipeline

import os

data_generator = ImageDataGenerator(data_format='channels_last')

batch_size = 1

data_dir = r'C:\Users\ender\Pictures\data\cifar10'

train_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'train'), target_size=(32, 32), color_mode='rgb', shuffle=False, batch_size=batch_size)

val_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'test'), target_size=(32, 32), color_mode='rgb', shuffle=False, batch_size=batch_size)

In [None]:
# Load model and remove final softmax layer

from keras.models import load_model

model = load_model('models/cifar10_teacher.h5')

# Remove softmax layer
model.pop()

# This model outputs teacher logits
model = Model(model.input, model.layers[-1].output)

In [None]:
# Collect logits for student training

batches = 0
train_logits = {}

for x_batch, y, name_batch in tqdm(train_generator):
    batch_logits = model.predict_on_batch(x_batch)
    
    for index, name in enumerate(name_batch):
        train_logits[name] = batch_logits[index]
        
    batches += 1
    if batches > len(x_train) / batch_size:
        break

In [None]:
# Collect logits for student validation

batches = 0
val_logits = {}

for x_batch, _, name_batch in tqdm(val_generator):
    batch_logits = model.predict_on_batch(x_batch)
    
    for index, name in enumerate(name_batch):
        val_logits[name] = batch_logits[index]
        
    batches += 1
    if batches > len(x_test) / batch_size:
        break

In [None]:
# Save logits

np.save(data_dir + 'cifar10_train_logits.npy', train_logits)
np.save(data_dir + 'cifar10_val_logits.npy', val_logits)