In [2]:
from pathlib import Path
from typing import Callable, List, Optional

import numpy as np
import pytorch_lightning as pl
import torch

torch.cuda.is_available()

False

# DataModule

In [1]:
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import CIFAR10

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(
        self,
        subset: Subset,
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = './data/',
        batch_size: int = 8,
        num_workers: int = 4,
        shuffle: bool = False,
        train_transforms: Optional[Callable] = None,
        val_transforms: Optional[Callable] = None,
        test_transforms: Optional[Callable] = None,
        val_size: float = 0.25,
    ) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.test_transforms = test_transforms
        self.val_size = val_size

    def prepare_data(self) -> None:
        Path(self.data_dir).mkdir(parents=True, exist_ok=True)
        CIFAR10(root=self.data_dir, train=True, download=True)
        CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit' or stage is None:
            trainval_dataset = CIFAR10(
                root=self.data_dir,
                train=True,
            )
            train_indices, val_indices = train_test_split(
                np.arange(len(trainval_dataset)),
                test_size=self.val_size,
                stratify=True,
            )
            train_subset = Subset(trainval_dataset, train_indices)
            val_subset = Subset(trainval_dataset, val_indices)

            train_transforms = self.default_transforms() \
                if self.train_transforms is None else self.train_transforms
            val_transforms = self.default_transforms() \
                if self.val_transforms is None else self.val_transforms
            self.train_dataset = \
                CIFAR10Dataset(train_subset, transform=train_transforms)
            self.val_dataset = \
                CIFAR10Dataset(val_subset, transform=val_transforms)

        if stage == 'test' or stage is None:
            test_transforms = self.default_transforms() \
                if self.test_transforms is None else self.test_transforms
            self.test_dataset = CIFAR10(
                root=self.data_dir,
                train=False,
                transform=test_transforms,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=self.shuffle,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def default_transforms(self) -> Callable:
        return transforms.Compose([
            transforms.ToTensor(),
        ])


dm = CIFAR10DataModule()

  rank_zero_deprecation(
  rank_zero_deprecation(
  rank_zero_deprecation(


# Models

In [33]:
import einops
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange
from torchvision.models import resnet18, resnet50, ResNet50_Weights

### Teacher

In [37]:
class ResNet50Model(nn.Module):
    def __init__(self, ):
        super().__init__()
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, input_mode: str = 'images'):
        features = None
        if input_mode == 'images':
            features = self.extract_features(x)
        elif input_mode == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)

        return x

### Student

In [38]:
class ResNet18Model(nn.Module):
    def __init__(self, ):
        super().__init__()
        model = resnet18(weights=None)
        self.feature_extractor = nn.Sequential(*(list(model.children())[:-2]))
        self.neck = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=2048, kernel_size=1),
            nn.BatchNorm2d(2048),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            Rearrange('bs c 1 1 -> bs c'),
            nn.Linear(in_features=2048, out_features=10, bias=True),
        )

    def forward(self, x, input_mode: str = 'images'):
        features = None
        if input_mode == 'images':
            features = self.extract_features(x)
        elif input_mode == 'features':
            features = x

        outputs = self.head(features)

        return outputs

    def extract_features(self, x):
        x = self.feature_extractor(x)
        x = self.neck(x)

        return x


In [39]:
student = ResNet18Model()
teacher = ResNet50Model()
x = torch.randn(8, 3, 224, 224)
assert student.extract_features(x).shape == teacher.extract_features(x).shape
assert teacher(x).shape == student(x).shape
print('Done!')

Done!
