In [1]:
import os
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateLogger
import numpy as np

from model_base import ModelBase, get_main_model
import pytorch_nn_tools as pnt
from pytorch_nn_tools.visual import ImgShow
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json

In [2]:
ish = ImgShow(ax=plt)

In [3]:
def _train_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    train_dir = os.path.join(path, 'train')
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    return train_dataset


def _val_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    val_dir = os.path.join(path, 'val')
    dataset = datasets.ImageFolder(val_dir, transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]))
    return dataset

In [4]:
num_workers = 8

# batch_size_train = 2
# batch_size_val = 2
batch_size_train = 128
batch_size_val = 16


data_path = "data/imagewoof2-320/"

train_dataloader = torch.utils.data.DataLoader(
        dataset=_train_dataset(data_path),
        batch_size=batch_size_train,
        shuffle=True, 
        num_workers=num_workers,
    )
    
val_dataloader = torch.utils.data.DataLoader(
        dataset=_val_dataset(data_path),
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )



In [5]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [6]:
from tqdm import tqdm

In [7]:
def to_device(batch, device):
    return batch.to(device)

Metrics = dict

class MetricBuilder:
    def __init__(self, required=()):
        self._required = required
        self._metrics = defaultdict(list)
        
    def add(self, metrics_dict, preproc_fn=lambda x: x.detach().item()):
        for r in self._required:
            assert r in metrics_dict, f"Metric {r} is required, but not provided"
            
        for k, v in metrics_dict.items():
            self._metrics[k].append(preproc_fn(v))
            
    def build(self) -> Metrics:
        return {
            k: sum(vs) / (len(vs) if vs else 1.)
            for k, vs in self._metrics.items()
        }
    

    
class HistoryCondition:
    def __init__(self, metric_name: Metrics, history_condition: Callable, history=()):
        self.metric_name = metric_name
        self.history = list(history[:])
        self.condition = history_condition
        
    def __call__(self, metrics: Metrics):
        self.history.append(metrics[self.metric_name])
        result = self.condition(self.history[:])
#         print(f'Condition {result} on {self.history}')
        return result
            
            
class CheckpointSaver:
    def __init__(self, path: Union[Path, str]):
        self.path = Path(path)
        self.path.mkdir(parents=True, exist_ok=True)
        
    
    def save(self, model, optimizer, scheduler, epoch):
        path = self.path.joinpath(f"epoch_{epoch:05d}.pth")
        print(f"saving model to {path}")
        torch.save(model.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.optimizer.pth")
        print(f"saving optimizer state to {path}")
        torch.save(optimizer.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.scheduler.pth")
        print(f"saving scheduler state to {path}")
        torch.save(scheduler.state_dict(), path)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.meta.json")
        print(f"saving meta data to {path}")
        with path.open("w") as f:
            json.dump({'epoch': epoch}, f)
            
    def load(self, model, optimizer, scheduler, epoch):
        path = self.path.joinpath(f"epoch_{epoch:05d}.pth")
        print(f"loading model from {path}")
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        
        path = self.path.joinpath(f"epoch_{epoch:05d}.optimizer.pth")
        if path.exists():
            print(f"loading optimizer state from {path}")
            optimizer_dict = torch.load(path)
            optimizer.load_state_dict(optimizer_dict)
        else:
            print("optimizer state not found")
            
        path = self.path.joinpath(f"epoch_{epoch:05d}.scheduler.pth")
        if path.exists():
            print(f"loading scheduler state from {path}")
            scheduler_dict = torch.load(path)
            scheduler.load_state_dict(scheduler_dict)
        else:
            print("scheduler state not found")

    def find_last(self, start_epoch, end_epoch):
        for epoch in range(end_epoch, start_epoch-1, -1):
            path = self.path.joinpath(f"epoch_{epoch:05d}.meta.json")
            if path.exists():
                return epoch
        return None

In [8]:
    
class Trainer:
    def __init__(self, device, checkpoint_saver, checkpoint_condition,
                continue_training: bool = False):
        self.device = device
        self.checkpoint_condition = checkpoint_condition
        self.checkpoint_saver = checkpoint_saver
        self.continue_training = continue_training
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch):
        model = model.to(self.device)
        if self.continue_training:
            last = self.checkpoint_saver.find_last(start_epoch, end_epoch)
            if last is not None:
                print(f"found pretrained results for epoch {last}. Loading...")
                self.checkpoint_saver.load(model, optimizer, scheduler, last)
                start_epoch = last + 1

                
        for epoch in range(start_epoch, end_epoch):
            metrics_train = self.train_epoch(model, optimizer, scheduler)
            self.log(metrics_train)
            
            metrics_val = self.validate_epoch(model)
            self.log(metrics_val)
            
            scheduler.step()
                        
            if self.checkpoint_condition(metrics_val): 
                self.checkpoint_saver.save(model, optimizer, scheduler, epoch)
            
            
            
    def train_epoch(self, model, optimizer, scheduler):
        model.train()
        metrics = MetricBuilder()
        
        for images, target in self._progress_bar(train_dataloader):
            optimizer.zero_grad()
            
            images = to_device(images, self.device)
            target = to_device(target, self.device)
            
            output = model(images)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            metrics.add(dict(loss=loss, acc1=acc1, acc5=acc5))
            
        return metrics.build()
        
        
    def validate_epoch(self, model):
        model.eval()
        metrics = MetricBuilder()
        
        with torch.no_grad():
            for images, target in self._progress_bar(val_dataloader):
                images = to_device(images, self.device)
                target = to_device(target, self.device)

                output = model(images)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                metrics.add(dict(loss=loss, acc1=acc1, acc5=acc5))
                
        return metrics.build()
        
    
    def _progress_bar(self, it):
        return tqdm(it)
    
    def log(self, data):
        print(data)

In [9]:
!rm checkpoints/epo*

rm: cannot remove 'checkpoints/epo*': No such file or directory


In [None]:
model = models.resnet18(pretrained=True)        
        
checkpoint_condition = HistoryCondition(
    'acc1', 
    lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
)
checkpoint_saver = CheckpointSaver(path="checkpoints")
trainer = Trainer(
    device='cuda', checkpoint_saver=checkpoint_saver, checkpoint_condition=checkpoint_condition,
    continue_training=True
)

optimizer = torch.optim.SGD([
    {
        'name': 'main_model',
        'params': model.parameters(),
        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 1e-4,
    }
])

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda epoch: 0.1 ** (epoch // 30)
)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=80
)

 13%|█▎        | 9/71 [00:03<00:21,  2.90it/s]

In [None]:
# !rm checkpoints/epoch_00001*
# !rm checkpoints/epoch_00002*
# !rm checkpoints/epoch_00003*
# # !rm checkpoints/epoch_00004*

# reloaded

In [None]:
model = models.resnet18(pretrained=True)        
        
checkpoint_condition = HistoryCondition(
    'acc1', 
    lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
)
checkpoint_saver = CheckpointSaver(path="checkpoints")
trainer = Trainer(
    device='cuda', checkpoint_saver=checkpoint_saver, checkpoint_condition=checkpoint_condition,
    continue_training=True
)

optimizer = torch.optim.SGD([
    {
        'name': 'main_model',
        'params': model.parameters(),
        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 1e-4,
    }
])

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda epoch: 0.1 ** (epoch // 30)
)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=80
)

In [None]:
# model = models.resnet18(pretrained=True)        
        
# trainer = Trainer(device='cuda')

# optimizer = torch.optim.SGD([
#     {
#         'name': 'main_model',
#         'params': model.parameters(),
#         'lr': 0.1,
#         'momentum': 0.9,
#         'weight_decay': 1e-4,
#     }
# ])

# scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer,
#     lambda epoch: 0.1 ** (epoch // 30)
# )

# trainer.fit(
#     model, optimizer, scheduler,
#     start_epoch=0, end_epoch=80
# )

In [None]:
batch[1]

In [None]:
# ish.show_image(batch[0][1])

In [None]:
next(iter(train_dataloader))