In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random
import tqdm
import PIL
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
drop_rate = 0

train_val_test_split = (0.7,0.2,0.1)
batch_size = 96
num_workers=6
drop_rate = 0.23
train_dir = '../data/train'
epochs = 10
lr = 0.001

In [3]:
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.densenet121(weights='DEFAULT')
        for params in self.model.parameters():
            params.requires_grad_ = False

        num_ftrs = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 500),
            nn.Linear(500, 2)
            )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [4]:
class Student(nn.Module):

    def __init__(self):
        super().__init__()

        # onvolutional layers (3,16,32)
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        # conected layers
        self.fc1 = nn.Linear(in_features= 64 * 3 * 3, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=2)


    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [5]:
train_files = random.sample(os.listdir(train_dir),k = 20000)
# Only select jpg files
train_files = [f for f in train_files if '.jpg' in f]

In [6]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ColorJitter(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [7]:
# Creating dataset

class DogsVsCatsDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = val_transform):
        self.file_list = file_list
        self.dir = dir
        self.mode= mode
        self.transform = transform
            
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.dir, self.file_list[idx]))
        img = self.transform(img)
        img = np.array(img)
        if 'dog' in self.file_list[idx]:
            self.label = 1
        else:
            self.label = 0
        return img.astype('float32'), self.label


train_files, test_files = train_test_split(train_files, 
                                    test_size=train_val_test_split[2], 
                                    random_state=42
                                    )
train_files, valid_files = train_test_split(train_files,
                                    test_size=train_val_test_split[1]/train_val_test_split[0], 
                                    random_state=42
                                    )

TrainDataSet = DogsVsCatsDataset(train_files, dir = train_dir, mode='train', transform = train_transform)
TrainDataLoader = DataLoader(TrainDataSet, batch_size = batch_size, shuffle=True, num_workers=num_workers)

ValidDataSet = DogsVsCatsDataset(valid_files, dir = train_dir, mode='valid')
ValidDataLoader = DataLoader(ValidDataSet, batch_size = batch_size, shuffle=False, num_workers=num_workers)

TestDataSet = DogsVsCatsDataset(test_files, dir = train_dir, mode='test')
TestDataLoader = DataLoader(TestDataSet, batch_size = batch_size, shuffle=False, num_workers=num_workers)

In [8]:
student_model = Student()
teacher_model = Teacher()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
def train(model):

    train_loss_list = []
    train_acc_list = []
    valid_loss_list = []
    valid_acc_list = []
    test_acc_list = []

    for epoch in range(epochs):
        print("Epoch",epoch+1,"/",epochs)
        model.train()
        train_loss = 0
        train_acc = 0
        itr = 1
        tot_itr = len(TrainDataLoader)
        for samples, labels in tqdm.tqdm(TrainDataLoader, desc = "Training", unit = " Iterations"):
            samples, labels = samples.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(samples)
            loss = criterion(output, labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            pred = torch.argmax(output, dim=1)
            correct = pred.eq(labels)
            train_acc+= torch.mean(correct.float())
            torch.cuda.empty_cache()
            itr += 1
            
        train_loss_list.append(train_loss/tot_itr)
        train_acc_list.append(train_acc.item()/tot_itr)
        print(' Total Loss: {:.4f}, Accuracy: {:.1f} %'.format(train_loss, train_acc/tot_itr*100))

        model.eval()
        valid_loss=0
        valid_acc=0
        itr=1
        tot_itr = len(ValidDataLoader)
        for samples, labels in tqdm.tqdm(ValidDataLoader, desc = "Validating", unit = " Iterations"):
            with torch.no_grad():
                samples, labels = samples.to(device), labels.to(device)
                output = model(samples)
                loss = criterion(output, labels)
                valid_loss += loss.item()
                pred = torch.argmax(output, dim=1)
                correct = pred.eq(labels)
                valid_acc += torch.mean(correct.float())
                torch.cuda.empty_cache()
                itr += 1
                
        valid_loss_list.append(valid_loss/tot_itr)
        valid_acc_list.append(valid_acc.item()/tot_itr)
        print('-----------------------------> Validation Loss: {:.4f}, Accuracy: {:.1f} %'.format(valid_loss, valid_acc/tot_itr*100))


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
train(teacher_model)


Epoch 1 / 10
 Total Loss: 24.1403, Accuracy: 92.3 %
-----------------------------> Validation Loss: 8.5310, Accuracy: 93.4 %
Epoch 2 / 10
 Total Loss: 16.2486, Accuracy: 94.9 %
-----------------------------> Validation Loss: 7.9996, Accuracy: 94.0 %
Epoch 3 / 10
 Total Loss: 13.7326, Accuracy: 95.9 %
-----------------------------> Validation Loss: 7.7569, Accuracy: 94.3 %
Epoch 4 / 10
 Total Loss: 12.7694, Accuracy: 96.4 %
-----------------------------> Validation Loss: 7.0898, Accuracy: 94.7 %
Epoch 5 / 10
 Total Loss: 11.5515, Accuracy: 96.5 %
-----------------------------> Validation Loss: 9.3706, Accuracy: 93.5 %
Epoch 6 / 10
 Total Loss: 9.7655, Accuracy: 97.2 %
-----------------------------> Validation Loss: 9.6928, Accuracy: 93.3 %
Epoch 7 / 10
 Total Loss: 10.5929, Accuracy: 97.1 %
-----------------------------> Validation Loss: 7.6148, Accuracy: 93.6 %
Epoch 8 / 10
 Total Loss: 10.0824, Accuracy: 97.2 %
-----------------------------> Validation Loss: 6.9528, Accuracy: 94.6 %
E

Training: 100%|██████████| 134/134 [01:37<00:00,  1.38 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.45 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.44 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:34<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.47 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.44 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.43 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:35<00:00,  1.41 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.47 Iterations/s]
Training: 100%|██████████| 134/134 [01:33<00:00,  1.43 Iterations/s]
Validating: 100%|██████████| 54/54 [00:37<00:00,  1.43 Iterations/s]
Training: 100%|██████████| 134/134

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr, amsgrad=True)
student_model = student_model.to(device)
train(student_model)

Epoch 1 / 10
 Total Loss: 92.6143, Accuracy: 52.1 %
-----------------------------> Validation Loss: 37.1863, Accuracy: 52.8 %
Epoch 2 / 10
 Total Loss: 91.8145, Accuracy: 54.4 %
-----------------------------> Validation Loss: 36.8519, Accuracy: 55.3 %
Epoch 3 / 10
 Total Loss: 90.5905, Accuracy: 57.6 %


Training: 100%|██████████| 134/134 [01:34<00:00,  1.42 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:30<00:00,  1.49 Iterations/s]
Validating: 100%|██████████| 54/54 [00:36<00:00,  1.48 Iterations/s]
Training: 100%|██████████| 134/134 [01:30<00:00,  1.48 Iterations/s]
Validating:  41%|████      | 22/54 [00:17<00:17,  1.81 Iterations/s]

In [None]:
student_model_distilled = Student()

In [None]:
class DistillationLoss:

    def __init__(self):
        self.student_loss = nn.CrossEntropyLoss()
        self.distillation_loss = nn.KLDivLoss()
        self.temperature = 1
        self.alpha = 0.25

    def __call__(self, student_logits, student_target_loss, teacher_logits):
        distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1), 
                                                   F.softmax(teacher_logits/self.temperature, dim=1))
        loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
        return loss

In [None]:
criterion = nn.CrossEntropyLoss()
ds_loss = DistillationLoss()
optimizer = torch.optim.Adam(student_model_distilled.parameters(), lr=lr, amsgrad=True)
teacher_model = teacher_model.to(device)
student_model_distilled = student_model_distilled.to(device)

In [None]:
def train_with_distillation(student_model, teacher_model):

    train_loss_list = []
    train_acc_list = []
    valid_loss_list = []
    valid_acc_list = []
    test_acc_list = []

    for epoch in range(epochs):
        print("Epoch",epoch+1,"/",epochs)
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_acc = 0
        
        itr = 1
        tot_itr = len(TrainDataLoader)
        for samples, labels in tqdm.tqdm(TrainDataLoader, desc = "Training", unit = " Iterations"):
            samples, labels = samples.to(device), labels.to(device)
            optimizer.zero_grad()
            output_student = student_model(samples)
            output_teacher = teacher_model(samples)

            student_target_loss = criterion(output_student, labels)
            loss = ds_loss(output_student, student_target_loss, output_teacher)

            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            pred = torch.argmax(output_student, dim=1)
            correct = pred.eq(labels)
            train_acc+= torch.mean(correct.float())
            torch.cuda.empty_cache()
            itr += 1
            
        train_loss_list.append(train_loss/tot_itr)
        train_acc_list.append(train_acc.item()/tot_itr)
        print(' Total Loss: {:.4f}, Accuracy: {:.1f} %'.format(train_loss, train_acc/tot_itr*100))

        student_model.eval()
        valid_loss=0
        valid_acc=0
        itr=1
        tot_itr = len(ValidDataLoader)
        for samples, labels in tqdm.tqdm(ValidDataLoader, desc = "Validating", unit = " Iterations"):
            with torch.no_grad():
                samples, labels = samples.to(device), labels.to(device)
                output = student_model(samples)
                loss = criterion(output, labels)
                valid_loss += loss.item()
                pred = torch.argmax(output, dim=1)
                correct = pred.eq(labels)
                valid_acc += torch.mean(correct.float())
                torch.cuda.empty_cache()
                itr += 1
                
        valid_loss_list.append(valid_loss/tot_itr)
        valid_acc_list.append(valid_acc.item()/tot_itr)
        print('-----------------------------> Validation Loss: {:.4f}, Accuracy: {:.1f} %'.format(valid_loss, valid_acc/tot_itr*100))


In [None]:
train_with_distillation(student_model=student_model_distilled, teacher_model=teacher_model)