In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Function
from torchvision import models
from torchvision import utils
from hard_margin_loss import MarginLoss
from matplotlib import pyplot as plt
from torchvision import datasets, models, transforms

In [None]:
#Create the dataset loader
input_path = 'Dataset/'

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((256,256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]),
    'validation':
    transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ]),
        'test':
    transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])
}

image_datasets = {
    'train': 
    datasets.ImageFolder(input_path + 'train', data_transforms['train']),
    'validation': 
    datasets.ImageFolder(input_path + 'validation', data_transforms['validation']),  
    'test': 
    datasets.ImageFolder(input_path + 'test', data_transforms['test'])
}

dataloaders = {
    'train':
    torch.utils.data.DataLoader(image_datasets['train'],
                                batch_size=32,
                                shuffle=True,
                                num_workers=0),
    'validation':
    torch.utils.data.DataLoader(image_datasets['validation'],
                                batch_size=32,
                                shuffle=False,
                                num_workers=0), 
    
    'test':
    torch.utils.data.DataLoader(image_datasets['test'],
                                batch_size=32,
                                shuffle=False,
                                num_workers=0) 
}

In [None]:
#Load the weights of Resident Fellow (Teacher)
class DenseNet121(nn.Module):
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
        )
    def forward(self, x):
        x = self.densenet121(x)
        return x

model = DenseNet121(3).cuda()
model = torch.nn.DataParallel(model)
model = torch.load('RF_model/Pretrain_DenseNet')
print("model loaded")
device = torch.device("cuda:0")
model.to(device);

In [None]:
#Test the performance of Resident Fellow (Teacher)
running_corrects = 0
model.eval()
for inputs, labels in dataloaders['test']:
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    running_corrects += torch.sum(preds == labels.data)
epoch_acc = running_corrects.double() / len(image_datasets['test'])
print(epoch_acc.data.cpu().numpy())

In [None]:
#Set up parameters
num_classes = 3
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()
#Set margin to be 0.8
criterion3 = MarginLoss(3,margin=0.8)
num_epochs = 50

In [None]:
#Training and save model weights

#Loop through turning parameter of a in paper
for a in [0.2,0.4,0.6,0.8]:
    #loop through turning parameter of T in paper
    for T in [1,5,10]:
        
        #Create the model, change the output layer to 3
        best_acc = 0
        num_classes = 3
        mobilenet = models.mobilenet_v2(pretrained=True)
        mobilenet.classifier[1] = nn.Linear(1280,num_classes)
        mobilenet.to(device);

        #Load pretrained clean weights for training noisy network
        #mobilenet = torch.load('mobilenet/clean')
        
        #Set up Adam optimizer
        optimizer = optim.Adam(mobilenet.parameters(),lr=0.0002, betas=(0.9, 0.999))

        #Fixed the weights of Resident Fellow (Teacher)
        for param in model.parameters():
            param.requires_grad = False
        model.eval()
        
        #Training starts
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch+1, num_epochs))
            print('-' * 10)

            for phase in ['train', 'validation']:
                if phase == 'train':
                    mobilenet.train()
                else:
                    mobilenet.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    outputs = mobilenet(inputs)
                    
                    #criterion1 is softmax loss
                    loss1 = criterion1(outputs, labels)
                   
                    #criterion3 is PC loss
                    #loss1 = criterion3(outputs, labels)

                    soft_target = model(inputs)

                    outputs_S = F.log_softmax(outputs/T, dim=1)
                    outputs_T = F.softmax(soft_target/T, dim=1)
                    loss2 = criterion2(outputs_S, outputs_T) * T * T


                    loss = (1-a)*loss1 + a*loss2 
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    _, preds = torch.max(outputs, 1)
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                epoch_loss = running_loss / len(image_datasets[phase])
                epoch_acc = running_corrects.double() / len(image_datasets[phase])

                if phase =='validation':
                    if epoch_acc>best_acc:
                        torch.save(mobilenet, 'MobileNet_weights/a='+str(a)+'T='+str(T))
                        best_acc = epoch_acc
                print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                            epoch_loss,
                                                            epoch_acc))

In [None]:
#Testing the learned Medical Student
num_classes = 3
mobilenet = models.mobilenet_v2(pretrained=True)
mobilenet.classifier[1] = nn.Linear(1280,num_classes)
mobilenet = torch.load('MobileNet_weights/best_validation_weights')
mobilenet.to(device)
mobilenet.eval();

In [None]:
running_corrects = 0
for inputs, labels in dataloaders['test']:
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = mobilenet(inputs)
    _, preds = torch.max(outputs, 1)
    running_corrects += torch.sum(preds == labels.data)
epoch_acc = running_corrects.double() / len(image_datasets['test'])
print(epoch_acc.data.cpu().numpy())