# Factor Transfer

This notebook demonstrates the implementation of this paper [Paraphrasing Complex Network: Network Compression via Factor Transfer](https://arxiv.org/abs/1802.04977)

## Steps to transfer from a teacher to student model

- Load dataset and create dataloaders
- Create teacher and student models and load pretrained weights of teacher model
- Load the configuration YAML file
- Create `FactorTransfer` object and pass the dataloaders, teacher model, student model and configuration
- Transfer the knowledge to student model by using `compress_model` method

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import sys
sys.path.append("../../../")

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.hub import load_state_dict_from_url

from tqdm import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
import yaml

from trailmet.algorithms.distill.factor_transfer import FactorTransfer
from trailmet.datasets.classification import DatasetFactory
from trailmet.models import ModelsFactory

# Dataset

### Define data loaders

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Augmentations

In [4]:
transform_train = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Pad(4, padding_mode='reflect'),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomCrop(32),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transform_test = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                ]
            )

transforms1 = {
    'train': transform_train, 
    'val': transform_test, 
    'test': transform_test}

target_transforms = {
    'train': None, 
    'val': None, 
    'test': None}

### Load Dataset

In [5]:
cifar_dataset = DatasetFactory.create_dataset(name = 'CIFAR10', 
                                        root = "./data",
                                        split_types = ['train', 'val', 'test'],
                                        val_fraction = 0.1,
                                        transform = transforms1,
                                        target_transform = target_transforms,
                                        random_seed=42
                                        )

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Define data loaders

In [6]:
train_loader = torch.utils.data.DataLoader(
        cifar_dataset['train'], batch_size=128, 
        sampler=cifar_dataset['train_sampler'],
        num_workers=2
    )

val_loader = torch.utils.data.DataLoader(
        cifar_dataset['val'], batch_size=128, 
        sampler=cifar_dataset['val_sampler'],
        num_workers=2
    )

test_loader = torch.utils.data.DataLoader(
        cifar_dataset['test'], batch_size=128, 
        sampler=cifar_dataset['test_sampler'],
        num_workers=2
    )

dataloaders = {'train':train_loader, 'val':val_loader, 'test':test_loader}

### Load configurations for training the student model

The configuration should contain distillation arguments including the training parameters such as total epochs, learning rates, milestones, etc. as well as the name of layer involved in Factor Transfer

__Note:__ Running on 5 epochs for demonstration purpose

In [9]:
with open("./resnet50-resnet18.yaml", 'r') as stream:
    data_loaded = yaml.safe_load(stream)
data_loaded

{'DISTILL_ARGS': {'BETA': 500,
  'EPOCHS': 5,
  'LR': 0.1,
  'WEIGHT_DECAY': 0.0005,
  'MILESTONES': [82, 123],
  'GAMMA': 0.1,
  'IN_PLANES': 512,
  'RATE': 2,
  'TEACHER_LAYER_NAME': 'layer4',
  'STUDENT_LAYER_NAME': 'layer4',
  'VERBOSE': True},
 'PARAPHRASER': {'IN_PLANES': 2048,
  'RATE': 0.5,
  'EPOCHS': 1,
  'LR': 0.1,
  'WEIGHT_DECAY': 0.0005},
 'log_dir': 'ft_resnet50-resnet18',
 'DEVICE': 'cuda',
 'insize': 32,
 'wandb': True}

# Model

### Create the teacher and student models

In [10]:
teacher_model = ModelsFactory.create_model(name='resnet50', num_classes=100, version="original", pretrained=False, **data_loaded)
student_model = ModelsFactory.create_model(name='resnet18', num_classes=100, version="original", pretrained=False, **data_loaded)

teacher_model.to(device)
student_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activ): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)

# Training

In [11]:
distillation_box = FactorTransfer(teacher_model, student_model, dataloaders, **data_loaded)

[34m[1mwandb[0m: Currently logged in as: [33manimesh-007[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
distillation_box.compress_model()

=====> TRAINING PARAPHRASER <=====


Training paraphraser Epoch [0] (352 / 352 Steps) (batch time=0.03953s) (data time=0.00195s) (loss=nan): 100%|| 352/352 [00:17<00:00, 20.62it/s]              


=====> TRAINING STUDENT NETWORK <=====


Training student network Epoch [0] (352 / 352 Steps) (batch time=0.03301s) (data time=0.00273s) (loss=nan): 100%|| 352/352 [00:13<00:00, 25.34it/s]
Validating student network Epoch [0] (40 / 40 Steps) (batch time=0.01934s) (loss=nan) (top1=50.00000) (top5=75.00000): 100%|| 40/40 [00:01<00:00, 32.21it/s]


 * acc@1 43.060 acc@5 90.160


Training student network Epoch [1] (352 / 352 Steps) (batch time=0.03138s) (data time=0.00216s) (loss=nan): 100%|| 352/352 [00:13<00:00, 26.63it/s]
Validating student network Epoch [1] (40 / 40 Steps) (batch time=0.02281s) (loss=nan) (top1=75.00000) (top5=100.00000): 100%|| 40/40 [00:01<00:00, 31.59it/s]


 * acc@1 53.040 acc@5 94.400


Training student network Epoch [2] (352 / 352 Steps) (batch time=0.03378s) (data time=0.00341s) (loss=nan): 100%|| 352/352 [00:13<00:00, 26.40it/s]
Validating student network Epoch [2] (40 / 40 Steps) (batch time=0.01816s) (loss=nan) (top1=75.00000) (top5=100.00000): 100%|| 40/40 [00:01<00:00, 34.03it/s]


 * acc@1 63.380 acc@5 96.460


Training student network Epoch [3] (352 / 352 Steps) (batch time=0.03521s) (data time=0.00269s) (loss=nan): 100%|| 352/352 [00:13<00:00, 26.68it/s]
Validating student network Epoch [3] (40 / 40 Steps) (batch time=0.01729s) (loss=nan) (top1=75.00000) (top5=100.00000): 100%|| 40/40 [00:01<00:00, 32.44it/s]


 * acc@1 65.140 acc@5 97.480


Training student network Epoch [4] (352 / 352 Steps) (batch time=0.03298s) (data time=0.00200s) (loss=nan): 100%|| 352/352 [00:13<00:00, 26.96it/s]
Validating student network Epoch [4] (40 / 40 Steps) (batch time=0.01989s) (loss=nan) (top1=75.00000) (top5=100.00000): 100%|| 40/40 [00:01<00:00, 30.70it/s]


 * acc@1 66.780 acc@5 96.980
