# DataLoader Setup

In [1]:
import os
import torch
import torchvision
from pytorch_lightning import seed_everything
from pl_bolts.datamodules import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

PATH_DATASETS = os.environ.get('PATH_DATASETS', '../datasets')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
NUM_WORKERS = int(os.cpu_count()/2)
TEST_ONLY = False

test_transforms = [
        torchvision.transforms.ToTensor(),
        cifar10_normalization(),        
        torchvision.transforms.Lambda(lambda x: x.clamp(min = -1, max = 1)),
        torchvision.transforms.Lambda(lambda x: x * 127),   
        torchvision.transforms.Lambda(lambda x: x.floor())  
]

train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
    ] 
    + test_transforms
)

test_transforms = torchvision.transforms.Compose(test_transforms)

cifar10_dm = CIFAR10DataModule(
    data_dir = PATH_DATASETS,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    train_transforms = train_transforms,
    test_transforms = test_transforms,
    val_transforms = test_transforms,
)

seed_everything(seed=1234, workers=True)

  rank_zero_deprecation(
  rank_zero_deprecation(
  rank_zero_deprecation(
Global seed set to 1234


1234

# model 

In [2]:
from model import ReactModel

base_model = [{'in_channels':3, 'out_channels':32, 'stride':1 }, # block : react or baseline
            {'in_channels':32, 'out_channels':64, 'block':'baseline'}, # 16
            {'in_channels':64, 'out_channels':128, 'block':'baseline'}, # 8
            {'in_channels':128, 'out_channels':256, 'block':'baseline'}, # 4
            {'in_channels':256, 'out_channels':256, 'block':'baseline'}, # 4
            {'in_channels':256, 'out_channels':512, 'block':'baseline'}, # 1
            {'in_channels':512, 'out_channels':10, 'block':'baseline'}]

baselinemodel = ReactModel(structure=base_model, 
              adam_init_lr=0.01, 
              lr_patience=50,
              limit_conv_weight=True,
              limit_bn_weight=True)

print(baselinemodel)
print(baselinemodel.hparams)

ReactModel(
  (blocks): ModuleList(
    (0): firstconv3x3(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Reduction_Block(
      (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (layer1): Sequential(
        (0): Sign(in_channels=32)
        (1): GeneralConv2d(32, 32, kernel_size=3, stride=2, padding=1, conv=real)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (layer2_1): Sequential(
        (0): Sign(in_channels=32)
        (1): GeneralConv2d(32, 32, kernel_size=1, stride=1, padding=0, conv=real)
        (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (layer2_2): Sequential(
        (0): Sign(in_channels=32)
        (1): GeneralConv2d(32, 32, kernel_size=1, stride=1, padding=0, conv=real)
        (2): BatchNorm2

In [3]:
baselinemodel(baselinemodel.example_input_array)


AttributeError: 'Reduction_Block' object has no attribute 'in_rprelu'

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(filename='{epoch}-{val_loss:.4f}-{val_acc:.4f}', monitor='val_acc', mode='max')
trainer = Trainer(
    max_epochs=-1,
    gpus=AVAIL_GPUS,
    logger=TensorBoardLogger('lightning_logs/', name='ReActNet', log_graph=True),
    # logger=TensorBoardLogger('lightning_logs/', name='Real', log_graph=True),
    callbacks=[LearningRateMonitor(logging_interval='step'), 
               EarlyStopping(monitor='val_acc', mode='max', patience=100),
               checkpoint_callback],
    deterministic=True,
#    gradient_clip_val = 0.5
)

trainer.fit(base_model, datamodule=cifar10_dm)