In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import torch
import torchvision.transforms as T
from shutil import rmtree
from sklearn.model_selection import train_test_split
from tqdm.autonotebook import tqdm, trange
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.datasets import ImageFolder

from predictors.alexnet import Alexnet
from predictors.half_alexnet_distil import HalfAlexnetDistil

from datasets import CIFAR10, ProxyDataset

%matplotlib inline

In [3]:
LR = 0.001
EPOCHS = 200
BATCH_SIZE = 32
CONFIDENCE_TH = 0.8

DATASET_PATH = '.'
TEMP_BEST_MODEL = 'distil_experiments'

# Set random seed for replicating testing results
RANDOM_SEED = 0
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
# Teacher
teacher_model = Alexnet(name=None, n_outputs=10)

ckpt_path = 'checkpoints/teacher_alexnet_for_cifar10_state_dict'
teacher_model.load_state_dict(torch.load(ckpt_path, map_location=device))
teacher_model.eval()
teacher_model = teacher_model.to(device)

In [5]:
# Define dataset
label_mapper = {
    'airplane': 0,
    'automobile': 1,
    'bird': 2,
    'cat': 3,
    'deer': 4,
    'dog': 5,
    'frog': 6,
    'horse': 7,
    'ship': 8,
    'truck': 9
}
label_mapper_inv = {v:k for k,v in label_mapper.items()}

# Get images paths and labels
images = []
labels = []

folders_path = os.path.join(DATASET_PATH,'images_generated')
for folder in os.listdir(folders_path):
    if 'imagenet' in folder:
      continue
    class_path = os.path.join(folders_path,folder)
    images_names = os.listdir(class_path)
    
    for image_name in images_names:
        images.append(os.path.join(class_path,image_name))
        labels.append(label_mapper[folder])

In [6]:
proxy_transforms = T.Compose([
    T.Resize((32,32)),
    T.Normalize((0.5,), (0.5,))
])

proxy_dataset = ProxyDataset(images, labels, proxy_transforms, True)
proxy_dataloader  = DataLoader(proxy_dataset,  batch_size=BATCH_SIZE)

# Obtain teacher predictions
filtered_images = []
filtered_labels = []
filtered_soft_labels = []

with torch.no_grad():
    for img,label,path,_ in tqdm(proxy_dataloader):
        img = img.to(device=device)
        label = label.to(device=device)

        logits = teacher_model(img)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)

        filtered_images.extend(list(path))
        filtered_labels.extend(y_hat.tolist())
        filtered_soft_labels.extend(pred)
        
# Display results of filtering
print(f'A total of {len(filtered_images)} remained out of {len(proxy_dataset)}')
print()
counter_per_class = {v:0 for k,v in label_mapper.items()}
for label in filtered_labels:
    counter_per_class[label] += 1
for clasa in counter_per_class:
    print(f'Class {clasa}({label_mapper_inv[clasa]}) has {counter_per_class[clasa]} entries')

  0%|          | 0/1600 [00:00<?, ?it/s]

A total of 51200 remained out of 51200

Class 0(airplane) has 6384 entries
Class 1(automobile) has 5145 entries
Class 2(bird) has 4796 entries
Class 3(cat) has 5329 entries
Class 4(deer) has 5571 entries
Class 5(dog) has 4287 entries
Class 6(frog) has 4553 entries
Class 7(horse) has 5025 entries
Class 8(ship) has 5178 entries
Class 9(truck) has 4932 entries


In [8]:
# Define the student model
student_model = HalfAlexnetDistil(name=None, n_outputs=10)

path_to_save = 'pretrained_student_distil.pt'
if torch.cuda.is_available():
    student_model.load_state_dict(torch.load(path_to_save))
else:
    student_model.load_state_dict(torch.load(path_to_save, map_location ='cpu'))
student_model.to(device)

# Define optimizer
optimizer = torch.optim.Adam(student_model.parameters(), lr=LR)

# Define loss function
criterion = torch.nn.CrossEntropyLoss(reduction='mean')

In [9]:
try:
  # raise
  # Do a stratified split of the data
  NO_IMGS_TO_USE = 60
  
  filtered_images_subset, filtered_images_unused, filtered_labels_subset, filtered_labels_unused, filtered_soft_labels_subset, filtered_soft_labels_unused = \
      train_test_split(filtered_images, filtered_labels, filtered_soft_labels, train_size=NO_IMGS_TO_USE, stratify=filtered_labels, random_state=RANDOM_SEED)  

  train_images, validation_images, train_labels, validation_labels, train_soft_labels, validation_soft_labels = \
      train_test_split(filtered_images_subset, filtered_labels_subset, filtered_soft_labels_subset, train_size=0.8, stratify=filtered_labels_subset, random_state=RANDOM_SEED)
  valid_images, test_images, valid_labels, test_labels, valid_soft_labels, test_soft_labels = \
      train_test_split(validation_images, validation_labels, validation_soft_labels, test_size=0.5, stratify=validation_labels, random_state=RANDOM_SEED)
except:
  # For small datasets 
  print('Small dataset')
  NO_IMGS_TO_USE = 30
  NO_TRAIN_IMGS = 10

  """ DEV - train_test_split fails, need to manually introduce 1 sample
  """
  for i in range(len(filtered_labels)-1,-1,-1):
    if filtered_labels[i] == 5:
      img_dev = filtered_images[i]
      soft_label_dev = filtered_soft_labels[i]
      break

  filtered_images_subset, filtered_images_unused, filtered_labels_subset, filtered_labels_unused, filtered_soft_labels_subset, filtered_soft_labels_unused = \
    train_test_split(filtered_images, filtered_labels, filtered_soft_labels, train_size=NO_IMGS_TO_USE, stratify=filtered_labels, random_state=RANDOM_SEED)  
  train_images, validation_images, train_labels, validation_labels, train_soft_labels, validation_soft_labels = \
    train_test_split(filtered_images_subset, filtered_labels_subset, filtered_soft_labels_subset, train_size=NO_TRAIN_IMGS, stratify=filtered_labels_subset, random_state=RANDOM_SEED)
  
  """ DEV - interchange an image
  """
  for i in range(len(validation_images)):
    if validation_labels[i] == 0:
      validation_images[i] = img_dev
      validation_labels[i] = 5
      validation_soft_labels[i] = soft_label_dev
      break

  valid_images, test_images, valid_labels, test_labels, valid_soft_labels, test_soft_labels = \
      train_test_split(validation_images, validation_labels, validation_soft_labels, test_size=0.5, stratify=validation_labels)

Small dataset


In [10]:
# Define the transformations
train_transforms = T.Compose([
    T.Resize((32,32)),
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(p=0.5),
    T.Normalize((0.5,), (0.5,))
])
valid_transforms = T.Compose([
    T.Resize((32,32)),
    T.Normalize((0.5,), (0.5,))
])

# Define the proxy datasets
proxy_train_dataset = ProxyDataset(train_images, train_labels, train_transforms, False, train_soft_labels)
proxy_valid_dataset = ProxyDataset(valid_images, valid_labels, valid_transforms, False, valid_soft_labels)
proxy_test_dataset  = ProxyDataset(test_images,  test_labels,  valid_transforms, False, test_soft_labels)

# Define the proxy dataloaders
train_dataloader = DataLoader(proxy_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(proxy_valid_dataset, batch_size=BATCH_SIZE)
test_dataloader  = DataLoader(proxy_test_dataset,  batch_size=BATCH_SIZE)

# Define true dataset
true_dataset = CIFAR10(input_size = 32)

validate_on_trueds = True
if validate_on_trueds:
    true_valid_ds = true_dataset.test_dataloader()

Files already downloaded and verified
Files already downloaded and verified


In [11]:
# Class for early stopping
class EarlyStopping():
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.min_valid_loss = np.inf
        self.early_stop = False

    def __call__(self, validation_loss):
        if validation_loss < self.min_valid_loss:
          self.min_valid_loss = validation_loss
          self.counter = 0
        elif validation_loss > self.min_valid_loss + self.min_delta:
          self.counter +=1
          
          if self.counter >= self.tolerance:  
              self.early_stop = True

In [14]:
early_stopping = EarlyStopping(tolerance=5, min_delta=0.05)
# early_stopping = EarlyStopping(tolerance=5, min_delta=0.001)

best_accuracy = 0.0
# Training the student
for epoch in range(1):
    # Define progress bar
    loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    
    # Training loop
    student_model.train()
    training_loss_epoch = []
    for batch_idx, (x,y,soft_y) in loop:
        optimizer.zero_grad()
        
        x = x.to(device=device)
        y = y.to(device=device)
        soft_y = soft_y.to(device=device)
        
        # Forward pass
        logits_bot, logits_top = student_model(x)
        # Backward pass
        loss = criterion(input=logits_bot, target=soft_y) + criterion(input=logits_top, target=soft_y)
        training_loss_epoch.append(loss.item())
        # Prob trb inlocuit criterionul, sa adaugi soft-labels
        loss.backward()
        
        # Optimize
        optimizer.step()
        
        # Update progress bar
        loop.set_description(f'Epoch {epoch+1}/{EPOCHS}')
        loop.set_postfix(training_loss=loss.item())
    
    # Validation loop on proxy validation dataset
    student_model.eval()
    validation_loss_epoch = []  
    acc_bot = 0
    acc_top = 0
    with torch.no_grad():
        for x,y,_ in valid_dataloader:
            x = x.to(device=device)
            y = y.to(device=device)
        
            logits_bot, logits_top = student_model(x)
            pred_bot = softmax(logits_bot, dim=1)
            pred_top = softmax(logits_top, dim=1)
            
            confidence_bot,y_hat_bot = torch.max(pred_bot, dim=1)
            confidence_top,y_hat_top = torch.max(pred_top, dim=1)
            
            loss = criterion(input=logits_bot, target=y) # + criterion(input=logits_top, target=y)
            validation_loss_epoch.append(loss.item())
            
            acc_bot += torch.sum(y_hat_bot==y).item()
            acc_top += torch.sum(y_hat_top==y).item()
        
    loop.write(f'validation_loss on proxy = {sum(validation_loss_epoch)/len(validation_loss_epoch):.4f}')
    loop.write(f'validation_accuracy_bot on proxy = {100*acc_bot/len(proxy_valid_dataset):.2f}%')
    loop.write(f'validation_accuracy_top on proxy = {100*acc_top/len(proxy_valid_dataset):.2f}%')

    # Save best model
    valid_proxy_acc = acc_bot / len(proxy_valid_dataset)
    if valid_proxy_acc > best_accuracy:
        best_accuracy = valid_proxy_acc
        torch.save(student_model.state_dict(), f'temp_models/{TEMP_BEST_MODEL}.pt')
        print(f'Saved at epoch {epoch}')

    if validate_on_trueds:
        # Validation loop on proxy validation dataset
        student_model.eval()
        with torch.no_grad():
            val_loss = []
            acc = 0
            for x,y in true_valid_ds:
                x = x.to(device=device)
                y = y.to(device=device)
            
                logits, _ = student_model(x)
                pred = softmax(logits, dim=1)
                
                confidence,y_hat = torch.max(pred, dim=1)
                
                loss = criterion(input=logits, target=y)
                val_loss.append(loss.item())
                
                acc += torch.sum(y_hat==y).item()
            
        loop.write(f'validation_loss on true ds = {sum(val_loss)/len(val_loss):.4f}')
        loop.write(f'validation_accuracy on true ds = {100*acc/len(true_dataset.test_dataset):.2f}%')

        
    early_stopping(sum(validation_loss_epoch)/len(validation_loss_epoch))
    if early_stopping.early_stop:
      print(f"We are at epoch {epoch}")
      break

student_model.load_state_dict(torch.load(f'temp_models/{TEMP_BEST_MODEL}.pt'))

  0%|          | 0/1 [00:00<?, ?it/s]

validation_loss on proxy = 2.2148
validation_accuracy_bot on proxy = 10.00%
validation_accuracy_top on proxy = 20.00%
Saved at epoch 0
validation_loss on true ds = 2.2973
validation_accuracy on true ds = 13.54%


<All keys matched successfully>

In [15]:
# Testing on CIFAR10 ground truth
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        y = y.to(device=device)
    
        logits, _ = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            if a == p:
                acc_per_class[a][0] += 1 # correct predictions
            acc_per_class[a][1] += 1     # total number 
            acc_per_class[p][2] += 1     # predictions of class
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: correct_pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 2.2973
test_accuracy = 13.54%

Class airplane: correct_pred=87, actual=1000 => acc=8.70%, total_pred=396
Class automobile: correct_pred=169, actual=1000 => acc=16.90%, total_pred=1049
Class bird: correct_pred=692, actual=1000 => acc=69.20%, total_pred=4483
Class cat: correct_pred=182, actual=1000 => acc=18.20%, total_pred=1945
Class deer: correct_pred=0, actual=1000 => acc=0.00%, total_pred=46
Class dog: correct_pred=4, actual=1000 => acc=0.40%, total_pred=758
Class frog: correct_pred=4, actual=1000 => acc=0.40%, total_pred=82
Class horse: correct_pred=6, actual=1000 => acc=0.60%, total_pred=111
Class ship: correct_pred=181, actual=1000 => acc=18.10%, total_pred=691
Class truck: correct_pred=29, actual=1000 => acc=2.90%, total_pred=439


In [16]:
# Testing using labels predicted with teacher
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
teacher_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        # y = y.to(device=device)

        teacher_pred = softmax(teacher_model(x), dim=1)
        _, y = torch.max(teacher_pred, dim=1)
    
        logits, _ = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            if a == p:
                acc_per_class[a][0] += 1
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: correct_pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 2.3007
test_accuracy = 13.30%

Class airplane: correct_pred=99, actual=1038 => acc=9.54%, total_pred=396
Class automobile: correct_pred=166, actual=1002 => acc=16.57%, total_pred=1049
Class bird: correct_pred=675, actual=966 => acc=69.88%, total_pred=4483
Class cat: correct_pred=168, actual=938 => acc=17.91%, total_pred=1945
Class deer: correct_pred=1, actual=1109 => acc=0.09%, total_pred=46
Class dog: correct_pred=7, actual=913 => acc=0.77%, total_pred=758
Class frog: correct_pred=4, actual=995 => acc=0.40%, total_pred=82
Class horse: correct_pred=6, actual=1063 => acc=0.56%, total_pred=111
Class ship: correct_pred=169, actual=939 => acc=18.00%, total_pred=691
Class truck: correct_pred=35, actual=1037 => acc=3.38%, total_pred=439


In [19]:
# Create new dataset
proxy_unused_dataset  = ProxyDataset(filtered_images_unused,  filtered_labels_unused,  train_transforms, False, filtered_soft_labels_unused)
train_unused_dataloader = DataLoader(proxy_unused_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Reset optimzer
optimizer = torch.optim.Adam(student_model.parameters(), lr=LR)

In [20]:
# early_stopping = EarlyStopping(tolerance=5, min_delta=0.05)

# best_accuracy = 0.0
# Training the student
for epoch in range(1):
    # Define progress bar
    loop = tqdm(enumerate(train_unused_dataloader), total=len(train_unused_dataloader))
    
    # Training loop
    student_model.train()
    training_loss_epoch = []
    for batch_idx, (x,y,soft_y) in loop:
        optimizer.zero_grad()
        
        x = x.to(device=device)
        # y = y.to(device=device)
        # soft_y = soft_y.to(device=device)
        
        # Forward pass
        logits_bot, logits_top = student_model(x)
        # Backward pass
        loss = criterion(input=logits_top, target=logits_bot)
        training_loss_epoch.append(loss.item())
        # Prob trb inlocuit criterionul, sa adaugi soft-labels
        loss.backward()
        
        # Optimize
        optimizer.step()
        
        # Update progress bar
        loop.set_description(f'Epoch {epoch+1}/{EPOCHS}')
        loop.set_postfix(training_loss=loss.item())
        
    # early_stopping(sum(validation_loss_epoch)/len(validation_loss_epoch))
    # if early_stopping.early_stop:
    #   print(f"We are at epoch {epoch}")
    #   break

# student_model.load_state_dict(torch.load(f'temp_models/{TEMP_BEST_MODEL}.pt'))

  0%|          | 0/1600 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
# Testing on CIFAR10 ground truth
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        y = y.to(device=device)
    
        logits, _ = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
            if a == p:
                acc_per_class[a][0] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 2.4607
test_accuracy = 15.43%

Class airplane: pred=21, actual=1000 => acc=2.10%, total_pred=246
Class automobile: pred=101, actual=1000 => acc=10.10%, total_pred=1183
Class bird: pred=584, actual=1000 => acc=58.40%, total_pred=3738
Class cat: pred=90, actual=1000 => acc=9.00%, total_pred=1551
Class deer: pred=0, actual=1000 => acc=0.00%, total_pred=0
Class dog: pred=6, actual=1000 => acc=0.60%, total_pred=499
Class frog: pred=0, actual=1000 => acc=0.00%, total_pred=0
Class horse: pred=0, actual=1000 => acc=0.00%, total_pred=61
Class ship: pred=729, actual=1000 => acc=72.90%, total_pred=2481
Class truck: pred=12, actual=1000 => acc=1.20%, total_pred=241


In [23]:
# Testing using labels predicted with teacher
true_dataloader = true_dataset.test_dataloader()
acc_per_class = {k:[0,0,0] for k,v in label_mapper_inv.items()}

student_model.eval()
teacher_model.eval()
with torch.no_grad():
    test_loss = []
    acc = 0
    for x,y in true_dataloader:
        x = x.to(device=device)
        # y = y.to(device=device)

        teacher_pred = softmax(teacher_model(x), dim=1)
        _, y = torch.max(teacher_pred, dim=1)
    
        logits, _ = student_model(x)
        pred = softmax(logits, dim=1)
        
        confidence,y_hat = torch.max(pred, dim=1)
        
        loss = criterion(input=logits, target=y)
        test_loss.append(loss.item())
        
        acc += torch.sum(y_hat==y)
        
        for i in range(len(y)):
            a = y[i].item()
            p = y_hat[i].item()
            
            
            acc_per_class[a][1] += 1
            acc_per_class[p][2] += 1
            if a == p:
                acc_per_class[a][0] += 1
        
    print('Student with true dataset:')
    print(f'test_loss = {sum(test_loss)/len(test_loss):.4f}')                  # 1.7439;  1.6199
    print(f'test_accuracy = {100*acc/len(true_dataset.test_dataset):.2f}%')    # 46.09%;  48.57%
    print()
    
for k,v in acc_per_class.items():
    print(f'Class {label_mapper_inv[k]}: pred={v[0]}, actual={v[1]} => acc={v[0]*100/v[1]:.2f}%, total_pred={v[2]}')

Student with true dataset:
test_loss = 2.4644
test_accuracy = 15.10%

Class airplane: pred=23, actual=1038 => acc=2.22%, total_pred=246
Class automobile: pred=102, actual=1002 => acc=10.18%, total_pred=1183
Class bird: pred=576, actual=966 => acc=59.63%, total_pred=3738
Class cat: pred=86, actual=938 => acc=9.17%, total_pred=1551
Class deer: pred=0, actual=1109 => acc=0.00%, total_pred=0
Class dog: pred=4, actual=913 => acc=0.44%, total_pred=499
Class frog: pred=0, actual=995 => acc=0.00%, total_pred=0
Class horse: pred=0, actual=1063 => acc=0.00%, total_pred=61
Class ship: pred=706, actual=939 => acc=75.19%, total_pred=2481
Class truck: pred=13, actual=1037 => acc=1.25%, total_pred=241
