In [1]:
# !pip install git+https://github.com/serge-m/pytorch-nn-tools.git@master
# !pip install pytorch-nn-tools==0.3.7
# !pip install torch_lr_finder==0.2.1

In [2]:
1

1

In [3]:
# pip install -U albumentations==0.5.1

In [4]:
import os
from argparse import ArgumentParser, Namespace

import torch
import torch.utils.data
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json
import time


from pytorch_nn_tools.visual import ImgShow, tfm_vis_img, UnNormalize_, imagenet_stats
from pytorch_nn_tools.train.metrics.processor import mod_name_train, mod_name_val, Marker
from pytorch_nn_tools.train.metrics.processor import MetricAggregator, TensorBoardMetricLogger
from pytorch_nn_tools.train.tensor_io.torch_xla_tensor_io import TorchXlaTensorIO
from pytorch_nn_tools.train.tensor_io.torch_tensor_io import TorchTensorIO
from pytorch_nn_tools.metrics.accuracy import topk_accuracy
from pytorch_nn_tools.train.progress import ProgressTracker
from pytorch_nn_tools.convert import map_dict
from pytorch_nn_tools.train.metrics.history_condition import HistoryCondition
from pytorch_nn_tools.devices import to_device
import ml_dataset_tools as mdt
import albumentations as A
from albumentations.pytorch import ToTensorV2, ToTensor

In [5]:
from trainer.trainer_io import TrainerIO

In [6]:
ish = ImgShow(ax=plt)


In [7]:
num_workers = 8

batch_size_train, batch_size_val, device = 2, 2, 'cpu'
# batch_size_train, batch_size_val, device = 128, 128, 'cuda'

data_root_path = Path("data/")
data_path = data_root_path.joinpath("dataset")

In [8]:
size_h_w = 224, 224

imagenet_stats = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

cifar_stats = dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**cifar_stats),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**cifar_stats),
])


ds_tr  = datasets.CIFAR10(root=data_root_path, train=True, download=True, transform=transform_train)
ds_val = datasets.CIFAR10(root=data_root_path, train=False, download=False, transform=transform_test)

train_dataloader = torch.utils.data.DataLoader(
        dataset=ds_tr,
        batch_size=batch_size_train,
        shuffle=True, 
        num_workers=num_workers,
    )

    
val_dataloader = torch.utils.data.DataLoader(
        dataset=ds_val,
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )

Files already downloaded and verified


In [9]:
def publish_images(tb_writer, images, iteration_id):
    with torch.no_grad():
        vis = images.detach().clone()
        for v in vis:
            v[:] = UnNormalize_(**cifar_stats)(v)
        grid = torchvision.utils.make_grid(vis)
        tb_writer.add_image('images', grid, iteration_id)

In [10]:
from itertools import islice

In [11]:
class Trainer:
    def __init__(self, device, trainer_io: TrainerIO,
                continue_training: bool = False):
        self.device = device
        self.continue_training = continue_training
        self.trainer_io = trainer_io
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch, train_dataloader, val_dataloader):
        metric_logger = TensorBoardMetricLogger(self.trainer_io.tb_summary_writer)
        model = to_device(model, self.device)
        
        if self.continue_training:
            start_epoch = self.trainer_io.load_last(start_epoch, end_epoch, model, optimizer, scheduler)

        progr_train = ProgressTracker()
        
        for epoch in self.trainer_io.main_progress_bar(range(start_epoch, end_epoch)):
            metric_aggregator = MetricAggregator()
            self.train_epoch(
                train_dataloader, progr_train,
                model, optimizer, scheduler,  
                metric_proc=mod_name_train+metric_aggregator+metric_logger,
                report_step=100,
            )
            self.validate_epoch(
                val_dataloader,
                model,  
                metric_proc=mod_name_val+metric_aggregator+metric_logger,
            )
            
            aggregated = map_dict(metric_aggregator.aggregate(), key_fn=lambda key: f"avg.{key}")
            metric_logger({
                **aggregated, 
                **{f"lr_{i}": lr for i, lr in enumerate(scheduler.get_last_lr())},
                Marker.EPOCH: epoch,
            })
            self.trainer_io.set_main_status_msg(f"{aggregated}")
            self.trainer_io.save_checkpoint(aggregated, model, optimizer, scheduler, epoch)
            
        metric_logger.close()
            
    def train_epoch(self, data_loader, progr, model, optimizer, scheduler, metric_proc, report_step):
        model.train()
                
        for batch in self.trainer_io.secondary_progress_bar(progr.track(data_loader)):
            batch = to_device(batch, self.device)
            images, target = batch

            optimizer.zero_grad()
                        
            output = model(images)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if progr.cnt_total_iter % report_step == 0:
                with torch.no_grad():
                    acc1, acc5 = topk_accuracy(output, target, topk=(1, 5))

                    metric_proc({
                        'loss': loss, 
                        'acc1': acc1, 
                        'acc5': acc5, 
                        Marker.ITERATION: progr.cnt_total_iter,
                        **{f"lr_{i}": lr for i, lr in enumerate(scheduler.get_last_lr())},
                    })

#             if batch_idx == 0 and tb_writer:
#                 publish_images(tb_writer, images, progr.cnt_total_iter)
            
#         scheduler.step()

            

    def validate_epoch(self, data_loader, model, metric_proc):
        model.eval()
        
        with torch.no_grad():
            for batch in self.trainer_io.secondary_progress_bar(data_loader):
                batch = to_device(batch, self.device)
                images, target = batch
                output = model(images)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = topk_accuracy(output, target, topk=(1, 5))
                metric_proc(dict(loss=loss, acc1=acc1, acc5=acc5))
                

In [12]:
from net import preact_resnet_from_pytorch_cifar 
# model = preact_resnet_from_pytorch_cifar.PreActResNet(preact_resnet_from_pytorch_cifar.PreActBlock, [2, 2, 2, 2])
def build_model():
    return preact_resnet_from_pytorch_cifar.PreActResNet(preact_resnet_from_pytorch_cifar.PreActBlock, [1, 1, 1, 1], num_planes=[64,64,64,64])


model = build_model()


In [13]:
from torchsummary import summary
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       PreActBlock-6           [-1, 64, 32, 32]               0
       BatchNorm2d-7           [-1, 64, 32, 32]             128
            Conv2d-8           [-1, 64, 16, 16]           4,096
            Conv2d-9           [-1, 64, 16, 16]          36,864
      BatchNorm2d-10           [-1, 64, 16, 16]             128
           Conv2d-11           [-1, 64, 16, 16]          36,864
      PreActBlock-12           [-1, 64, 16, 16]               0
      BatchNorm2d-13           [-1, 64, 16, 16]             128
           Conv2d-14             [-1, 6

In [14]:
# from net import resnet 
# model = resnet.ResNet(depth=20, num_classes=len(ds_tr.classes), block_name='BasicBlock', inplanes=64)
# model

In [15]:
# optimizer = torch.optim.AdamW([
#     {
#         'name': 'main_model',
#         'params': model.parameters(),
#         'lr': 1e-9,
#         'weight_decay': 5e-4,
#     }
# ])
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-9,
#                       momentum=0.9, weight_decay=5e-4)

In [16]:
# from torch_lr_finder import LRFinder, TrainDataLoaderIter

# class LRFinderDL(TrainDataLoaderIter):
#     def inputs_labels_from_batch(self, batch):
#         return batch['image'], batch['target']

# class LRFinderDL(TrainDataLoaderIter):
#     def inputs_labels_from_batch(self, batch):
#         return batch[0], batch[1]
    
    
# criterion = torch.nn.CrossEntropyLoss()
# lr_finder = LRFinder(model, optimizer, criterion, device=device)
# lr_finder.range_test(LRFinderDL(train_dataloader), val_loader=None, end_lr=1, num_iter=100, step_mode="exp")
# _, recommended_lr = lr_finder.plot(log_lr=False)
# lr_finder.reset()

In [15]:
recommended_lr = 0.1

In [16]:
model = build_model()

num_epochs = 50
optimizer = torch.optim.SGD([
    {
        'params': model.parameters(), 
        'lr': recommended_lr,
        'momentum' :0.9, 
        'weight_decay': 5e-4
    }
])

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=recommended_lr,
    epochs=num_epochs,
    steps_per_epoch=len(train_dataloader),
    pct_start=0.1,
)

trainer_io = TrainerIO(
    log_dir="./logs/", experiment_name=f"cifar10_preactresnet18pc_small_lr{recommended_lr}_sgdarr_onecpct0.1", 
    checkpoint_condition=HistoryCondition(
        'avg.val.acc1', 
        lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
    ),
    tensor_io=TorchTensorIO(),
    
)

In [None]:
# trainer_io.save_checkpoint({'avg.val.acc1':1.0}, model, optimizer, scheduler, 1000)
start_epoch = trainer_io.load_last(0, 50, model, optimizer, scheduler)

try logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00050.meta.json
try logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00049.meta.json
last 49
found pretrained results for epoch 49. Loading...
loading model from logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00049.pth


In [None]:
trainer = Trainer(device=device, trainer_io=trainer_io, continue_training=False)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=num_epochs,
#     train_dataloader=list(islice(train_dataloader, 0, 5)), 
#     val_dataloader=list(islice(val_dataloader, 0, 5))
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
)

In [None]:
1

In [None]:
# !rm checkpoints/epoch_00001*
# !rm checkpoints/epoch_00002*
# !rm checkpoints/epoch_00003*
# !rm logs/experiment1/checkpoints/epoch_00004*

In [None]:
trainer_io_exp = TrainerIO(
    log_dir="./logs/", experiment_name=f"cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024", 
    checkpoint_condition=HistoryCondition(
        'avg.val.acc1', 
        lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
    ),
    tensor_io=TorchXlaTensorIO()
)
num_epochs = 50

start_epoch = trainer_io_exp.load_last(0, 50, model, optimizer, scheduler)

try logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00050.meta.json
try logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00049.meta.json
last 49
found pretrained results for epoch 49. Loading...
loading model from logs/cifar10_preactresnet18pc_small_lr0.8_sgdarr_onecpct0.1_tpu_batch1024/checkpoints/epoch_00049.pth


In [40]:
pp = dict(model.named_parameters())

In [41]:
pp[]

{'conv1.weight': Parameter containing:
 tensor([[[[-0.1454, -0.1528,  0.1543],
           [ 0.1444, -0.1581, -0.0173],
           [-0.1294, -0.0269,  0.0492]],
 
          [[-0.0875, -0.1876,  0.1042],
           [-0.1637, -0.1723,  0.1534],
           [ 0.0701, -0.0654,  0.0997]],
 
          [[ 0.0074,  0.1355, -0.0675],
           [-0.1063, -0.1077, -0.1177],
           [-0.0358, -0.1217, -0.0281]]],
 
 
         [[[ 0.0543, -0.1739,  0.0202],
           [ 0.0050,  0.1386,  0.0739],
           [-0.1793, -0.1760, -0.0671]],
 
          [[-0.1293, -0.0699,  0.1698],
           [-0.0998, -0.1502,  0.1716],
           [ 0.0973, -0.1236,  0.1001]],
 
          [[-0.0913,  0.1462, -0.0346],
           [ 0.0944,  0.0149,  0.0149],
           [-0.0741,  0.1264, -0.1864]]],
 
 
         [[[ 0.0068, -0.0608,  0.1859],
           [-0.0516, -0.1264, -0.1379],
           [-0.0091, -0.1339,  0.1388]],
 
          [[-0.1666,  0.0402,  0.1793],
           [-0.0935, -0.1787,  0.1479],
           [-0