# Colab

In [None]:
!pip install albumentations einops pytorch_lightning wandb

# Imports

In [1]:
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch

print(f'CUDA: {torch.cuda.is_available()}')

CUDA: False


# DataModule

In [2]:
import albumentations as A
import numpy as np
from albumentations.pytorch import ToTensorV2
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import CIFAR10

In [3]:
class CIFAR10Dataset(Dataset):
    def __init__(
        self,
        subset: Subset,
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        self.subset = subset
        self.transform = transform

    def __len__(self) -> int:
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]
        # be careful! image is not a numpy array

        if self.transform is not None:
            image = np.array(image)
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

In [4]:
class CIFAR10DataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = './data/',
        batch_size: int = 8,
        num_workers: int = 4,
        shuffle: bool = False,
        train_transforms: Optional[Callable] = None,
        val_transforms: Optional[Callable] = None,
        test_transforms: Optional[Callable] = None,
        val_size: float = 0.25,
    ) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.test_transforms = test_transforms
        self.val_size = val_size

    def prepare_data(self) -> None:
        Path(self.data_dir).mkdir(parents=True, exist_ok=True)
        CIFAR10(root=self.data_dir, train=True, download=True)
        CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit' or stage is None:
            trainval_dataset = CIFAR10(root=self.data_dir, train=True)
            train_indices, val_indices = train_test_split(
                np.arange(len(trainval_dataset)),
                test_size=self.val_size,
            )
            train_subset = Subset(trainval_dataset, train_indices)
            val_subset = Subset(trainval_dataset, val_indices)

            train_transforms = self.default_transforms() \
                if self.train_transforms is None else self.train_transforms
            val_transforms = self.default_transforms() \
                if self.val_transforms is None else self.val_transforms
            self.train_dataset = \
                CIFAR10Dataset(train_subset, transform=train_transforms)
            self.val_dataset = \
                CIFAR10Dataset(val_subset, transform=val_transforms)

        if stage == 'test' or stage is None:
            test_transforms = self.default_transforms() \
                if self.test_transforms is None else self.test_transforms
            self.test_dataset = CIFAR10(
                root=self.data_dir,
                train=False,
                transform=test_transforms,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=self.shuffle,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def default_transforms(self) -> Callable:
        return A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

# Models

In [5]:
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange
from torchvision.models import resnet18, resnet50, ResNet50_Weights

### Teacher

In [6]:
class ResNet50Model(nn.Module):
    def __init__(self, pretrained: bool = False):
        super().__init__()
        weights = None if pretrained is True else ResNet50_Weights.IMAGENET1K_V1
        model = resnet50(weights=weights)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, inputs: str = 'images'):
        features = None
        if inputs == 'images':
            features = self.extract_features(x)
        elif inputs == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)

        return x

### Student

In [7]:
class ResNet18Model(nn.Module):
    def __init__(self, ):
        super().__init__()
        model = resnet18(weights=None)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.neck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=2048, kernel_size=1),
            nn.BatchNorm2d(2048),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, inputs: str = 'images'):
        features = None
        if inputs == 'images':
            features = self.extract_features(x)
        elif inputs == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)
        x = self.neck(x)

        return x

In [8]:
student = ResNet18Model()
teacher = ResNet50Model()
x = torch.randn(8, 3, 224, 224)
assert student.extract_features(x).shape == teacher.extract_features(x).shape
assert teacher(x).shape == student(x).shape
print('Done!')

Done!


# Modules

In [9]:
import torch.optim as optim
from pytorch_lightning import LightningModule
from torch import Tensor
from torchmetrics import Accuracy

### Classification Module

In [10]:
class ClassificationModule(LightningModule):
    def __init__(self, model: nn.Module, learning_rate: float = 3e-4) -> None:
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_metric = Accuracy()
        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def _step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        images, labels = batch
        predicts = self.model(images)
        loss = self.criterion(predicts, labels)
        accuracy = self.accuracy_metric(predicts, labels)

        info = {'loss': loss, 'accuracy': accuracy}

        return info

    def training_step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        info = self._step(batch, batch_idx)
        self.log('train', info)

        return info

    def validation_step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        info = self._step(batch, batch_idx)
        self.log('val', info)

        return info

    def test_step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        info = self._step(batch, batch_idx)
        self.log('test', info)

        return info

    def configure_optimizers(self):
        return optim.Adam(params=self.model.parameters(), lr=self.learning_rate)

### Distillation Module

In [11]:
class DistillationModule(LightningModule):
    def __init__(self, teacher: nn.Module, student: nn.Module) -> None:
        super().__init__()
        self.teacher = teacher
        self.student = student

    def configure_optimizers(self):
        pass

# Train loops

In [12]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

In [15]:
import wandb

ModuleNotFoundError: No module named 'wandb'

### Classification Teacher

In [13]:
model = ResNet50Model()
module = ClassificationModule(model=model)
datamodule = CIFAR10DataModule()
logger = WandbLogger(project='samogonka')
trainer = Trainer(accelerator='cpu', logger=logger, max_epochs=30)
trainer.fit(module, datamodule=datamodule)

ModuleNotFoundError: You want to use `wandb` logger which is not installed yet, install it with `pip install wandb`.

### Classification Student

In [13]:
model = ResNet18Model()
module = ClassificationModule(model=model)
datamodule = CIFAR10DataModule(batch_size=4096, num_workers=4)
logger = WandbLogger(project='samogonka')
trainer = Trainer(accelerator='gpu', logger=logger, log_every_n_steps=2, max_epochs=30)
trainer.fit(module, datamodule=datamodule)

NameError: name 'logger' is not defined

### Distillation

In [None]:
trainer = Trainer(accelerator='gpu', logger=logger, log_every_n_steps=2, max_epochs=30)
teacher = ResNet50Model()
student = ResNet18Model()
module = DistillationModule(teacher=teacher, student=student)
datamodule = CIFAR10DataModule(batch_size=4096, num_workers=4)
trainer.fit(module, datamodule=datamodule)