In [1]:
# !pip install git+https://github.com/serge-m/pytorch-nn-tools.git@master
# !pip install pytorch-nn-tools==0.3.8
# !pip install torch_lr_finder==0.2.1

In [2]:
# pip install -U albumentations==0.5.1

In [3]:
import os
from argparse import ArgumentParser, Namespace

import torch
import torch.utils.data
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json
import time


from pytorch_nn_tools.visual import ImgShow, tfm_vis_img, UnNormalize_, imagenet_stats
from pytorch_nn_tools.train.metrics.processor import mod_name_train, mod_name_val, Marker
from pytorch_nn_tools.train.metrics.processor import MetricAggregator, TensorBoardMetricLogger
from pytorch_nn_tools.metrics.accuracy import topk_accuracy
from pytorch_nn_tools.train.progress import ProgressTracker
from pytorch_nn_tools.convert import map_dict
from pytorch_nn_tools.train.metrics.history_condition import HistoryCondition
from pytorch_nn_tools.devices import to_device
import ml_dataset_tools as mdt
import albumentations as A
from albumentations.pytorch import ToTensorV2, ToTensor
from ds_io.continuous_sequence import ContinuousSequenceSampler, RandIntSampler
import random

In [4]:
from trainer.trainer_io import TrainerIO

In [5]:
ish = ImgShow(ax=plt)


In [10]:
num_workers = 2

batch_size_train, batch_size_val, device = 2, 2, 'cpu'
# batch_size_train, batch_size_val, device = 128, 128, 'cuda'

data_root_path = Path("data/")
data_path = data_root_path.joinpath("dataset")

In [11]:
size_h_w = 224, 224

imagenet_stats = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

cifar_stats = dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**cifar_stats),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**cifar_stats),
])


ds_tr  = datasets.CIFAR10(root=data_root_path, train=True, download=True, transform=transform_train)
ds_val = datasets.CIFAR10(root=data_root_path, train=False, download=False, transform=transform_test)

Files already downloaded and verified


In [12]:
ds = ds_tr

subsets = defaultdict(list)
for i, target in enumerate(ds.targets):
    subsets[target // (len(ds.classes) // 5)].append(i)

for i, lst in subsets.items():
    print(f"{i}: {lst[:5]}...")

3: [0, 7, 11, 12, 19]...
4: [1, 2, 8, 14, 15]...
2: [3, 10, 20, 27, 28]...
0: [4, 5, 29, 30, 32]...
1: [6, 9, 13, 17, 18]...


In [13]:
rs = random.Random(77)

sampler_tr = ContinuousSequenceSampler(
    ds_tr, 
    list(subsets.values()), 
    RandIntSampler(rs, 128, 256), 
    random_state=rs)


In [14]:
# for i, x in enumerate(sampler_tr):
#     if i < 10: 
#         print(x) 
#     pass
# print("asdasd")
# for i, x in enumerate(sampler_tr):
#     if i < 10: 
#         print(x) 
#     pass

In [15]:

train_dataloader = torch.utils.data.DataLoader(
        dataset=ds_tr,
        batch_size=batch_size_train,
        sampler=sampler_tr,
#         shuffle=True,
        num_workers=num_workers,
    )

    
val_dataloader = torch.utils.data.DataLoader(
        dataset=ds_val,
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )

In [16]:
train_dataloader.sampler

<ds_io.continuous_sequence.ContinuousSequenceSampler at 0x7fe8a6ae4710>

In [17]:
def publish_images(tb_writer, images, iteration_id):
    with torch.no_grad():
        vis = images.detach().clone()
        for v in vis:
            v[:] = UnNormalize_(**cifar_stats)(v)
        grid = torchvision.utils.make_grid(vis)
        tb_writer.add_image('images', grid, iteration_id)

In [18]:
from itertools import islice

In [19]:
class Trainer:
    def __init__(self, device, trainer_io: TrainerIO,
                continue_training: bool = False):
        self.device = device
        self.continue_training = continue_training
        self.trainer_io = trainer_io
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch, train_dataloader, val_dataloader):
        metric_logger = TensorBoardMetricLogger(self.trainer_io.tb_summary_writer)
        model = to_device(model, self.device)
        
        if self.continue_training:
            start_epoch = self.trainer_io.load_last(start_epoch, end_epoch, model, optimizer, scheduler)

        progr_train = ProgressTracker()
        
        for epoch in self.trainer_io.main_progress_bar(range(start_epoch, end_epoch)):
            metric_aggregator = MetricAggregator()
            self.train_epoch(
                train_dataloader, progr_train,
                model, optimizer, scheduler,  
                metric_proc=mod_name_train+metric_aggregator+metric_logger,
                report_step=100,
            )
            self.validate_epoch(
                val_dataloader,
                model,  
                metric_proc=mod_name_val+metric_aggregator+metric_logger,
            )
            
            aggregated = map_dict(metric_aggregator.aggregate(), key_fn=lambda key: f"avg.{key}")
            metric_logger({
                **aggregated, 
                **{f"lr_{i}": lr for i, lr in enumerate(scheduler.get_last_lr())},
                Marker.EPOCH: epoch,
            })
            self.trainer_io.set_main_status_msg(f"{aggregated}")
            self.trainer_io.save_checkpoint(aggregated, model, optimizer, scheduler, epoch)
            
        metric_logger.close()
            
    def train_epoch(self, data_loader, progr, model, optimizer, scheduler, metric_proc, report_step):
        model.train()
                
        for batch in self.trainer_io.secondary_progress_bar(progr.track(data_loader)):
            batch = to_device(batch, self.device)
            images, target = batch

            optimizer.zero_grad()
                        
            output = model(images)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            if progr.cnt_total_iter % report_step == 0:
                with torch.no_grad():
                    acc1, acc5 = topk_accuracy(output, target, topk=(1, 5))

                    metric_proc({
                        'loss': loss, 
                        'acc1': acc1, 
                        'acc5': acc5, 
                        Marker.ITERATION: progr.cnt_total_iter,
                        **{f"lr_{i}": lr for i, lr in enumerate(scheduler.get_last_lr())},
                    })

#             if batch_idx == 0 and tb_writer:
#                 publish_images(tb_writer, images, progr.cnt_total_iter)
            
#         scheduler.step()

            

    def validate_epoch(self, data_loader, model, metric_proc):
        model.eval()
        
        with torch.no_grad():
            for batch in self.trainer_io.secondary_progress_bar(data_loader):
                batch = to_device(batch, self.device)
                images, target = batch
                output = model(images)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = topk_accuracy(output, target, topk=(1, 5))
                metric_proc(dict(loss=loss, acc1=acc1, acc5=acc5))
                

In [20]:
# from net import resnet_from_pytorch_cifar 
# model = resnet_from_pytorch_cifar.ResNet18()
# model


In [21]:
from net import resnet 
model = resnet.ResNet(depth=20, num_classes=len(ds_tr.classes), block_name='BasicBlock', inplanes=64)
# model

In [22]:
# optimizer = torch.optim.AdamW([
#     {
#         'name': 'main_model',
#         'params': model.parameters(),
#         'lr': 1e-9,
#         'weight_decay': 5e-4,
#     }
# ])
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-9,
#                       momentum=0.9, weight_decay=5e-4)

In [23]:
# from torch_lr_finder import LRFinder, TrainDataLoaderIter

# class LRFinderDL(TrainDataLoaderIter):
#     def inputs_labels_from_batch(self, batch):
#         return batch['image'], batch['target']

# class LRFinderDL(TrainDataLoaderIter):
#     def inputs_labels_from_batch(self, batch):
#         return batch[0], batch[1]
    
    
# criterion = torch.nn.CrossEntropyLoss()
# lr_finder = LRFinder(model, optimizer, criterion, device=device)
# lr_finder.range_test(LRFinderDL(train_dataloader), val_loader=None, end_lr=1, num_iter=100, step_mode="exp")
# _, recommended_lr = lr_finder.plot(log_lr=False)
# lr_finder.reset()

In [24]:
recommended_lr = 0.1

In [25]:
from net import resnet 
model = resnet.ResNet(depth=20, num_classes=len(ds_tr.classes), block_name='BasicBlock', inplanes=64)


num_epochs = 50
optimizer = torch.optim.SGD([
    {
        'params': model.parameters(), 
        'lr': recommended_lr,
        'momentum' :0.9, 
        'weight_decay': 5e-4
    }
])

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=recommended_lr,
    epochs=num_epochs,
    steps_per_epoch=len(train_dataloader),
    pct_start=0.1,
)

if isinstance(train_dataloader.sampler, ContinuousSequenceSampler):
    ssls = train_dataloader.sampler.subsequence_len_sampler
    shuffling_type = f"seq{ssls.min_len}_{ssls.max_len}" 
elif isinstance(train_dataloader.sampler, torch.utils.data.sampler.RandomSampler):
    shuffling_type = "rand"
else:
    raise ValueError("unsupported sampler")

trainer_io = TrainerIO(
    log_dir="./logs/", experiment_name=f"cifar10_resnet20w64_lr{recommended_lr}__{shuffling_type}", 
    checkpoint_condition=HistoryCondition(
        'avg.val.acc1', 
        lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
    )
)

trainer = Trainer(device=device, trainer_io=trainer_io, continue_training=False)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=num_epochs,
#     train_dataloader=list(islice(train_dataloader, 0, 5)), 
#     val_dataloader=list(islice(val_dataloader, 0, 5))
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
)

KeyboardInterrupt: 

In [None]:
# !rm checkpoints/epoch_00001*
# !rm checkpoints/epoch_00002*
# !rm checkpoints/epoch_00003*
# !rm logs/experiment1/checkpoints/epoch_00004*