In [3]:
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch

print(f'CUDA: {torch.cuda.is_available()}')

CUDA: False


# DataModule

In [4]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import CIFAR10

In [5]:
class CIFAR10Dataset(Dataset):
    def __init__(
        self,
        subset: Subset,
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

In [6]:
class CIFAR10DataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = './data/',
        batch_size: int = 8,
        num_workers: int = 4,
        shuffle: bool = False,
        train_transforms: Optional[Callable] = None,
        val_transforms: Optional[Callable] = None,
        test_transforms: Optional[Callable] = None,
        val_size: float = 0.25,
    ) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.test_transforms = test_transforms
        self.val_size = val_size

    def prepare_data(self) -> None:
        Path(self.data_dir).mkdir(parents=True, exist_ok=True)
        CIFAR10(root=self.data_dir, train=True, download=True)
        CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit' or stage is None:
            trainval_dataset = CIFAR10(
                root=self.data_dir,
                train=True,
            )
            train_indices, val_indices = train_test_split(
                np.arange(len(trainval_dataset)),
                test_size=self.val_size,
            )
            train_subset = Subset(trainval_dataset, train_indices)
            val_subset = Subset(trainval_dataset, val_indices)

            train_transforms = self.default_transforms() \
                if self.train_transforms is None else self.train_transforms
            val_transforms = self.default_transforms() \
                if self.val_transforms is None else self.val_transforms
            self.train_dataset = \
                CIFAR10Dataset(train_subset, transform=train_transforms)
            self.val_dataset = \
                CIFAR10Dataset(val_subset, transform=val_transforms)

        if stage == 'test' or stage is None:
            test_transforms = self.default_transforms() \
                if self.test_transforms is None else self.test_transforms
            self.test_dataset = CIFAR10(
                root=self.data_dir,
                train=False,
                transform=test_transforms,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=self.shuffle,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def default_transforms(self) -> Callable:
        return A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])


dm = CIFAR10DataModule()

  rank_zero_deprecation(
  rank_zero_deprecation(
  rank_zero_deprecation(


# Models

In [7]:
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange
from torchvision.models import resnet18, resnet50, ResNet50_Weights

### Teacher

In [8]:
class ResNet50Model(nn.Module):
    def __init__(self, pretrained: bool = False):
        super().__init__()
        weights = None if pretrained is True else ResNet50_Weights.IMAGENET1K_V1
        model = resnet50(weights=weights)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, input_mode: str = 'images'):
        features = None
        if input_mode == 'images':
            features = self.extract_features(x)
        elif input_mode == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)

        return x

### Student

In [9]:
class ResNet18Model(nn.Module):
    def __init__(self, ):
        super().__init__()
        model = resnet18(weights=None)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.neck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=2048, kernel_size=1),
            nn.BatchNorm2d(2048),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, input_mode: str = 'images'):
        features = None
        if input_mode == 'images':
            features = self.extract_features(x)
        elif input_mode == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)
        x = self.neck(x)

        return x


In [10]:
student = ResNet18Model()
teacher = ResNet50Model()
x = torch.randn(8, 3, 224, 224)
assert student.extract_features(x).shape == teacher.extract_features(x).shape
assert teacher(x).shape == student(x).shape
print('Done!')

Done!


# Modules

In [11]:
import torch.optim as optim
from pytorch_lightning import LightningModule
from torch import Tensor

### Classification Module

In [12]:
class ClassificationModule(LightningModule):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def training_step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        images, labels = batch
        predicts = self.model(images)
        loss = self.criterion(predicts, images)

        return {
            'train/loss': loss,
        }

    def validation_step(self, batch: Any, batch_idx: int) -> Dict[str, Tensor]:
        return self.training_step(batch, batch_idx)

    def configure_optimizers(self):
        return optim.Adam(params=self.model.parameters(), lr=3e-4)

### Distillation Module

In [None]:
class DistillationModule(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def configure_optimizers(self):
        pass

# Train loops

In [13]:
from pytorch_lightning import Trainer

### Classification Teacher

In [14]:
trainer = Trainer()
model = ResNet50Model()
module = ClassificationModule(model=model)
datamodule = CIFAR10DataModule()
trainer.fit(module, datamodule=datamodule)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
  rank_zero_deprecation(
  rank_zero_deprecation(


Files already downloaded and verified
Files already downloaded and verified


Missing logger folder: /Users/sergevkim/work/samogonka/notebooks/lightning_logs


TypeError: Singleton array array(True) cannot be considered a valid collection.

### Classification Student

### Distillation