In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as  F
import torch.nn as nn

In [3]:
DEVICE='cuda'
use_gpu=True

In [4]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [5]:
#data=pd.read_csv('/content/drive/MyDrive/cs231n/mnist_train.csv')
#X=data.iloc[:,1:]
#Y=data.iloc[:,0]
#X=np.array(X)
#Y=np.array(Y)

In [6]:
#Dataset class ,not used here
class CustomDataset:
  def __init__(self,data,targets):
    self.data=data
    self.targets=targets
  def __len__(self):
    return len(self.data)
  def __getitem__(self,idx):
    current_sample=self.data[idx,:]
    current_target=self.targets[idx]
    return{
        "sample":torch.tensor(current_sample,dtype=torch.float),
        "target":torch.tensor(current_target,dtype=torch.long)
    }

**Using MNIST and performing augmentations as mentioned in paper for the TEACHER MODEL ONLY**

In [7]:
import torchvision
import torchvision.transforms as transforms

mnist_image_shape = (28, 28)
random_pad_size = 2
# Training images augmented by randomly shifting images by at max. 2 pixels in any of 4 directions
transform_train = transforms.Compose(
                [
                    transforms.RandomCrop(mnist_image_shape, random_pad_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform_test)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_dataset/MNIST/raw



In [8]:
x,y=train_dataset.__getitem__(50)
print(y)

6


In [9]:
#custom_dataset=CustomDataset(X,Y)
#trainloader=torch.utils.data.DataLoader(custom_dataset,batch_size=100)

In [10]:
#data_test=pd.read_csv('/content/drive/MyDrive/cs231n/mnist_test.csv',header=None)
#X_test=np.array(data_test.iloc[:,1:])
#Y_test=np.array(data_test.iloc[:,0])
#val_dataset=CustomDataset(X_test,Y_test)
#val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=32)

In [11]:
#print(X_test.shape,Y_test.shape)

**MODEL class for implementing fully connected neural networks of required architecture**

In [12]:
class model(nn.Module):
    def __init__(self,in_channels,hidden_layer):
        super(model,self).__init__()
        self.in_channels=in_channels
        self.hidden_layer=hidden_layer
        self.drop1=nn.Dropout(p=0.5)
        self.drop2=nn.Dropout(p=0.2)
        self.fcs=nn.ModuleList()
        for idx,layer_size in enumerate(hidden_layer):
            self.fcs.append(nn.Linear(in_channels,layer_size))
            nn.init.normal_(self.fcs[idx].weight, mean=0.0, std=0.01)
            self.fcs.append(nn.BatchNorm1d(layer_size))
            in_channels=layer_size
        self.final=nn.Linear(in_channels,10)
        nn.init.normal_(self.final.weight, mean=0.0, std=0.01)
    def forward(self,x):
        x = x.view(-1, 28 * 28)
        x=self.drop2(x)
        for idx in range(0,len(self.fcs),2):
            x=self.fcs[idx](x)
            x=self.fcs[idx+1](x)
            x=F.relu(x)
            x=self.drop1(x)
        x=self.final(x)
        return x

**This one uses dropout**

In [13]:
class model_without_dropout(nn.Module):
    def __init__(self,in_channels,hidden_layer):
        super(model_without_dropout,self).__init__()
        self.in_channels=in_channels
        self.hidden_layer=hidden_layer
        self.drop1=nn.Dropout(p=0.5)
        self.drop2=nn.Dropout(p=0.2)
        self.fcs=nn.ModuleList()
        for idx,layer_size in enumerate(hidden_layer):
            self.fcs.append(nn.Linear(in_channels,layer_size))
            nn.init.normal_(self.fcs[idx].weight, mean=0.0, std=0.01)
            self.fcs.append(nn.BatchNorm1d(layer_size))
            in_channels=layer_size
        self.final=nn.Linear(in_channels,10)
        nn.init.normal_(self.final.weight, mean=0.0, std=0.01)
    def forward(self,x):
        #x=self.drop2(x)
        x = x.view(-1, 28 * 28)
        for idx in range(0,len(self.fcs),2):
            x=self.fcs[idx](x)
            x=self.fcs[idx+1](x)
            x=F.relu(x)
            #x=self.drop1(x)
        x=self.final(x)
        return x

**Checker function to check accuracy on test set**

In [14]:
def check(loader,model,device="cuda"):
  model.eval()
  num_correct = 0
  num_samples = 0
  with torch.no_grad():
      for t,data in enumerate(loader,0):
        x,y=data
        x=x.to(device)
        y=y.to(device)
        scores = model(x)
        _, preds = scores.max(1)
        num_correct += (preds == y).sum()
        num_samples += preds.size(0)
      acc = float(num_correct) / num_samples
      print('Got %d wrong tests'%(num_samples-num_correct))
      print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))


**Train loop for baseline teacher and student model**

In [15]:
def train_loop(model,optimizer,epochs,scheduler):
  for epoch in range(epochs):
        for t,data in enumerate(train_val_loader,0):
           x,y=data
           x=x.to(DEVICE)
           y=y.to(DEVICE)
           model.to(DEVICE)
           model.train()
           scores=model(x)
           loss=F.cross_entropy(scores,y)
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
        scheduler.step()
        check(val_loader,model)
        

**Train loop for training a distilled model-Uses both croos entropy loss from labels and KL div between the student's and teacher's predicted probabilities**

In [16]:
def train_loop_distill(epochs,model1,model3,temperature,soft_targets_weight,label_loss_weight):
    #model1----Teacher Model
    #model3----Student Model
    for epoch in range(epochs):
        for t,data in enumerate(train_val_loader,0):
            x,y=data
            x=x.to(DEVICE)
            y=y.to(DEVICE)
            model3.to(DEVICE)
            model1.eval()
            with torch.no_grad():
                large_logits = model1(x)
            model3.train()
            scores=model3(x)
            #soft_targets = F.log_softmax(large_logits / temperature, dim=-1)
            #soft_prob = F.log_softmax(scores/ temperature, dim=-1)
            soft_targets_loss = F.kl_div(F.log_softmax(scores / temperature, dim=1), F.softmax(large_logits / temperature, dim=1), reduction='batchmean')
            label_loss = F.cross_entropy(scores, y)
            # Weighted sum of the two losses
            #label_loss_weight----------(1-alpha)
            #soft_targets_weight--------alpha*(temperature^2)
            loss = soft_targets_weight * soft_targets_loss + label_loss_weight * label_loss
            optimizer_distilled.zero_grad()
            loss.backward()
            optimizer_distilled.step()
        scheduler_distilled.step()
        check(val_loader,model3)
        

In [17]:
learning_rate=0.01


# Training Teacher Baseline

In [18]:
model_teacher = model_without_dropout(784,[1200,1200])
optimizer_teacher =torch.optim.SGD(model_teacher.parameters(), lr=learning_rate,weight_decay=1e-5,momentum=0.9)
scheduler_teacher = torch.optim.lr_scheduler.StepLR(optimizer_teacher, step_size=1, gamma=0.95)

In [19]:
train_loop(model_teacher,optimizer_teacher,60,scheduler_teacher)

Got 164 wrong tests
Got 2836 / 3000 correct (94.53)
Got 108 wrong tests
Got 2892 / 3000 correct (96.40)
Got 89 wrong tests
Got 2911 / 3000 correct (97.03)
Got 73 wrong tests
Got 2927 / 3000 correct (97.57)
Got 73 wrong tests
Got 2927 / 3000 correct (97.57)
Got 62 wrong tests
Got 2938 / 3000 correct (97.93)
Got 58 wrong tests
Got 2942 / 3000 correct (98.07)
Got 53 wrong tests
Got 2947 / 3000 correct (98.23)
Got 42 wrong tests
Got 2958 / 3000 correct (98.60)
Got 36 wrong tests
Got 2964 / 3000 correct (98.80)
Got 30 wrong tests
Got 2970 / 3000 correct (99.00)
Got 33 wrong tests
Got 2967 / 3000 correct (98.90)
Got 31 wrong tests
Got 2969 / 3000 correct (98.97)
Got 29 wrong tests
Got 2971 / 3000 correct (99.03)
Got 33 wrong tests
Got 2967 / 3000 correct (98.90)
Got 22 wrong tests
Got 2978 / 3000 correct (99.27)
Got 32 wrong tests
Got 2968 / 3000 correct (98.93)
Got 23 wrong tests
Got 2977 / 3000 correct (99.23)
Got 22 wrong tests
Got 2978 / 3000 correct (99.27)
Got 21 wrong tests
Got 2979 /

In [20]:
acc=check(test_loader,model_teacher)
print(acc)

Got 71 wrong tests
Got 9929 / 10000 correct (99.29)
None


**Loading datasets for the student and distilled student They do not use augmented train sets**

In [21]:
transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]
            )

train_val_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=True, 
                                            download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(root='./MNIST_dataset/', train=False, 
                                            download=True, transform=transform)

num_train = int(1.0 * len(train_val_dataset) * 95 / 100)
num_val = len(train_val_dataset) - num_train
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=128, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [22]:
model_student=model_without_dropout(784,[800,800])
optimizer_student =torch.optim.SGD(model_student.parameters(), lr=learning_rate,weight_decay=1e-5,momentum=0.9)
scheduler_student = torch.optim.lr_scheduler.StepLR(optimizer_student, step_size=1, gamma=0.95)

In [23]:
model_distilled=model_without_dropout(784,[800,800])
optimizer_distilled =torch.optim.SGD(model_distilled.parameters(), lr=learning_rate,weight_decay=1e-5,momentum=0.9)
scheduler_distilled = torch.optim.lr_scheduler.StepLR(optimizer_distilled, step_size=1, gamma=0.95)

# Student Baseline

In [24]:
train_loop(model_student,optimizer_student,60,scheduler_student)

Got 71 wrong tests
Got 2929 / 3000 correct (97.63)
Got 27 wrong tests
Got 2973 / 3000 correct (99.10)
Got 24 wrong tests
Got 2976 / 3000 correct (99.20)
Got 12 wrong tests
Got 2988 / 3000 correct (99.60)
Got 8 wrong tests
Got 2992 / 3000 correct (99.73)
Got 2 wrong tests
Got 2998 / 3000 correct (99.93)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 1 wrong tests
Got 2999 / 3000 correct (99.97)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 1 wrong tests
Got 2999 / 3000 correct (99.97)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 correct (100.00)
Got 0 wrong tests
Got 3000 / 3000 c

In [25]:
acc=check(test_loader,model_student)
print(acc)

Got 135 wrong tests
Got 9865 / 10000 correct (98.65)
None


In [26]:
train_loop_distill(60,model_teacher,model_distilled,temperature=20,soft_targets_weight=200,label_loss_weight=0.5)

Got 31 wrong tests
Got 2969 / 3000 correct (98.97)
Got 16 wrong tests
Got 2984 / 3000 correct (99.47)
Got 18 wrong tests
Got 2982 / 3000 correct (99.40)
Got 8 wrong tests
Got 2992 / 3000 correct (99.73)
Got 5 wrong tests
Got 2995 / 3000 correct (99.83)
Got 6 wrong tests
Got 2994 / 3000 correct (99.80)
Got 5 wrong tests
Got 2995 / 3000 correct (99.83)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 4 wrong tests
Got 2996 / 3000 correct (99.87)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.90)
Got 3 wrong tests
Got 2997 / 3000 correct (99.9

In [27]:
acc=check(test_loader,model_teacher)
print(acc)

Got 71 wrong tests
Got 9929 / 10000 correct (99.29)
None


In [28]:
acc=check(test_loader,model_student)
print(acc)

Got 135 wrong tests
Got 9865 / 10000 correct (98.65)
None


In [29]:
acc=check(test_loader,model_distilled)
print(acc)

Got 77 wrong tests
Got 9923 / 10000 correct (99.23)
None
