In [17]:
# prevent random files being included in dataset
!rm -rf `find -type d -name .ipynb_checkpoints`

In [18]:
import papermill as pm
import mlflow
import torch
from utils import md5_dir, set_seed
from torch.utils.data import DataLoader
import tqdm.auto as tqdm

from loss_functions import kd_loss
from datasets import TrainImageNetDataset

import torch
import torch.nn.functional as F

# if using pretrained model
from torchvision.models import ResNet50_Weights

In [19]:
set_seed(42)
torch.cuda.empty_cache()

In [20]:
# Default Parameters
run_id = "5363b9a6f6954354b7bb68d535b5ea88"
train_data_path = "../data/ImageNet/ILSVRC/Data/CLS-LOC/train/"
# test_data_path = "data/ImageNet/ILSVRC/Data/CLS-LOC/val/"

train_data_labels_path = "../data/ImageNet/LOC_train_solution.csv"
# test_data_labels_path = "data/ImageNet/LOC_val_solution.csv"

label_mapping_path = "../data/ImageNet/LOC_synset_mapping.txt"

resnet50_weights = ResNet50_Weights.DEFAULT

preprocess = resnet50_weights.transforms()

device = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
train_dataset = TrainImageNetDataset(train_data_path, train_data_labels_path, label_mapping_path, preprocess)

# test_dataset = TestImageNetDataset(test_data_path, test_data_labels_path, label_mapping_path, preprocess)

train_dataloader = DataLoader(train_dataset, batch_size=80, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=80, shuffle=False)

/home/vincent/projects/MLRC_2023/data/ImageNet/LOC_synset_mapping.txt


In [22]:
# if using pretrained model
from torchvision.models import resnet50, ResNet50_Weights, resnet18

resnet50_pretrained_weights = ResNet50_Weights.DEFAULT

teacher = resnet50(weights=resnet50_pretrained_weights)

student = resnet18(weights=None)
independent_student = resnet18(weights=None)

In [23]:
def train_student(student, teacher, train_dataloader, criterion, optimizer, epochs, device):
    """
    - student: The smaller, untrained model that uses the teacher's output as an additional label
    - teacher: The pretrained model used to help the student model learn
    - train_dataloader: Dataloader for training data
    - criterion: The loss function
    - optimizer: The optimization algorithm
    - epochs: Number of training epochs
    - device: Device to run training
    """
    teacher.eval()
    teacher.to(device)
    student.train()
    student.to(device)
    
    for epoch in tqdm(range(epochs), position = 0, leave = True):
        running_loss = 0.0

        for inputs, labels in tqdm(train_dataloader, position = 0, leave = True):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = F.one_hot(labels, num_classes=1000).float()

            # Zero the gradients 
            optimizer.zero_grad()

            teacher_predictions = teacher(inputs)
            student_predictions = student(inputs)

            loss = criterion(student_predictions, labels, teacher_predictions, 0.5, 0.5)

            loss.backward()

            optimizer.step()

            running_loss += loss.item()

        average_loss = running_loss / len(train_dataloader)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {average_loss:.4f}')

        # save training loss in mlflow
        with mlflow.start_run(run_id=run_id) as run:
            mlflow.log_metric("student_training_loss", average_loss)

    with mlflow.start_run(run_id=run_id) as run:
        mlflow.pytorch.log_model(
            pytorch_model=teacher.to("cpu"),
            artifact_path="teacher",
        )

        mlflow.pytorch.log_model(
            pytorch_model=student.to("cpu"),
            artifact_path="student"
        )
        
        


optimizer = torch.optim.Adam(student.parameters(), lr = 0.0001)
epochs = 1
train_student(student, teacher, train_dataloader, kd_loss, optimizer, epochs, device)

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                 | 0/16015 [00:00<?, ?it/s][A
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.36it/s]


Epoch [1/1], Loss: 0.0003


In [24]:
def train_independent_student(independent_student, train_dataloader, criterion, optimizer, epochs, device):
    """
    - teacher: The pretrained model used to help the student model learn
    - student: The smaller, untrained model that uses the teacher's output as an additional label
    - criterion: The loss function
    - optimizer: The optimization algorithm
    - epochs: Number of training epochs
    - device: Device to run training
    """
    independent_student.train()
    independent_student.to(device)
    
    for epoch in tqdm.tqdm(range(epochs)):
        running_loss = 0.0

        for inputs, labels in tqdm.tqdm(train_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = F.one_hot(labels, num_classes=1000).float()

            # Zero the gradients 
            optimizer.zero_grad()

            independent_student_predictions = independent_student(inputs)

            loss = criterion(independent_student_predictions, labels)

            loss.backward()

            optimizer.step()

            running_loss += loss.item()
            break

        average_loss = running_loss / len(train_dataloader)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {average_loss:.4f}')


        # save training loss in mlflow
        with mlflow.start_run(run_id=run_id) as run:
            mlflow.log_metric("independent_student_training_loss", average_loss)

    with mlflow.start_run(run_id=run_id) as run:
        mlflow.pytorch.log_model(
            pytorch_model=independent_student.to("cpu"),
            artifact_path="independent_student"
        )

optimizer = torch.optim.Adam(independent_student.parameters(), lr = 0.0001)
epochs = 1
criterion = torch.nn.CrossEntropyLoss()
train_independent_student(independent_student, train_dataloader, criterion, optimizer, epochs, device)

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                 | 0/16015 [00:00<?, ?it/s][A
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  2.36it/s]


Epoch [1/1], Loss: 0.0004
