In [1]:
# Necessary imports
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchinfo import summary
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
import cv2
from PIL import Image
from sklearn.model_selection import train_test_split

In [2]:
# For reproducibility
torch.manual_seed(0)
torch.backends.cudnn.benchmark = True

In [3]:
def get_df(img_path):
    image, label = [], []

    i = 0

    for get_img_name in os.listdir(img_path):
        image += [os.path.join(img_path, get_img_name)]
        label += [get_img_name.split('.')[0]]
        
        i = i+1

    PathDF = pd.DataFrame({'image': image, 'label': label})
    print(i)
    PathDF.head()
    return PathDF

In [4]:
transform = transforms.Compose([
            transforms.Resize([224,224]),
            transforms.ToTensor()
            ])

In [5]:
class mod_Dataset(Dataset):
    def __init__(self, path_df, transform=None):
        self.path_df = path_df
        self.transform = transform

    def __len__(self):
        return self.path_df.shape[0]
    
    def __getitem__(self, idx):
        if self.transform is not None:
            image = self.transform(Image.open(self.path_df.iloc[idx]['image']))

            if self.path_df.iloc[idx]['label'] == 'cat':
                label = torch.tensor([1, 0], dtype=torch.float32)
            else:
                label = torch.tensor([0, 1], dtype=torch.float32)

        return image, label

In [6]:
train_df = get_df('/home/yasaisen/Desktop/13_research/research_main/lab_02/dogs-vs-cats/train')

validation_fraction = 0.15
test_fraction = 0.10

train2rest = validation_fraction + test_fraction
test2valid = validation_fraction / train2rest


train_df, rest = train_test_split(train_df, random_state=42,
                                 test_size = train2rest)

test_df, valid_df = train_test_split(rest, random_state=42,
                                    test_size = test2valid)

train_data = mod_Dataset(train_df, transform)
valid_data = mod_Dataset(valid_df, transform)
test_data = mod_Dataset(test_df, transform)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True , num_workers=0, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_data, batch_size=8, shuffle=False, num_workers=0)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False, num_workers=0)

25000


In [7]:
class TeacherModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = torchvision.models.resnet34(weights=None)
        self.fc1 = nn.Linear(1000, num_classes)

    def forward(self, input):
        output = self.base_model(input)
        output = self.fc1(output)
        return output

In [8]:
class StudentModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = torchvision.models.resnet18(weights=None)
        self.fc1 = nn.Linear(1000, num_classes)

    def forward(self, input):
        output = self.base_model(input)
        output = self.fc1(output)
        return output

In [9]:
# model = StudentModel(2)
# # print(model)
# t = torch.randn((32, 3, 224, 224))
# print(t.shape)
# get = model(t)
# print(get.shape)

In [10]:
def check_accuracy(loader, model, device):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            scores = model(x)
            _, predictions = scores.max(1)
            _, y = y.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    model.train()
    return (num_correct/num_samples).item()
  

def train_teacher(epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher_model = TeacherModel(2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(teacher_model.parameters(), lr=1e-4)


    for epoch in range(epochs):
        teacher_model.train()
        losses = []

        pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
        for data, targets in pbar:
            data = data.to(device)
            targets = targets.to(device)

            # forward
            scores = teacher_model(data)
            loss = criterion(scores, targets)
            losses.append(loss.item())
            # backward
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        
        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, teacher_model, device)
        print(f"Loss:{avg_loss:.8f}\tAccuracy:{acc:.8f}")

    return teacher_model

In [11]:
get_teacher = train_teacher(3)

Epoch 0: 100%|██████████| 2343/2343 [01:41<00:00, 23.00it/s]


Loss:0.59953313	Accuracy:0.71719998


Epoch 1: 100%|██████████| 2343/2343 [01:37<00:00, 24.07it/s]


Loss:0.45504716	Accuracy:0.78679997


Epoch 2: 100%|██████████| 2343/2343 [01:39<00:00, 23.51it/s]


Loss:0.31818571	Accuracy:0.88720000


In [12]:
def train_step(
    teacher_model,
    student_model,
    optimizer,
    student_loss_fn,
    divergence_loss_fn,
    temp,
    alpha,
    epoch,
    device
):
    losses = []
    pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
    for data, targets in pbar:
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # forward
        with torch.no_grad():
            teacher_preds = teacher_model(data)

        student_preds = student_model(data)
        student_loss = student_loss_fn(student_preds, targets)
        
        ditillation_loss = divergence_loss_fn(
            F.log_softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        loss = alpha * student_loss + (1 - alpha) * ditillation_loss
        losses.append(loss.item())

        # backward
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    
    avg_loss = sum(losses) / len(losses)
    return avg_loss
  
def main(epochs, teacher, student, temp=7, alpha=0.3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher = teacher.to(device)
    student = student.to(device)
    student_loss_fn = nn.CrossEntropyLoss()
    divergence_loss_fn = nn.KLDivLoss(reduction="batchmean")
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)

    teacher.eval()
    student.train()
    for epoch in range(epochs):
        loss = train_step(
            teacher,
            student,
            optimizer,
            student_loss_fn,
            divergence_loss_fn,
            temp,
            alpha,
            epoch,
            device
        )
        acc = check_accuracy(test_loader, student, device)
        print(f"Loss:{loss:.8f}\tAccuracy:{acc:.8f}")

In [13]:
get_student = StudentModel(2)

In [14]:
main(3, get_teacher, get_student)

Epoch 0: 100%|██████████| 2343/2343 [01:29<00:00, 26.20it/s]


Loss:0.18874079	Accuracy:0.79759997


Epoch 1: 100%|██████████| 2343/2343 [02:51<00:00, 13.64it/s]


Loss:0.12273631	Accuracy:0.87279999


Epoch 2: 100%|██████████| 2343/2343 [03:08<00:00, 12.41it/s]


Loss:0.08168724	Accuracy:0.89559996


In [15]:
def train_student(epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    student_model = StudentModel(2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)


    for epoch in range(epochs):
        student_model.train()
        losses = []

        pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
        for data, targets in pbar:
            data = data.to(device)
            targets = targets.to(device)

            # forward
            scores = student_model(data)
            loss = criterion(scores, targets)
            losses.append(loss.item())
            # backward
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        
        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, student_model, device)
        print(f"Loss:{avg_loss:.8f}\tAccuracy:{acc:.8f}")

    return student_model

In [16]:
student_model = train_student(3)

Epoch 0: 100%|██████████| 2343/2343 [02:16<00:00, 17.13it/s]


Loss:0.58247208	Accuracy:0.78880000


Epoch 1: 100%|██████████| 2343/2343 [02:29<00:00, 15.63it/s]


Loss:0.40024273	Accuracy:0.85600001


Epoch 2: 100%|██████████| 2343/2343 [01:14<00:00, 31.41it/s]


Loss:0.27931961	Accuracy:0.89239997
