In [None]:
# !pip install git+https://github.com/serge-m/pytorch-nn-tools.git@master
# !pip install pytorch-nn-tools==0.3.7
# !pip install torch_lr_finder==0.2.1

In [None]:
# pip install -U albumentations==0.5.1

In [None]:
import os
from argparse import ArgumentParser, Namespace

import torch
import torch.utils.data
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json
import time


from pytorch_nn_tools.visual import ImgShow, tfm_vis_img, UnNormalize_, imagenet_stats
from pytorch_nn_tools.train.metrics.processor import mod_name_train, mod_name_val, Marker
from pytorch_nn_tools.train.metrics.processor import MetricAggregator, TensorBoardMetricLogger
from pytorch_nn_tools.metrics.accuracy import topk_accuracy
from pytorch_nn_tools.train.progress import ProgressTracker
from pytorch_nn_tools.convert import map_dict
from pytorch_nn_tools.train.metrics.history_condition import HistoryCondition
from pytorch_nn_tools.devices import to_device
import ml_dataset_tools as mdt
import albumentations as A
from albumentations.pytorch import ToTensorV2, ToTensor

In [None]:
from trainer.trainer_io import TrainerIO

In [None]:
ish = ImgShow(ax=plt)


In [None]:
num_workers = 8

batch_size_train, batch_size_val, device = 2, 2, 'cpu'
batch_size_train, batch_size_val, device = 128, 128, 'cuda'

data_root_path = Path("data/")
data_path = data_root_path.joinpath("dataset")

In [None]:
size_h_w = 224, 224

imagenet_stats = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

cifar_stats = dict(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**cifar_stats),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**cifar_stats),
])


ds_tr  = datasets.CIFAR10(root=data_root_path, train=True, download=True, transform=transform_train)
ds_val = datasets.CIFAR10(root=data_root_path, train=False, download=False, transform=transform_test)

train_dataloader = torch.utils.data.DataLoader(
        dataset=ds_tr,
        batch_size=batch_size_train,
        shuffle=True, 
        num_workers=num_workers,
    )

    
val_dataloader = torch.utils.data.DataLoader(
        dataset=ds_val,
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )

In [None]:
def publish_images(tb_writer, images, iteration_id):
    with torch.no_grad():
        vis = images.detach().clone()
        for v in vis:
            v[:] = UnNormalize_(**cifar_stats)(v)
        grid = torchvision.utils.make_grid(vis)
        tb_writer.add_image('images', grid, iteration_id)

In [None]:
from itertools import islice

# drafts

In [None]:
# output = torch.tensor([
#     [0.1, 0.7, 0.2],
#     [0.6, 0.3, 0.1],
#     [0.05, 0.05, 0.05],
#     [5, 0.01, 0.0]
# ])
# confidence = torch.tensor([
#     0.99, 0.01, 0.99, 0.01 
# ])
# # output = torch.tensor([
# #     [0.0, 1., 0.0],
# #     [1., 0.0, 0.0]
# # ])
# target = torch.tensor([2, 2, 0, 0])
# F.cross_entropy(output, target, reduction='none'), F.cross_entropy(output, target)



In [None]:
# confidence.log()

In [None]:
# F.cross_entropy(output, target, reduction='none') * confidence

In [None]:
# idxs = torch.stack(
#     (
#         torch.arange(target.size(0)), 
#         target
#     ),
#     dim=1
# )
# idxs

In [None]:
# t_vals, t_ids = output.topk(1, dim=1)
# t_vals

In [None]:
# output.sum(axis=1)

In [None]:
# output = output.double()

In [None]:
# output.topk()

In [None]:
# output.log_softmax(axis=1).topk(k=1, dim=1)[0].exp(), 1./output.size(1)
# .exp()


In [None]:
# (output.exp() / output.exp().sum(axis=1, keepdims=True)).log()# - output.log_softmax(axis=1)


In [None]:
# i = 2
# (output[i].exp() / output[i].exp().sum()).log()

In [None]:
# output.log_softmax(axis=1)

In [None]:
# F.nll_loss(F.log_softmax(output, dim=1), target, reduction='none').mean()

# continue

In [None]:
class Trainer:
    def __init__(self, device, trainer_io: TrainerIO,
                continue_training: bool = False):
        self.device = device
        self.continue_training = continue_training
        self.trainer_io = trainer_io
        
    def fit(self, models, optimizers, schedulers, start_epoch, end_epoch, train_dataloader, val_dataloader):
        metric_logger = TensorBoardMetricLogger(self.trainer_io.tb_summary_writer)
        models = to_device(models, self.device)
        
        if self.continue_training:
            start_epoch = self.trainer_io.load_last(start_epoch, end_epoch, model, optimizer, scheduler)

        progr_train = ProgressTracker()
        
        for epoch in self.trainer_io.main_progress_bar(range(start_epoch, end_epoch)):
            metric_aggregator = MetricAggregator()
            self.train_epoch(
                train_dataloader, progr_train,
                models, optimizers, schedulers,  
                metric_proc=mod_name_train+metric_aggregator+metric_logger,
                report_step=100,
            )
            self.validate_epoch(
                val_dataloader,
                models,  
                metric_proc=mod_name_val+metric_aggregator+metric_logger,
            )
            
            aggregated = map_dict(metric_aggregator.aggregate(), key_fn=lambda key: f"avg.{key}")
            metric_logger({
                **aggregated, 
                **{f"lr_small_{i}": lr for i, lr in enumerate(schedulers['small'].get_last_lr())},
                **{f"lr_large_{i}": lr for i, lr in enumerate(schedulers['large'].get_last_lr())},
                Marker.EPOCH: epoch,
            })
            self.trainer_io.set_main_status_msg(f"{aggregated}")
#             self.trainer_io.save_checkpoint(aggregated, model, optimizer, scheduler, epoch)
            
        metric_logger.close()
            
    def train_epoch(self, data_loader, progr, models, optimizers, schedulers, metric_proc, report_step):
        for model in models.values():
            model.train()
                
        for batch in self.trainer_io.secondary_progress_bar(progr.track(data_loader)):
            batch = to_device(batch, self.device)
            images, target = batch

            optimizers['small'].zero_grad()
                        
            output_small = models['small'](images)
            confidence = output_small.log_softmax(axis=1).topk(k=1, dim=1)[0].exp()
            min_confidence = output_small.size(1)
            confidence = (confidence - min_confidence) / (1. - min_confidence) # scale from 0 to 1
            
            
            optimizers['large'].zero_grad()
            output_large = models['large'](images)
            loss_large = F.cross_entropy(output_large, target, reduction='none') 
            loss_large_with_confidence = (loss_large * (1.-confidence.detach())).mean()
            loss_large_with_confidence.backward()
            optimizers['large'].step()
                
            # TODO: understand how to apply weight for different items in a batch
            loss_small = (
                (
                    F.cross_entropy(output_small, target, reduction='none') * confidence
                ) + 
                (1-confidence) * loss_large.detach()
            ).mean()
            loss_small.backward()
            
            optimizers['small'].step()
            schedulers['large'].step()
            schedulers['small'].step()
            
            if progr.cnt_total_iter % report_step == 0:
                with torch.no_grad():
                    acc_small = topk_accuracy(output_small, target, topk=(1,))[0]
                    acc_large = topk_accuracy(output_large, target, topk=(1,))[0]

                    metric_proc({
                        'loss_small': loss_small, 
                        'loss_large': loss_large.mean(), 
                        'loss_large_with_confidence': loss_large_with_confidence, 
                        'acc_small': acc_small, 
                        'acc_large': acc_large, 
                        Marker.ITERATION: progr.cnt_total_iter,
                        
                    })

#             if batch_idx == 0 and tb_writer:
#                 publish_images(tb_writer, images, progr.cnt_total_iter)
            
#         scheduler.step()

            

    def validate_epoch(self, data_loader, models, metric_proc):
        for model in models.values():
            model.eval()
        
        with torch.no_grad():
            for batch in self.trainer_io.secondary_progress_bar(data_loader):
                batch = to_device(batch, self.device)
                images, target = batch
                output_small = models['small'](images)
                loss_small = F.cross_entropy(output_small, target)
                acc_small = topk_accuracy(output_small, target, topk=(1, ))[0]
                metric_proc({'loss_small': loss_small, 'acc_small': acc_small})
                

In [None]:
from net import preact_resnet_from_pytorch_cifar 
# model = preact_resnet_from_pytorch_cifar.PreActResNet(preact_resnet_from_pytorch_cifar.PreActBlock, [2, 2, 2, 2])
def build_model():
    return {
        'small': preact_resnet_from_pytorch_cifar.PreActResNet(preact_resnet_from_pytorch_cifar.PreActBlock, [1, 1, 1, 1], num_planes=[64,64,64,64]),
        'large': preact_resnet_from_pytorch_cifar.PreActResNet(preact_resnet_from_pytorch_cifar.PreActBlock, [2, 2, 2, 2])
    }


model = build_model()


In [None]:
# m = model['small']
# m.to

In [None]:
from torchsummary import summary
summary(model['small'], input_size=(3, 32, 32), device='cpu')

In [None]:
recommended_lr = 0.01

In [None]:
model = build_model()

num_epochs = 50
optimizers = {
    

    'small': torch.optim.SGD([
        {
            'params': model['small'].parameters(), 
            'lr': recommended_lr,
            'momentum' :0.9, 
            'weight_decay': 5e-4
        }
    ]),
    'large': torch.optim.SGD([
        {
            'params': model['large'].parameters(), 
            'lr': recommended_lr,
            'momentum' :0.9, 
            'weight_decay': 5e-4
        }
    ]),
}
    

schedulers = {
    
    'small': torch.optim.lr_scheduler.OneCycleLR(
        optimizers['small'],
        max_lr=recommended_lr,
        epochs=num_epochs,
        steps_per_epoch=len(train_dataloader),
        pct_start=0.1,
    ),
    'large': torch.optim.lr_scheduler.OneCycleLR(
        optimizers['large'],
        max_lr=recommended_lr,
        epochs=num_epochs,
        steps_per_epoch=len(train_dataloader),
        pct_start=0.1,
    ),
}

trainer_io = TrainerIO(
    log_dir="./logs/", experiment_name=f"cifar10_multistage_1_lr{recommended_lr}_sgdarr_onecpct0.1", 
    checkpoint_condition=HistoryCondition(
        'avg.val.acc1', 
        lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
    )
)

trainer = Trainer(device=device, trainer_io=trainer_io, continue_training=False)

trainer.fit(
    model, optimizers, schedulers,
    start_epoch=0, end_epoch=num_epochs,
#     train_dataloader=list(islice(train_dataloader, 0, 10)), 
#     val_dataloader=list(islice(val_dataloader, 0, 10)),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
)

In [None]:
1

In [None]:
# !rm checkpoints/epoch_00001*
# !rm checkpoints/epoch_00002*
# !rm checkpoints/epoch_00003*
# !rm logs/experiment1/checkpoints/epoch_00004*