In [1]:
import os
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateLogger
import numpy as np

from model_base import ModelBase, get_main_model
import pytorch_nn_tools as pnt
from pytorch_nn_tools.visual import ImgShow
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json
import time

from fastprogress.fastprogress import master_bar, progress_bar
from time import sleep


In [2]:
ish = ImgShow(ax=plt)

In [3]:
def _train_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    train_dir = os.path.join(path, 'train')
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    return train_dataset


def _val_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    val_dir = os.path.join(path, 'val')
    dataset = datasets.ImageFolder(val_dir, transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]))
    return dataset

In [4]:
import tensorboard
from torch.utils.tensorboard import SummaryWriter

In [5]:

writer = SummaryWriter("logs/lg1")
for i in range(0, 1000):
    writer.add_scalar('train_loss', 1/(i+0.1), i)

In [6]:
num_workers = 8

batch_size_train = 2
batch_size_val = 2
# batch_size_train = 128
# batch_size_val = 16


data_path = "data/imagewoof2-320-cut/"

train_dataloader = torch.utils.data.DataLoader(
        dataset=_train_dataset(data_path),
        batch_size=batch_size_train,
        shuffle=True, 
        num_workers=num_workers,
    )
    
val_dataloader = torch.utils.data.DataLoader(
        dataset=_val_dataset(data_path),
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )



In [7]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [8]:
from tqdm import tqdm
import datetime


In [9]:
def to_device(batch, device):
    return batch.to(device)



In [10]:
    
class HistoryCondition:
    def __init__(self, metric_name: str, history_condition: Callable, history=()):
        self.metric_name = metric_name
        self.history = list(history[:])
        self.condition = history_condition
        
    def __call__(self, metrics: Dict):
        self.history.append(metrics[self.metric_name])
        result = self.condition(self.history[:])
#         print(f'Condition {result} on {self.history}')
        return result
            
            
class CheckpointSaver:
    def __init__(self, path: Union[Path, str]):
        self.path = Path(path)
        self.path.mkdir(parents=True, exist_ok=True)
        
    
    def save(self, model, optimizer, scheduler, epoch):
        path = self.path.joinpath(f"epoch_{epoch:05d}.pth")
        print(f"saving model to {path}")
        torch.save(model.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.optimizer.pth")
        print(f"saving optimizer state to {path}")
        torch.save(optimizer.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.scheduler.pth")
        print(f"saving scheduler state to {path}")
        torch.save(scheduler.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.meta.json")
        print(f"saving meta data to {path}")
        with path.open("w") as f:
            json.dump({'epoch': epoch}, f)
            
    def load(self, model, optimizer, scheduler, epoch):
        path = self.path.joinpath(f"epoch_{epoch:05d}.pth")
        print(f"loading model from {path}")
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.optimizer.pth")
        if path.exists():
            print(f"loading optimizer state from {path}")
            optimizer_dict = torch.load(path)
            optimizer.load_state_dict(optimizer_dict)
        else:
            print("optimizer state not found")
            
        path = self.path.joinpath(f"epoch_{epoch:05d}.scheduler.pth")
        if path.exists():
            print(f"loading scheduler state from {path}")
            scheduler_dict = torch.load(path)
            scheduler.load_state_dict(scheduler_dict)
        else:
            print("scheduler state not found")

    def find_last(self, start_epoch, end_epoch):
        for epoch in range(end_epoch, start_epoch-1, -1):
            path = self.path.joinpath(f"epoch_{epoch:05d}.meta.json")
            if path.exists():
                return epoch
        return None

In [11]:
d={'a':'b', 'c': 2}
d.pop('a', None)
print(d)

{'c': 2}


In [12]:
Metrics = dict

# class MetricBuilder:
#     def __init__(self, required=()):
#         self._required = required
#         self._metrics = defaultdict(list)
        
#     def add(self, prefix, data, preproc_fn=lambda x: x.detach().item()):
#         for r in self._required:
#             assert r in data, f"Metric {r} is required, but not provided"
            
#         for k, v in data.items():
#             self._metrics[f"{prefix}_{k}"].append(preproc_fn(v))
            
#     def build(self) -> Metrics:
#         return {
#             k: sum(vs) / (len(vs) if vs else 1.)
#             for k, vs in self._metrics.items()
#         }

    
class MetricProcessor:
    def __call__(self, data, iteration=None):
        return data
    
    def __add__(self, other):
        processors = []
        for x in [self, other]:
            if isinstance(x, MetricPipeline):
                processors.extend(x._processors)
            else:
                processors.append(x)
        return MetricPipeline(*processors)

class MetricPipeline(MetricProcessor):
    def __init__(self, *processors):
        self._processors = processors
        
    def __call__(self, data):
        for p in self._processors:
            data = p(data)
        return data
    
    
class MetricAggregator(MetricProcessor):
    DEFAULT_SKIPPED = ('_iteration', '_epoch',)
    def __init__(self, skipped=DEFAULT_SKIPPED):
        self._metrics = defaultdict(list)
        self.skipped = set(skipped)
        
    def __call__(self, data):
        for k, v in data.items():
            if k not in self.skipped:
                self._metrics[k].append(v)
        return data
            
    def aggregate(self) -> Metrics:
        return {
            f"avg.{k}": sum(vs) / (len(vs) if vs else 1.)
            for k, vs in self._metrics.items()
        }
    
class MetricMod(MetricProcessor):
    DEFAULT_SKIPPED = ('_iteration', '_epoch',)
    def __init__(self, name_fn=lambda x: x, value_fn=lambda x: x.detach().item(), skipped=DEFAULT_SKIPPED):
        self.name_fn = name_fn
        self.value_fn = value_fn
        self.skipped = set(skipped)
        
        
    def __call__(self, data):
        return dict([
            (
                (self.name_fn(name), self.value_fn(value)) 
                if name not in self.skipped 
                else (name, value)
            )
            for name, value in data.items()
        ])

    

    
    
class MetricLogger(MetricProcessor):
    def __init__(self, path):
        path = Path(path)
        path.mkdir(exist_ok=True, parents=True)
        self.writer = SummaryWriter(path)
        
    def __call__(self, data):
        iteration = data.pop('_iteration', None)
        if iteration is None:
            iteration = data.pop('_epoch', None)
        if iteration is not None:
            for name, value in data.items():
                self.writer.add_scalar(name, value, iteration)

        return data

    def close(self):
        self.writer.close()

In [13]:
mod_name_train = MetricMod(
    name_fn=lambda name: f"train.{name}",
)
mod_name_val = MetricMod(
    name_fn=lambda name: f"val.{name}",
    
)

mod_name_train({'asdas': torch.tensor(123), '_iteration': 456})

{'train.asdas': 123, '_iteration': 456}

In [15]:
from fastprogress.fastprogress import master_bar, progress_bar

class ProgressTracker:
    def __init__(self, name=None):
        self.cnt_total_iter = 0
        self.name = name
        
    def track(self, dl):
        def tracked_iterator():
            for x in dl:
                self.cnt_total_iter += 1
                yield x
        return tracked_iterator()

# class ProgressTrackingDL:
#     def __init__(self, dl, progress_tracker: ProgressTracker, parent_pbar=None):
#         self.it = iter(progress_bar(dl, comment=progress_tracker.name, parent=parent_pbar))
#         self.progress_tracker = progress_tracker
        
#     def __iter__(self):
#         return self
    
#     def __next__(self):
#         result = next(self.it)
#         self.progress_tracker.cnt_total_iter += 1
#         return result
    
#     @property
#     def cnt_total_iter(self):
#         return self.progress_tracker.cnt_total_iter

In [38]:
def now_as_str():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%s_%f")

class PBars:
    def __init__(self):
        self._main = None
        self._second = None
        
    def main(self, it, **kwargs):
        self._main = master_bar(it, **kwargs)
        return self._main
    
    def secondary(self, it, **kwargs):
        if self._main is None:
            raise RuntimeError("Cannot instantiate secondary progress bar. The main progress bar is not set.")
        self._second = progress_bar(it, parent=self._main, **kwargs)
        return self._second
        
    def main_comment(self, comment):
        self._main.main_bar.comment = comment

class Trainer:
    def __init__(self, device, checkpoint_saver, checkpoint_condition,
                continue_training: bool = False,
                log_dir="./logs",
                name="model"):
        self.device = device
        self.checkpoint_condition = checkpoint_condition
        self.checkpoint_saver = checkpoint_saver
        self.continue_training = continue_training
        self.log_dir = Path(log_dir)
        self.name = name
        self.pbars = PBars()
        
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch):
        path_logs = Path(self.log_dir).joinpath(f"{self.name}_{now_as_str()}")
        metric_logger = MetricLogger(path_logs)
        model = model.to(self.device)
        
        if self.continue_training:
            last = self.checkpoint_saver.find_last(start_epoch, end_epoch)
            if last is not None:
                print(f"found pretrained results for epoch {last}. Loading...")
                self.checkpoint_saver.load(model, optimizer, scheduler, last)
                start_epoch = last + 1

        progr_train = ProgressTracker(name="train")
        
        for epoch in self.pbars.main(range(start_epoch, end_epoch)):
            metric_aggregator = MetricAggregator()
            self.train_epoch(
                train_dataloader, progr_train,
                model, optimizer, scheduler,  
                metric_proc=mod_name_train+metric_aggregator+metric_logger,
                pbars=self.pbars,
                report_step=3
            )
            self.validate_epoch(
                val_dataloader,
                model,  
                metric_proc=mod_name_val+metric_aggregator+metric_logger,
                pbars=self.pbars,
            )
            aggregated = metric_aggregator.aggregate()
            metric_logger({**aggregated, '_epoch': epoch})
            self.pbars.main_comment(f"{aggregated}")
            
            scheduler.step()
                        
            if self.checkpoint_condition(aggregated): 
                self.checkpoint_saver.save(model, optimizer, scheduler, epoch)
            
            metric_logger.close()
        return path_logs
            
    def train_epoch(self, data_loader, progr, model, optimizer, scheduler, metric_proc, pbars, report_step=1):
        model.train()
                
        for images, target in progr.track(pbars.secondary(data_loader)):
            optimizer.zero_grad()
            
            images = to_device(images, self.device)
            target = to_device(target, self.device)
            
            output = model(images)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if progr.cnt_total_iter % report_step == 0:
                with torch.no_grad():
                    acc1, acc5 = accuracy(output, target, topk=(1, 5))

                metric_proc(dict(loss=loss, acc1=acc1, acc5=acc5, _iteration=progr.cnt_total_iter))
                    
        
        
    def validate_epoch(self, data_loader, model, metric_proc, pbars):
        model.eval()
        
        with torch.no_grad():
            for images, target in pbars.secondary(data_loader):
                images = to_device(images, self.device)
                target = to_device(target, self.device)

                output = model(images)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                
                metric_proc(dict(loss=loss, acc1=acc1, acc5=acc5))


In [39]:
!rm checkpoints/epo*

In [40]:
model = models.resnet18(pretrained=True)        
        
checkpoint_condition = HistoryCondition(
    'avg.val.acc1', 
    lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
)
checkpoint_saver = CheckpointSaver(path="checkpoints")
trainer = Trainer(
    device='cpu', checkpoint_saver=checkpoint_saver, checkpoint_condition=checkpoint_condition,
    continue_training=True
)

optimizer = torch.optim.SGD([
    {
        'name': 'main_model',
        'params': model.parameters(),
        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 1e-4,
    }
])

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda epoch: 0.1 ** (epoch // 30)
)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=20
)

saving model to checkpoints/epoch_00000.pth
saving optimizer state to checkpoints/epoch_00000.optimizer.pth
saving scheduler state to checkpoints/epoch_00000.scheduler.pth
saving meta data to checkpoints/epoch_00000.meta.json




saving model to checkpoints/epoch_00013.pth
saving optimizer state to checkpoints/epoch_00013.optimizer.pth
saving scheduler state to checkpoints/epoch_00013.scheduler.pth
saving meta data to checkpoints/epoch_00013.meta.json


PosixPath('logs/model_20201102_09291604305770_207814')

In [37]:
str(123)

'123'

In [None]:
# !rm checkpoints/epoch_00001*
# !rm checkpoints/epoch_00002*
# !rm checkpoints/epoch_00003*
# # !rm checkpoints/epoch_00004*

# reloaded

In [None]:
# model = models.resnet18(pretrained=True)        
        
# checkpoint_condition = HistoryCondition(
#     'acc1', 
#     lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
# )
# checkpoint_saver = CheckpointSaver(path="checkpoints")
# trainer = Trainer(
#     device='cuda', checkpoint_saver=checkpoint_saver, checkpoint_condition=checkpoint_condition,
#     continue_training=True
# )

# optimizer = torch.optim.SGD([
#     {
#         'name': 'main_model',
#         'params': model.parameters(),
#         'lr': 0.1,
#         'momentum': 0.9,
#         'weight_decay': 1e-4,
#     }
# ])

# scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer,
#     lambda epoch: 0.1 ** (epoch // 30)
# )

# trainer.fit(
#     model, optimizer, scheduler,
#     start_epoch=0, end_epoch=80
# )

In [None]:
batch[1]

In [None]:
# ish.show_image(batch[0][1])

In [None]:
next(iter(train_dataloader))