# Attention Transfer

This notebook demonstrates the implementation of this paper [Paying More Attention to Attention](https://arxiv.org/abs/1612.03928)

## Steps to transfer from a teacher to student model

- Load dataset and create dataloaders
- Create teacher and student models and load pretrained weights of teacher model
- Load the configuration YAML file
- Create `AttentionTransfer` object and pass the dataloaders, teacher model, student model and configuration
- Transfer the knowledge to student model by using `compress_model` method

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import sys
sys.path.append("../../../")

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

from tqdm import tqdm
import numpy as np
import yaml

from trailmet.algorithms.distill.attention_transfer import AttentionTransfer
from trailmet.datasets.classification import DatasetFactory
from trailmet.models import ModelsFactory
from trailmet.utils import accuracy

  from .autonotebook import tqdm as notebook_tqdm


# Dataset

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Augmentations

In [3]:
transform_train = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Pad(4, padding_mode='reflect'),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomCrop(32),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transforms1 = {
    'train': transform_train, 
    'val': transform_test, 
    'test': transform_test}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Dataset

In [4]:
cifar_dataset = DatasetFactory.create_dataset(name = 'CIFAR10', 
                                        root = "./data",
                                        split_types = ['train', 'val', 'test'],
                                        val_fraction = 0.1,
                                        transform = transforms1,
                                        target_transform = target_transforms,
                                        random_seed=42
                                        )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Define data loaders

In [5]:
train_loader = torch.utils.data.DataLoader(
        cifar_dataset['train'], batch_size=128, 
        sampler=cifar_dataset['train_sampler'],
        num_workers=2
    )

val_loader = torch.utils.data.DataLoader(
        cifar_dataset['val'], batch_size=128, 
        sampler=cifar_dataset['val_sampler'],
        num_workers=2
    )

test_loader = torch.utils.data.DataLoader(
        cifar_dataset['test'], batch_size=128, 
        sampler=cifar_dataset['test_sampler'],
        num_workers=2
    )

dataloaders = {'train':train_loader, 'val':val_loader, 'test':test_loader}

# Training

### Load configurations for training the student model

The configuration should contain distillation arguments including the training parameters such as total epochs, learning rates, milestones, etc. as well as the names of layers involved in Attention Transfer

__Note:__ Running on 5 epochs for demonstration purpose

In [6]:
with open("./resnet50-resnet18.yaml", 'r') as stream:
    data_loaded = yaml.safe_load(stream)
data_loaded

{'DISTILL_ARGS': {'BETA': 500,
  'EPOCHS': 5,
  'LR': 0.1,
  'WEIGHT_DECAY': 0.0005,
  'MILESTONES': [82, 123],
  'GAMMA': 0.1,
  'TEACHER_LAYER_NAMES': ['layer1', 'layer2', 'layer3', 'layer4'],
  'STUDENT_LAYER_NAMES': ['layer1', 'layer2', 'layer3', 'layer4'],
  'VERBOSE': True},
 'log_dir': 'at_resnet50-resnet18',
 'wandb': True,
 'insize': 32,
 'DEVICE': 'cuda:0'}

# Model

### Create the teacher and student models

In [7]:
teacher_model = ModelsFactory.create_model(name='resnet50', num_classes=100, version="original", pretrained=False, **data_loaded)
student_model = ModelsFactory.create_model(name='resnet18', num_classes=100, version="original", pretrained=False, **data_loaded)

teacher_model.to(device)
student_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activ): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)

### Training student model

In [8]:
distillation_box = AttentionTransfer(teacher_model, student_model, dataloaders, **data_loaded)

[34m[1mwandb[0m: Currently logged in as: [33manimesh-007[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
distillation_box.compress_model()

=====> TRAINING STUDENT NETWORK <=====


Training student network Epoch [0] (352 / 352 Steps) (batch time=0.05071s) (data time=0.00202s) (loss=2.29474): 100%|| 352/352 [00:19<00:00, 17.66it/s]
Validating student network Epoch [0] (40 / 40 Steps) (batch time=0.01823s) (loss=1.98314) (top1=25.00000) (top5=87.50000): 100%|| 40/40 [00:01<00:00, 33.49it/s]


 * acc@1 19.440 acc@5 82.340


Training student network Epoch [1] (352 / 352 Steps) (batch time=0.04935s) (data time=0.00363s) (loss=1.97412): 100%|| 352/352 [00:18<00:00, 18.70it/s]
Validating student network Epoch [1] (40 / 40 Steps) (batch time=0.01691s) (loss=1.74551) (top1=37.50000) (top5=100.00000): 100%|| 40/40 [00:01<00:00, 33.26it/s]


 * acc@1 30.880 acc@5 86.520


Training student network Epoch [2] (352 / 352 Steps) (batch time=0.04837s) (data time=0.00243s) (loss=1.65647): 100%|| 352/352 [00:18<00:00, 19.07it/s]
Validating student network Epoch [2] (40 / 40 Steps) (batch time=0.01680s) (loss=2.25188) (top1=12.50000) (top5=75.00000): 100%|| 40/40 [00:01<00:00, 35.75it/s]


 * acc@1 36.300 acc@5 85.800


Training student network Epoch [3] (352 / 352 Steps) (batch time=0.04921s) (data time=0.00337s) (loss=1.56491): 100%|| 352/352 [00:19<00:00, 18.46it/s]
Validating student network Epoch [3] (40 / 40 Steps) (batch time=0.01689s) (loss=1.94252) (top1=50.00000) (top5=87.50000): 100%|| 40/40 [00:01<00:00, 35.66it/s]


 * acc@1 43.560 acc@5 92.000


Training student network Epoch [4] (352 / 352 Steps) (batch time=0.04672s) (data time=0.00211s) (loss=1.41725): 100%|| 352/352 [00:18<00:00, 18.99it/s]
Validating student network Epoch [4] (40 / 40 Steps) (batch time=0.01688s) (loss=1.08645) (top1=75.00000) (top5=87.50000): 100%|| 40/40 [00:01<00:00, 34.26it/s]


 * acc@1 49.920 acc@5 95.520
