In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import torch
import torchvision.transforms as T
from shutil import rmtree
from sklearn.model_selection import train_test_split
from tqdm.autonotebook import tqdm, trange
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.datasets import ImageFolder


from predictors.alexnet import Alexnet
from predictors.half_alexnet import HalfAlexnet

from datasets import CIFAR10, ProxyDataset

%matplotlib inline

In [3]:
LR = 0.001
EPOCHS = 200
BATCH_SIZE = 32
CONFIDENCE_TH = 0.8

DATASET_PATH = '.'

# Set random seed for replicating testing results
RANDOM_SEED = 0
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
# Teacher
teacher_model = Alexnet(name=None, n_outputs=10)

ckpt_path = 'checkpoints/teacher_alexnet_for_cifar10_state_dict'
teacher_model.load_state_dict(torch.load(ckpt_path, map_location=device))
teacher_model.eval()
teacher_model = teacher_model.to(device)

In [5]:
# Define dataset
label_mapper = {
    'airplane': 0,
    'automobile': 1,
    'bird': 2,
    'cat': 3,
    'deer': 4,
    'dog': 5,
    'frog': 6,
    'horse': 7,
    'ship': 8,
    'truck': 9
}
label_mapper_inv = {v:k for k,v in label_mapper.items()}

# Get images paths and labels
images = []
labels = []

folders_path = os.path.join(DATASET_PATH,'images_generated')
for folder in os.listdir(folders_path):
    if 'imagenet' in folder:
      continue
    class_path = os.path.join(folders_path,folder)
    images_names = os.listdir(class_path)
    
    for image_name in images_names:
        images.append(os.path.join(class_path,image_name))
        labels.append(label_mapper[folder])

In [6]:
proxy_transforms = T.Compose([
    T.Resize((32,32)),
    T.Normalize((0.5,), (0.5,))
])

proxy_dataset = ProxyDataset(images, labels, proxy_transforms, True)
proxy_dataloader  = DataLoader(proxy_dataset,  batch_size=BATCH_SIZE)

# Obtain teacher predictions
filtered_images = []
filtered_labels = []
filtered_soft_labels = []

with torch.no_grad():
    for img,label,path,_ in tqdm(proxy_dataloader):
        img = img.to(device=device)
        label = label.to(device=device)

        logits = teacher_model(img)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)

        filtered_images.extend(list(path))
        filtered_labels.extend(y_hat.tolist())
        filtered_soft_labels.extend(pred)
        
# Display results of filtering
print(f'A total of {len(filtered_images)} remained out of {len(proxy_dataset)}')
print()
counter_per_class = {v:0 for k,v in label_mapper.items()}
for label in filtered_labels:
    counter_per_class[label] += 1
for clasa in counter_per_class:
    print(f'Class {clasa}({label_mapper_inv[clasa]}) has {counter_per_class[clasa]} entries')

  0%|          | 0/1600 [00:00<?, ?it/s]

A total of 51200 remained out of 51200

Class 0(airplane) has 6384 entries
Class 1(automobile) has 5145 entries
Class 2(bird) has 4796 entries
Class 3(cat) has 5329 entries
Class 4(deer) has 5571 entries
Class 5(dog) has 4287 entries
Class 6(frog) has 4553 entries
Class 7(horse) has 5025 entries
Class 8(ship) has 5178 entries
Class 9(truck) has 4932 entries


In [7]:
# Define the student model
student_model = HalfAlexnet(name=None, n_outputs=10)

path_to_save = 'pretrained_student.pt'
if torch.cuda.is_available():
    student_model.load_state_dict(torch.load(path_to_save))
else:
    student_model.load_state_dict(torch.load(path_to_save, map_location ='cpu'))
student_model.to(device)

# Define optimizer
optimizer = torch.optim.Adam(student_model.parameters(), lr=LR)

# Define loss function
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

In [8]:
try:
  # raise
  # Do a stratified split of the data
  NO_IMGS_TO_USE = 1024
  
  filtered_images_subset, filtered_images_unused, filtered_labels_subset, filtered_labels_unused, filtered_soft_labels_subset, filtered_soft_labels_unused = \
      train_test_split(filtered_images, filtered_labels, filtered_soft_labels, train_size=NO_IMGS_TO_USE, stratify=filtered_labels, random_state=RANDOM_SEED)  

  train_images, validation_images, train_labels, validation_labels, train_soft_labels, validation_soft_labels = \
      train_test_split(filtered_images_subset, filtered_labels_subset, filtered_soft_labels_subset, train_size=0.8, stratify=filtered_labels_subset, random_state=RANDOM_SEED)
  valid_images, test_images, valid_labels, test_labels, valid_soft_labels, test_soft_labels = \
      train_test_split(validation_images, validation_labels, validation_soft_labels, test_size=0.5, stratify=validation_labels, random_state=RANDOM_SEED)
except:
  # For small datasets 
  print('Small dataset')
  NO_IMGS_TO_USE = 30
  NO_TRAIN_IMGS = 10

  """ DEV - train_test_split fails, need to manually introduce 1 sample
  """
  for i in range(len(filtered_labels)-1,-1,-1):
    if filtered_labels[i] == 5:
      img_dev = filtered_images[i]
      soft_label_dev = filtered_soft_labels[i]
      break

  filtered_images_subset, filtered_images_unused, filtered_labels_subset, filtered_labels_unused, filtered_soft_labels_subset, filtered_soft_labels_unused = \
    train_test_split(filtered_images, filtered_labels, filtered_soft_labels, train_size=NO_IMGS_TO_USE, stratify=filtered_labels, random_state=RANDOM_SEED)  
  train_images, validation_images, train_labels, validation_labels, train_soft_labels, validation_soft_labels = \
    train_test_split(filtered_images_subset, filtered_labels_subset, filtered_soft_labels_subset, train_size=NO_TRAIN_IMGS, stratify=filtered_labels_subset, random_state=RANDOM_SEED)
  
  """ DEV - interchange an image
  """
  for i in range(len(validation_images)):
    if validation_labels[i] == 0:
      validation_images[i] = img_dev
      validation_labels[i] = 5
      validation_soft_labels[i] = soft_label_dev
      break

  valid_images, test_images, valid_labels, test_labels, valid_soft_labels, test_soft_labels = \
      train_test_split(validation_images, validation_labels, validation_soft_labels, test_size=0.5, stratify=validation_labels)

In [9]:
# Define the transformations
train_transforms = T.Compose([
    T.Resize((32,32)),
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(p=0.5),
    T.Normalize((0.5,), (0.5,))
])
valid_transforms = T.Compose([
    T.Resize((32,32)),
    T.Normalize((0.5,), (0.5,))
])

# Define the proxy datasets
proxy_train_dataset = ProxyDataset(train_images, train_labels, train_transforms, False, train_soft_labels)
proxy_valid_dataset = ProxyDataset(valid_images, valid_labels, valid_transforms, False, valid_soft_labels)
proxy_test_dataset  = ProxyDataset(test_images,  test_labels,  valid_transforms, False, test_soft_labels)

# Define the proxy dataloaders
train_dataloader = DataLoader(proxy_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(proxy_valid_dataset, batch_size=BATCH_SIZE)
test_dataloader  = DataLoader(proxy_test_dataset,  batch_size=BATCH_SIZE)

# Define true dataset
true_dataset = CIFAR10(input_size = 32)

validate_on_trueds = True
if validate_on_trueds:
    true_valid_ds = true_dataset.test_dataloader()

Files already downloaded and verified
Files already downloaded and verified


In [10]:
# Class for early stopping
class EarlyStopping():
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, validation_loss):
        if self.best_score is None:
          self.best_score = validation_loss
        elif validation_loss - self.best_score < self.min_delta:
          self.best_score = validation_loss
        else:
          self.counter +=1
          if self.counter >= self.tolerance:  
              self.early_stop = True

In [11]:
early_stopping = EarlyStopping(tolerance=20, min_delta=0.001)
# early_stopping = EarlyStopping(tolerance=5, min_delta=0.001)

# Training the student
# for epoch in range(EPOCHS):
for epoch in range(3):
    # Define progress bar
    loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    
    # Training loop
    student_model.train()
    training_loss_epoch = []
    for batch_idx, (x,y,soft_y) in loop:
        optimizer.zero_grad()
        
        x = x.to(device=device)
        y = y.to(device=device)
        soft_y = soft_y.to(device=device)
        
        # Forward pass
        logits = student_model(x)
        # Backward pass
        loss = criterion(input=logits, target=soft_y)
        training_loss_epoch.append(loss.item())
        # Prob trb inlocuit criterionul, sa adaugi soft-labels
        loss.backward()
        
        # Optimize
        optimizer.step()
        
        # Update progress bar
        loop.set_description(f'Epoch {epoch+1}/{EPOCHS}')
        loop.set_postfix(training_loss=loss.item())
    
    # Validation loop on proxy validation dataset
    student_model.eval()
    validation_loss_epoch = []  
    acc = 0
    with torch.no_grad():
        for x,y,_ in valid_dataloader:
            x = x.to(device=device)
            y = y.to(device=device)
        
            logits = student_model(x)
            pred = softmax(logits, dim=1)
            
            confidence,y_hat = torch.max(pred, dim=1)
            
            loss = criterion(input=logits, target=y)
            validation_loss_epoch.append(loss.item())
            
            acc += torch.sum(y_hat==y).item()
        
    loop.write(f'validation_loss on proxy = {sum(validation_loss_epoch)/len(validation_loss_epoch):.4f}')
    loop.write(f'validation_accuracy on proxy = {100*acc/len(proxy_valid_dataset):.2f}%')

    if validate_on_trueds:
        # Validation loop on proxy validation dataset
        student_model.eval()
        with torch.no_grad():
            val_loss = []
            acc = 0
            for x,y in true_valid_ds:
                x = x.to(device=device)
                y = y.to(device=device)
            
                logits = student_model(x)
                pred = softmax(logits, dim=1)
                
                confidence,y_hat = torch.max(pred, dim=1)
                
                loss = criterion(input=logits, target=y)
                val_loss.append(loss.item())
                
                acc += torch.sum(y_hat==y).item()
            
        loop.write(f'validation_loss on true ds = {sum(val_loss)/len(val_loss):.4f}')
        loop.write(f'validation_accuracy on true ds = {100*acc/len(true_dataset.test_dataset):.2f}%')

    early_stopping(sum(validation_loss_epoch)/len(validation_loss_epoch))
    if early_stopping.early_stop:
      print(f"We are at epoch {epoch}")
      break

  0%|          | 0/26 [00:00<?, ?it/s]

validation_loss on proxy = 1.8052
validation_accuracy on proxy = 51.96%
validation_loss on true ds = 2.0238
validation_accuracy on true ds = 30.15%


  0%|          | 0/26 [00:00<?, ?it/s]

validation_loss on proxy = 1.4215
validation_accuracy on proxy = 66.67%
validation_loss on true ds = 1.8572
validation_accuracy on true ds = 36.52%


  0%|          | 0/26 [00:00<?, ?it/s]

validation_loss on proxy = 1.1609
validation_accuracy on proxy = 70.59%
validation_loss on true ds = 1.7509
validation_accuracy on true ds = 38.74%


In [12]:
# Testing on CIFAR10 ground truth
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
student_model.return_feature_domain = False
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        y = y.to(device=device)
    
        logits = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            if a == p:
                acc_per_class[a][0] += 1 # correct predictions
            acc_per_class[a][1] += 1     # total number 
            acc_per_class[p][2] += 1     # predictions of class
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: correct_pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 1.7509
test_accuracy = 38.74%

Class airplane: correct_pred=425, actual=1000 => acc=42.50%, total_pred=722
Class automobile: correct_pred=556, actual=1000 => acc=55.60%, total_pred=791
Class bird: correct_pred=59, actual=1000 => acc=5.90%, total_pred=120
Class cat: correct_pred=825, actual=1000 => acc=82.50%, total_pred=3951
Class deer: correct_pred=584, actual=1000 => acc=58.40%, total_pred=1511
Class dog: correct_pred=136, actual=1000 => acc=13.60%, total_pred=614
Class frog: correct_pred=191, actual=1000 => acc=19.10%, total_pred=372
Class horse: correct_pred=176, actual=1000 => acc=17.60%, total_pred=348
Class ship: correct_pred=517, actual=1000 => acc=51.70%, total_pred=924
Class truck: correct_pred=405, actual=1000 => acc=40.50%, total_pred=647


In [13]:
# Testing using labels predicted with teacher
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
teacher_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        # y = y.to(device=device)

        teacher_pred = softmax(teacher_model(x), dim=1)
        _, y = torch.max(teacher_pred, dim=1)
    
        logits = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            if a == p:
                acc_per_class[a][0] += 1
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: correct_pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 1.7555
test_accuracy = 38.66%

Class airplane: correct_pred=441, actual=1038 => acc=42.49%, total_pred=722
Class automobile: correct_pred=552, actual=1002 => acc=55.09%, total_pred=791
Class bird: correct_pred=57, actual=966 => acc=5.90%, total_pred=120
Class cat: correct_pred=793, actual=938 => acc=84.54%, total_pred=3951
Class deer: correct_pred=605, actual=1109 => acc=54.55%, total_pred=1511
Class dog: correct_pred=134, actual=913 => acc=14.68%, total_pred=614
Class frog: correct_pred=187, actual=995 => acc=18.79%, total_pred=372
Class horse: correct_pred=180, actual=1063 => acc=16.93%, total_pred=348
Class ship: correct_pred=497, actual=939 => acc=52.93%, total_pred=924
Class truck: correct_pred=420, actual=1037 => acc=40.50%, total_pred=647


In [14]:
# Redefine training dataset and dataloader with no augmentation
proxy_train_dataset = ProxyDataset(train_images, train_labels, valid_transforms, False, train_soft_labels)
train_dataloader = DataLoader(proxy_train_dataset, batch_size=BATCH_SIZE, shuffle=True)


student_model.return_feature_domain = True

if os.path.exists('images_db'):
  rmtree('images_db')
db_path = 'images_db'
images_path = os.path.join(db_path,'images')
labels_path = os.path.join(db_path,'labels')

images_db = []
labels_db = []

if not os.path.exists('images_db'):
  os.makedirs(db_path)
  os.makedirs(images_path)
  os.makedirs(labels_path)

  for i in range(10):
    os.makedirs(os.path.join(images_path,f'class{i}'))
    os.makedirs(os.path.join(labels_path,f'class{i}'))

# Create the database
student_model.eval()
with torch.no_grad():
  # for dataloader in [train_dataloader, valid_dataloader, test_dataloader]:
  for dataloader in [train_dataloader]:
    for x,y,soft_y in dataloader:
      x = x.to(device=device)

      _,latent_fm = student_model(x)

      for i in range(latent_fm.shape[0]):
        # label = y[i].item()
        # path_image = f'{images_path}/class{label}'
        # path_label = f'{labels_path}/class{label}'
        # idx = len(os.listdir(path_image))
        # torch.save(latent_fm[i].cpu(), f'{path_image}/t{idx}.pt')
        # torch.save(soft_y[i].cpu(), f'{path_label}/t{idx}.pt')

        images_db.append(latent_fm[i])
        labels_db.append(soft_y[i])

In [15]:
from collections import Counter

sorted(Counter([vlad.argmax().item() for vlad in labels_db]).items())

[(0, 102),
 (1, 82),
 (2, 77),
 (3, 86),
 (4, 89),
 (5, 69),
 (6, 73),
 (7, 80),
 (8, 82),
 (9, 79)]

In [16]:
class DBDataset(Dataset):
  def __init__(self, images_db, labels_db):
    self.images_db = images_db
    self.labels_db = labels_db

  def __len__(self):
    return len(self.images_db)

  def __getitem__(self, idx):
    image = self.images_db[idx]
    label = self.labels_db[idx]

    return image, label

db_dataset = DBDataset(images_db=images_db, labels_db=labels_db)
db_dataloader = DataLoader(db_dataset, batch_size=128, shuffle=False)

In [17]:
proxy_unused_dataset  = ProxyDataset(filtered_images_unused,  filtered_labels_unused,  valid_transforms, True, filtered_soft_labels_unused)

In [None]:
estimations = []
ground_truth = []

student_model.eval()
with torch.no_grad():
  for i in trange(3,len(proxy_unused_dataset)):
    image, label, image_path, soft_label = proxy_unused_dataset[i]
    image = image.to(device=device)

    _,latent_fm = student_model(image.unsqueeze(dim=0))

    norm_unkwn = torch.sqrt(torch.sum(torch.square(latent_fm), dim=(-1,-2,-3), keepdim=True))

    # Calculate distances
    distances = torch.zeros(size=(len(db_dataset),), device=device)
    for j, (img_db, softl_db) in enumerate(db_dataloader):
      img_db = img_db.to(device=device)

      norm_db = torch.sqrt(torch.sum(torch.square(img_db), dim=(-1,-2,-3), keepdim=True))

      dot_prod = torch.tensordot(img_db, latent_fm, dims=[[-1,-2,-3],[-1,-2,-3]])
      denominator = (norm_db * norm_unkwn).squeeze(dim=-1).squeeze(dim=-1)
      cosine_similarities = (dot_prod / denominator).squeeze(dim=-1)
      
      distances[j*128:(j+1)*128] = cosine_similarities
      # distances.extend(cosine_similarities.cpu().tolist())
      
    smallest_d, smallest_d_indices = torch.topk(distances, k=3, largest=True)
    closest_slabels = torch.vstack([filtered_soft_labels_unused[l] for l in smallest_d_indices])
    
    estimated_soft_label = closest_slabels.mean(0)
    estimated_hard_label = estimated_soft_label.argmax().item()
    
    # print(f'Ground truth soft_label = {soft_label}')
    # print(f'Estimated soft_label = {estimated_soft_label}')
    # print(f'Ground truth hard_label = {label}')
    # print(f'Estimated hard_label = {estimated_hard_label}')
    
    # print(f'Smallest 3 distances: {smallest_d}')
    # print(f'Closest 3 slabels = {closest_slabels}')
    
    train_images.append(image_path)
    train_labels.append(estimated_hard_label)
    train_soft_labels.append(estimated_soft_label)

    estimations.append(estimated_hard_label)
    ground_truth.append(label)

    if i % 1000 == 0 and i != 0:
      # Display results
      correct = (np.array(estimations) == np.array(ground_truth)).sum()
      print(f'{correct} / {len(estimations)}')

correct = (np.array(estimations) == np.array(ground_truth)).sum()
print(f'Final score: {correct} / {len(estimations)}')

In [78]:
correct = (np.array(estimations) == np.array(ground_truth)).sum()
print(f'Final score: {correct} / {len(estimations)}')

Final score: 305 / 2243


In [20]:
student_model.return_feature_domain = False

proxy_train_dataset = ProxyDataset(train_images, train_labels, train_transforms, False, train_soft_labels)
train_dataloader = DataLoader(proxy_train_dataset, batch_size=BATCH_SIZE, shuffle=True)

early_stopping = EarlyStopping(tolerance=20, min_delta=0.001)
# early_stopping = EarlyStopping(tolerance=5, min_delta=0.001)

# Training the student
for epoch in range(EPOCHS):
    # Define progress bar
    loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    
    # Training loop
    student_model.train()
    training_loss_epoch = []
    for batch_idx, (x,y,soft_y) in loop:
        optimizer.zero_grad()
        
        x = x.to(device=device)
        y = y.to(device=device)
        soft_y = soft_y.to(device=device)
        
        # Forward pass
        logits = student_model(x)
        # Backward pass
        loss = criterion(input=logits, target=y)
        # loss = criterion(input=logits, target=soft_y)
        training_loss_epoch.append(loss.item())
        # Prob trb inlocuit criterionul, sa adaugi soft-labels
        loss.backward()
        
        # Optimize
        optimizer.step()
        
        # Update progress bar
        loop.set_description(f'Epoch {epoch+1}/{EPOCHS}')
        loop.set_postfix(training_loss=loss.item())
    
    # Validation loop on proxy validation dataset
    student_model.eval()
    validation_loss_epoch = []  
    acc = 0
    with torch.no_grad():
        for x,y,_ in valid_dataloader:
            x = x.to(device=device)
            y = y.to(device=device)
        
            logits = student_model(x)
            pred = softmax(logits, dim=1)
            
            confidence,y_hat = torch.max(pred, dim=1)
            
            loss = criterion(input=logits, target=y)
            validation_loss_epoch.append(loss.item())
            
            acc += torch.sum(y_hat==y).item()
        
    loop.write(f'validation_loss on proxy = {sum(validation_loss_epoch)/len(validation_loss_epoch):.4f}')
    loop.write(f'validation_accuracy on proxy = {100*acc/len(proxy_valid_dataset):.2f}%')

    if validate_on_trueds:
        # Validation loop on proxy validation dataset
        student_model.eval()
        with torch.no_grad():
            val_loss = []
            acc = 0
            for x,y in true_valid_ds:
                x = x.to(device=device)
                y = y.to(device=device)
            
                logits = student_model(x)
                pred = softmax(logits, dim=1)
                
                confidence,y_hat = torch.max(pred, dim=1)
                
                loss = criterion(input=logits, target=y)
                val_loss.append(loss.item())
                
                acc += torch.sum(y_hat==y).item()
            
        loop.write(f'validation_loss on true ds = {sum(val_loss)/len(val_loss):.4f}')
        loop.write(f'validation_accuracy on true ds = {100*acc/len(true_dataset.test_dataset):.2f}%')

    early_stopping(sum(validation_loss_epoch)/len(validation_loss_epoch))
    if early_stopping.early_stop:
      print(f"We are at epoch {epoch}")
      break

  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.4980
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.6012
validation_accuracy on true ds = 51.77%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.6084
validation_accuracy on proxy = 82.35%
validation_loss on true ds = 1.6800
validation_accuracy on true ds = 50.04%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.4988
validation_accuracy on proxy = 82.35%
validation_loss on true ds = 1.7245
validation_accuracy on true ds = 49.85%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5373
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.6481
validation_accuracy on true ds = 51.74%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5300
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.6821
validation_accuracy on true ds = 49.92%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.4932
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.6166
validation_accuracy on true ds = 52.81%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5126
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.6941
validation_accuracy on true ds = 50.08%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5213
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.7141
validation_accuracy on true ds = 50.92%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.4800
validation_accuracy on proxy = 87.25%
validation_loss on true ds = 1.7531
validation_accuracy on true ds = 50.08%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5753
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.7276
validation_accuracy on true ds = 50.83%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5849
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.7437
validation_accuracy on true ds = 52.21%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5334
validation_accuracy on proxy = 82.35%
validation_loss on true ds = 1.8100
validation_accuracy on true ds = 49.67%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5597
validation_accuracy on proxy = 82.35%
validation_loss on true ds = 1.7276
validation_accuracy on true ds = 51.47%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5665
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.8617
validation_accuracy on true ds = 49.44%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5507
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.8770
validation_accuracy on true ds = 50.38%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5355
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.7022
validation_accuracy on true ds = 52.28%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5567
validation_accuracy on proxy = 85.29%
validation_loss on true ds = 1.8041
validation_accuracy on true ds = 51.67%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5482
validation_accuracy on proxy = 86.27%
validation_loss on true ds = 1.7243
validation_accuracy on true ds = 51.58%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5064
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.9045
validation_accuracy on true ds = 50.65%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.6494
validation_accuracy on proxy = 84.31%
validation_loss on true ds = 1.8053
validation_accuracy on true ds = 52.05%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5335
validation_accuracy on proxy = 86.27%
validation_loss on true ds = 1.9247
validation_accuracy on true ds = 50.57%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5205
validation_accuracy on proxy = 85.29%
validation_loss on true ds = 1.8493
validation_accuracy on true ds = 51.04%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5281
validation_accuracy on proxy = 83.33%
validation_loss on true ds = 1.7492
validation_accuracy on true ds = 51.95%


  0%|          | 0/1594 [00:00<?, ?it/s]

validation_loss on proxy = 0.5246
validation_accuracy on proxy = 86.27%
validation_loss on true ds = 1.9533
validation_accuracy on true ds = 50.27%
We are at epoch 23


In [21]:
# Testing on CIFAR10 ground truth
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        y = y.to(device=device)
    
        logits = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
            if a == p:
                acc_per_class[a][0] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 1.9533
test_accuracy = 50.27%

Class airplane: pred=538, actual=1000 => acc=53.80%, total_pred=960
Class automobile: pred=518, actual=1000 => acc=51.80%, total_pred=647
Class bird: pred=557, actual=1000 => acc=55.70%, total_pred=1569
Class cat: pred=491, actual=1000 => acc=49.10%, total_pred=1360
Class deer: pred=547, actual=1000 => acc=54.70%, total_pred=1207
Class dog: pred=403, actual=1000 => acc=40.30%, total_pred=1131
Class frog: pred=342, actual=1000 => acc=34.20%, total_pred=463
Class horse: pred=454, actual=1000 => acc=45.40%, total_pred=961
Class ship: pred=455, actual=1000 => acc=45.50%, total_pred=581
Class truck: pred=722, actual=1000 => acc=72.20%, total_pred=1121


In [22]:
# Testing using labels predicted with teacher
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
teacher_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        # y = y.to(device=device)

        teacher_pred = softmax(teacher_model(x), dim=1)
        _, y = torch.max(teacher_pred, dim=1)
    
        logits = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
            if a == p:
                acc_per_class[a][0] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 1.9204
test_accuracy = 50.62%

Class airplane: pred=543, actual=1038 => acc=52.31%, total_pred=960
Class automobile: pred=519, actual=1002 => acc=51.80%, total_pred=647
Class bird: pred=544, actual=965 => acc=56.37%, total_pred=1569
Class cat: pred=485, actual=938 => acc=51.71%, total_pred=1360
Class deer: pred=578, actual=1108 => acc=52.17%, total_pred=1207
Class dog: pred=402, actual=913 => acc=44.03%, total_pred=1131
Class frog: pred=337, actual=996 => acc=33.84%, total_pred=463
Class horse: pred=465, actual=1063 => acc=43.74%, total_pred=961
Class ship: pred=445, actual=939 => acc=47.39%, total_pred=581
Class truck: pred=744, actual=1038 => acc=71.68%, total_pred=1121


In [23]:
# import matplotlib.pyplot as plt


# def plot_db(X_embedded, labels_cpu):
#   # scale and move the coordinates so they fit [0; 1] range
#   def scale_to_01_range(x):
#       # compute the distribution range
#       value_range = (np.max(x) - np.min(x))
  
#       # move the distribution so that it starts from zero
#       # by extracting the minimal value from all its values
#       starts_from_zero = x - np.min(x)
  
#       # make the distribution fit [0; 1] by dividing by its range
#       return starts_from_zero / value_range
  
#   # extract x and y coordinates representing the positions of the images on T-SNE plot
#   tx = X_embedded[:, 0]
#   ty = X_embedded[:, 1]
  
#   tx = scale_to_01_range(tx)
#   ty = scale_to_01_range(ty)

#   # initialize a matplotlib plot
#   fig = plt.figure(figsize=(6, 6), dpi=144)
#   ax = fig.add_subplot(111)
  
#   # for every class, we'll add a scatter plot separately
#   for label in range(10):
#       # find the samples of the current class in the data
#       indices = [i for i, l in enumerate(labels_cpu) if l == label]
  
#       # extract the coordinates of the points of this class only
#       current_tx = np.take(tx, indices)
#       current_ty = np.take(ty, indices)
  
#       # add a scatter plot with the corresponding color and label
#       ax.scatter(current_tx, current_ty, label=label)
  
#   # build a legend using the labels we set previously
#   ax.legend(loc='best')
  
#   # finally, show the plot
#   plt.show()

In [24]:
# from sklearn.manifold import TSNE


# tsne = TSNE()
# images_cpu = [im.cpu().numpy().flatten() for im in images_db]
# labels_cpu = [l.cpu().numpy().argmax() for l in labels_db]
# X_embedded = tsne.fit_transform(np.array(images_cpu), np.array(labels_cpu))

# plot_db(X_embedded, labels_cpu)