In [1]:
import torch
import torchvision
import torchmetrics
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from utils import Fashion_MNIST_ResNet, Fashion_MNIST_MobileNet
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
print("Torch version is: ", torch.__version__) # Should see something like 2.1.0+cu118
print("Is CUDA available - ", torch.cuda.is_available())
# Just as a fail-safe, switch to CPU if CUDA not available, but training will be very slow
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pl.seed_everything(42)

Seed set to 42


Torch version is:  2.1.0
Is CUDA available -  True


42

In [3]:
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                  transforms.RandomVerticalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize((72.9404/255,), (90.0212/255,))
])

test_transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((72.9404/255,), (90.0212/255,))
])

In [4]:
# Here we load the download the train dataset
train_set = FashionMNIST('./data', download=True, train=True, transform=train_transform)
print(train_set.data.float().mean())
print(train_set.data.float().std())
# Here we load the download the test dataset
test_set = FashionMNIST('./data', download=True, train=False, transform=test_transform)
# Dictionary of the classes in the dataset
classes_dict = dict(enumerate(train_set.classes))
print(classes_dict)

tensor(72.9404)
tensor(90.0212)
{0: 'T-shirt/top', 1: 'Trouser', 2: 'Pullover', 3: 'Dress', 4: 'Coat', 5: 'Sandal', 6: 'Shirt', 7: 'Sneaker', 8: 'Bag', 9: 'Ankle boot'}


In [5]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=8)

In [6]:
teacher_model = Fashion_MNIST_ResNet.load_from_checkpoint("lightning_logs/resnet50/checkpoints/epoch=27-val_loss=0.250-val_acc=0.913.ckpt")

In [7]:
# PyTorch Lightning Module for Knowledge Distillation
class KnowledgeDistillationModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.teacher = teacher_model
        self.student = Fashion_MNIST_MobileNet()
        self.distillation_temperature = 10.0
        self.alpha = 0.0
        self.criterion = nn.CrossEntropyLoss()
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')
        self.acc = MulticlassAccuracy(num_classes=10)
        self.f1 = MulticlassF1Score(num_classes=10)

    def forward(self, x):
        return self.student(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        teacher_outputs = self.teacher(images)
        student_outputs = self.student(images)

        loss = self.distill_loss(student_outputs, teacher_outputs, labels)
        self.log('train_loss', loss)
        return loss

    def distill_loss(self, student_logits, teacher_logits, labels):
        # Calculate the soft targets for the KL Divergence loss
        soft_labels = torch.softmax(teacher_logits / self.distillation_temperature, dim=1)
        student_log_probs = torch.log_softmax(student_logits / self.distillation_temperature, dim=1)
        distillation_loss = self.kl_div_loss(student_log_probs, soft_labels.detach())

        # Calculate the student's standard loss
        student_loss = self.criterion(student_logits, labels)

        # Combine the losses with the distillation alpha weight
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss * (self.distillation_temperature ** 2)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)
        self.acc.update(outputs, targets)
        self.f1.update(outputs, targets)
        self.log("val_loss", loss)
        return loss

    def on_validation_epoch_end(self):
        e_loss = self.trainer.callback_metrics.get('val_loss')
        t_loss = self.trainer.callback_metrics.get('train_loss')
        e_acc = self.acc.compute()
        e_f1 = self.f1.compute()
        self.log("val_acc", e_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_f1", e_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        print(f"\n\nEpoch: {self.current_epoch} - Metrics: ")
        print(f"Training loss: {t_loss}, Validation loss:{e_loss:.4f}, Validation accuracy: {e_acc:.4f}, Validation F1: {e_f1:.4f}\n")
        self.acc.reset()
        self.f1.reset()

    def configure_optimizers(self):
        return optim.AdamW(self.student.parameters(), lr=1e-3)

In [8]:
model = KnowledgeDistillationModule()

checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode='max', filename='{epoch}-{val_loss:.3f}-{val_acc:.3f}', auto_insert_metric_name=True)

trainer = pl.Trainer(accelerator='gpu', max_epochs=30, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


/home/youxiang/anaconda3/envs/dl_proj/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                    

Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  9.67it/s]

Epoch: 0 - Metrics: 
Training loss: None, Validation loss:2.3026, Validation accuracy: 0.0634, Validation F1: 0.0420

Epoch 0: 100%|██████████| 938/938 [00:19<00:00, 47.76it/s, v_num=0]        

Epoch: 0 - Metrics: 
Training loss: 4.602952480316162, Validation loss:0.7631, Validation accuracy: 0.7870, Validation F1: 0.7798

Epoch 1: 100%|██████████| 938/938 [00:19<00:00, 48.44it/s, v_num=0, val_acc=0.787, val_f1=0.780]

Epoch: 1 - Metrics: 
Training loss: 3.7599284648895264, Validation loss:0.6593, Validation accuracy: 0.8139, Validation F1: 0.8084

Epoch 2: 100%|██████████| 938/938 [00:19<00:00, 46.93it/s, v_num=0, val_acc=0.814, val_f1=0.808]

Epoch: 2 - Metrics: 
Training loss: 1.469132900238037, Validation loss:0.5319, Validation accuracy: 0.8388, Validation F1: 0.8378

Epoch 3: 100%|██████████| 938/938 [00:20<00:00, 46.46it/s, v_num=0, val_acc=0.839, val_f1=0.838]

Epoch: 3 - Metrics: 
Training loss: 1.42

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 938/938 [00:20<00:00, 46.63it/s, v_num=0, val_acc=0.880, val_f1=0.879]
