<a href="https://colab.research.google.com/github/seonhe/React_Net/blob/Seonhy/xnor_net_cifar10_block_depthwise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-lightning
!pip install lightning-bolts



In [2]:
import os
import torch
import torchvision
from pytorch_lightning import seed_everything
from pl_bolts.datamodules import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

PATH_DATASETS = os.environ.get('PATH_DATASETS', '../datasets')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
NUM_WORKERS = int(os.cpu_count()/2)
TEST_ONLY = False

test_transforms = [
        torchvision.transforms.ToTensor(),
        cifar10_normalization(),        
        torchvision.transforms.Lambda(lambda x: x.clamp(min = -1, max = 1)),
        torchvision.transforms.Lambda(lambda x: x * 127),   
        torchvision.transforms.Lambda(lambda x: x.floor())  
]

train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
    ] 
    + test_transforms
)

test_transforms = torchvision.transforms.Compose(test_transforms)

cifar10_dm = CIFAR10DataModule(
    data_dir = PATH_DATASETS,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS,
    train_transforms = train_transforms,
    test_transforms = test_transforms,
    val_transforms = test_transforms,
)

seed_everything(seed=1234, workers=True)

  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
  "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
  "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
Global seed set to 1234


1234

In [3]:
from xnor_net import Model

binary_structure_GAP1 = [{'in_channels':3, 'out_channels':128, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0}, #32,32
             {'in_channels':128, 'out_channels':128, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':128, 'out_channels':256, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':256, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':512, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':512, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             ##FCL OR GAP
             {'in_channels':512, 'out_channels':1024, 'stride':1, 'kernel_size':8, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':1024, 'out_channels':512, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':512, 'out_channels':10, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'none', 'dropout':0},]



binary_structure_GAP2 = [{'in_channels':3, 'out_channels':128, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':128, 'out_channels':128, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':128, 'out_channels':256, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':256, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':512, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':512, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':512, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'pointwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':1024, 'stride':1, 'kernel_size':4, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':1024, 'out_channels':1024, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':1024, 'out_channels':10, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'none', 'dropout':0},]

binary_structure = [{'in_channels':3, 'out_channels':128, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':128, 'out_channels':128, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':128, 'out_channels':256, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':256, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':256, 'out_channels':512, 'stride':1, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':512, 'stride':2, 'kernel_size':3, 'padding':1, 'conv':'depthwise', 'act_fn':'sign', 'dropout':0},
             {'in_channels':512, 'out_channels':1024, 'stride':1, 'kernel_size':4, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':1024, 'out_channels':1024, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'sign', 'dropout':0.5},
             {'in_channels':1024, 'out_channels':10, 'stride':1, 'kernel_size':1, 'padding':0, 'conv':'scaled_sign', 'act_fn':'none', 'dropout':0},]






model = Model(structure=binary_structure_GAP2, 
              adam_init_lr=0.01, 
              lr_patience=50,
              limit_conv_weight=True,
              limit_bn_weight=True)
#print(model)
#print(model.hparams)

#real_model = Model.load_from_checkpoint('/content/drive/MyDrive/checkpoint/xnor_cifar10_bgap/lightning_logs/version_0/checkpoints/epoch=424-val_loss=0.5535-val_acc=0.8853.ckpt')
#model.load_state_dict(real_model.state_dict(), strict=False)

<All keys matched successfully>

In [4]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(filename='{epoch}-{val_loss:.4f}-{val_acc:.4f}', monitor='val_acc', mode='max')
trainer = Trainer(
    max_epochs=-1,
    gpus=AVAIL_GPUS,
    logger=TensorBoardLogger('/content/drive/MyDrive/checkpoint/xnor_cifar10/', log_graph=True),
    # logger=TensorBoardLogger('lightning_logs/', name='Real', log_graph=True),
    callbacks=[LearningRateMonitor(logging_interval='step'), 
               EarlyStopping(monitor='val_acc', mode='max', patience=100),
               checkpoint_callback],
    deterministic=True,
#    gradient_clip_val = 0.5
)

trainer.fit(model, datamodule=cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)
#print(checkpoint_callback.best_model_score)
print(checkpoint_callback.best_model_path)

Validation: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8755999803543091
        test_loss           0.5886259078979492
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

