In [None]:
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import time
%matplotlib inline


In [None]:
#Download data, normalize and apply transforms

stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

train_tfms = tt.Compose([tt.RandomCrop(32, padding=4, padding_mode='reflect'),tt.RandomHorizontalFlip(),tt.ToTensor(),tt.Normalize(*stats,inplace=True)])
valid_tfms = tt.Compose([tt.ToTensor(), tt.Normalize(*stats,inplace=True)])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_tfms)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=valid_tfms)


In [None]:
#Prepare train and test data loader

batch_size = 200
train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=3,pin_memory=True)
valid_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size*2,
                                         shuffle=False, num_workers=3,pin_memory=True)

In [None]:
#use cuda if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1,bias=False), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

#Adds a randomly sampled noise from a gaussian distribution to a standardized input tensor
class AddNoise(nn.Module):
    def __init__(self, mean=0, std=1):
        super(AddNoise, self).__init__()
        self.mean = mean
        self.std = std
    
    def to_device(self, device):
        # Move the internal state (mean and std) to the specified device
        self.mean = torch.tensor(self.mean, device=device)
        self.std = torch.tensor(self.std, device=device)
    
    def forward(self, x):
        noise = torch.randn(x.size(),device=device) * self.std + self.mean
        return x+noise

#Resnet 9 architecture for teacher without noise
class Net1(nn.Module):
    def __init__(self,in_channels,num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1,bias = False)
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.add_noise = AddNoise()
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2)
        self.res1 = nn.Sequential(conv_block(128, 128),conv_block(128, 128))
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=False)
        self.res1batchnorm1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=False)
        self.res1batchnorm2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=False)
        self.batchnorm4 = nn.BatchNorm2d(512)
        self.res2 = nn.Sequential(conv_block(512, 512),conv_block(512, 512))
        self.res2conv1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=False)
        self.res2batchnorm1 = nn.BatchNorm2d(512)
        self.res2conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=False)
        self.res2batchnorm2 = nn.BatchNorm2d(512)
        self.classifier = nn.Sequential(nn.MaxPool2d(4), 
                                        nn.Flatten(), 
                                        nn.Linear(512, num_classes))        

        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.pool(F.relu(self.batchnorm2(self.conv2(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        y = x
        x = F.relu(self.res1batchnorm1(self.res1conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = F.relu(self.res1batchnorm2(self.res1conv2(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = x+y
        x = self.pool(F.relu(self.batchnorm3(self.conv3(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.pool(F.relu(self.batchnorm4(self.conv4(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        z = x
        x = F.relu(self.res2batchnorm1(self.res2conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = F.relu(self.res2batchnorm2(self.res2conv2(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = x+z
        x = self.classifier(x)
        return x
        
#Resnet 9 architecture for student with noise   
class Net2(nn.Module):
    def __init__(self,in_channels,num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1,bias = False)
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.add_noise = AddNoise()
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2)
        self.res1 = nn.Sequential(conv_block(128, 128),conv_block(128, 128))
        self.res1conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=False)
        self.res1batchnorm1 = nn.BatchNorm2d(128)
        self.res1conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=False)
        self.res1batchnorm2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=False)
        self.batchnorm4 = nn.BatchNorm2d(512)
        self.res2 = nn.Sequential(conv_block(512, 512),conv_block(512, 512))
        self.res2conv1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=False)
        self.res2batchnorm1 = nn.BatchNorm2d(512)
        self.res2conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=False)
        self.res2batchnorm2 = nn.BatchNorm2d(512)
        self.classifier = nn.Sequential(nn.MaxPool2d(4), 
                                        nn.Flatten(), 
                                        nn.Linear(512, num_classes))        

        
    def forward(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = self.pool(F.relu(self.batchnorm2(self.conv2(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        y = x
        x = F.relu(self.res1batchnorm1(self.res1conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = F.relu(self.res1batchnorm2(self.res1conv2(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = x+y
        x = self.pool(F.relu(self.batchnorm3(self.conv3(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = self.pool(F.relu(self.batchnorm4(self.conv4(x))))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        z = x
        x = F.relu(self.res2batchnorm1(self.res2conv1(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = F.relu(self.res2batchnorm2(self.res2conv2(x)))
        x = (x-torch.mean(x))/(torch.std(x))
        x = self.add_noise(x)
        x = x+z
        x = self.classifier(x)
        return x

net1 = Net1(3,10)
net2 = Net2(3,10)
net1 = net1.to(device)
net2 = net2.to(device)

In [None]:
#load the pretrained weights for clean baseline network 
net1.load_state_dict(torch.load('teacher_30epochs_pre.pth'))
net2.load_state_dict(torch.load('teacher_30epochs_pre.pth'))

In [None]:
#make the weights fixed for the teacher network
for param in net1.parameters():
    param.requires_grad = False

In [None]:
epochs = 30
max_lr = 0.005
grad_clip = 0.1
weight_decay = 1e-4

In [None]:
import torch.optim as optim


criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()
optimizer2 = optim.Adam(net2.parameters(), lr=max_lr, weight_decay = weight_decay)
scheduler2 = optim.lr_scheduler.OneCycleLR(optimizer2, max_lr, epochs=epochs,steps_per_epoch=len(train_dl))

In [None]:
#Function to get learning rate during training
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    
T = 6 #temperature of distillation
training_time = 0
alpha = 1 #balance between hard and soft targets

for epoch in range(epochs):  # loop over the dataset multiple times
    
    start_time = time.perf_counter()
    
    running_loss = 0.0
    lrs = []
    correct_predictions = 0
    total_predictions = 0
    for i, data in enumerate(train_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer2.zero_grad()

        # forward + backward + optimize
        outputs1 = net1(inputs)
        outputs1_prob = F.softmax(outputs1/T, dim=1)
        outputs2 = net2(inputs)
        outputs2_prob = F.softmax(outputs2, dim=1)
        outputs2_prob_T = F.softmax(outputs2/T, dim=1)
        loss_student_hard = criterion1(outputs2_prob, labels)
        loss_student_soft = criterion2(outputs2_prob_T, outputs1_prob)
        loss =loss_student_hard+alpha*(T**2)*(loss_student_soft)
        loss.backward()
        nn.utils.clip_grad_value_(net2.parameters(), grad_clip)
        optimizer2.step()
        lrs.append(get_lr(optimizer2))
        scheduler2.step()
        
        # Calculate training accuracy
        _, predicted = torch.max(outputs2, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)
        training_acc = (correct_predictions/total_predictions)*100
        
        # print statistics
        running_loss += loss.item()
        if i % 50 == 49:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 50:.3f} learning rate: {lrs[i]} training accuracy: {training_acc:.3f}')
            running_loss = 0.0
            
    training_time += time.perf_counter()-start_time

print(f'Finished Training, Training time: {training_time}')


In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in valid_dl:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        
        outputs = net2(images) 

        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')    

