# DataLoader Setup

In [None]:
import os
import torch
import torchvision
from pytorch_lightning import seed_everything
from pl_bolts.datamodules import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

PATH_DATASETS = os.environ.get('PATH_DATASETS', '../datasets')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
NUM_WORKERS = int(os.cpu_count()/2)
TEST_ONLY = False

test_transforms = [
        torchvision.transforms.ToTensor(),
        cifar10_normalization(),        
        torchvision.transforms.Lambda(lambda x: x.clamp(min = -1, max = 1)),
        torchvision.transforms.Lambda(lambda x: x * 127),   
        torchvision.transforms.Lambda(lambda x: x.floor())  
]

train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
    ] 
    + test_transforms
)

test_transforms = torchvision.transforms.Compose(test_transforms)

cifar10_dm = CIFAR10DataModule(
    data_dir = PATH_DATASETS,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    train_transforms = train_transforms,
    test_transforms = test_transforms,
    val_transforms = test_transforms,
)

seed_everything(seed=1234, workers=True)

# model 

In [None]:
from model import Model

teacher_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet44", pretrained=True)

# model_structure = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2
 
# model_structure = [64] * 2 + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2

model_structure =[{'in_channels':3, 'out_channels':128, 'stride':2, 'padding':1}, #16
            {'in_channels':128,'kernel_size':3,'stride':2,'padding':1}, # Reduction out(512,256, 8, 8)
            {'in_channels':256,'kernel_size':3,'stride':2,'padding':1}, # Reduction out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':1,'padding':1}, # Normal    out(512,512, 4, 4)
            {'in_channels':512,'kernel_size':3,'stride':2,'padding':1}, # Reduction out(512,1024,2, 2)   
            {'in_channels':1024,'kernel_size':3,'stride':1,'padding':1},# Normal    out(512,1024,2, 2)
            {'in_channels':1024,'kernel_size':3,'stride':1,'padding':1},# Normal    out(512,1024,2, 2)
            {'in_channels':1024,'kernel_size':3,'stride':1,'padding':1},# Normal    out(512,1024,2, 2)
            ] # Reduction
 

react_model = Model(
    structure=model_structure,
    adam_init_lr=1e-2, 
    adam_weight_decay=0,
    adam_betas=(0.9, 0.999),
    lr_reduce_factor=0.1,
    lr_patience=20,
    limit_conv_weight=True,
    limit_bn_weight=True,
    teacher_model=teacher_model
    )

print(react_model.hparams)
print(react_model)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(filename='{epoch}-{val_loss:.4f}-{val_acc:.4f}', monitor='val_acc', mode='max')
trainer = Trainer(
    max_epochs=-1,
    gpus=AVAIL_GPUS,
    logger=TensorBoardLogger('lightning_logs/', name='128-256-512-1024', log_graph=True),
    # logger=TensorBoardLogger('lightning_logs/', name='Real', log_graph=True),
    callbacks=[LearningRateMonitor(logging_interval='step'), 
               EarlyStopping(monitor='val_acc', mode='max', patience=100),
               checkpoint_callback],
    deterministic=True,
#    gradient_clip_val = 0.5
)

trainer.fit(react_model, datamodule=cifar10_dm)

In [None]:
react_model.state_dict()

In [None]:
trainer.test(react_model,cifar10_dm)


In [None]:
print(checkpoint_callback.best_model_path)
print(checkpoint_callback.best_model_score)