In [1]:
# !pip install git+https://github.com/serge-m/pytorch-nn-tools.git@v0.3.1

In [2]:
import os
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateLogger
import numpy as np

from model_base import ModelBase, get_main_model

import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import defaultdict
from typing import Dict, List, Callable, Union
from pathlib import Path
import json
import time

from fastprogress.fastprogress import master_bar, progress_bar
from time import sleep

import pytorch_nn_tools as pnt
from pytorch_nn_tools.visual import ImgShow
from pytorch_nn_tools.train.metrics.processor import mod_name_train, mod_name_val, Marker
from pytorch_nn_tools.train.metrics.processor import MetricAggregator, MetricLogger
from pytorch_nn_tools.train.progress import ProgressTracker
from pytorch_nn_tools.convert import map_dict
from pytorch_nn_tools.train.metrics.history_condition import HistoryCondition
from pytorch_nn_tools.train.checkpoint import CheckpointSaver

In [3]:
ish = ImgShow(ax=plt)

In [4]:
def _train_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    train_dir = os.path.join(path, 'train')
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    return train_dataset


def _val_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    val_dir = os.path.join(path, 'val')
    dataset = datasets.ImageFolder(val_dir, transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]))
    return dataset

In [6]:
num_workers = 8

batch_size_train = 2
batch_size_val = 2
# batch_size_train = 128
# batch_size_val = 16


data_path = "data/imagewoof2-320-cut/"

train_dataloader = torch.utils.data.DataLoader(
        dataset=_train_dataset(data_path),
        batch_size=batch_size_train,
        shuffle=True, 
        num_workers=num_workers,
    )
    
val_dataloader = torch.utils.data.DataLoader(
        dataset=_val_dataset(data_path),
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )



In [7]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [8]:
from tqdm import tqdm
import datetime


In [9]:
def to_device(batch, device):
    return batch.to(device)



In [10]:
from fastprogress.fastprogress import master_bar, progress_bar

In [11]:
def now_as_str():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%s_%f")

class PBars:
    def __init__(self):
        self._main = None
        self._second = None
        
    def main(self, it, **kwargs):
        self._main = master_bar(it, **kwargs)
        return self._main
    
    def secondary(self, it, **kwargs):
        if self._main is None:
            raise RuntimeError("Cannot instantiate secondary progress bar. The main progress bar is not set.")
        self._second = progress_bar(it, parent=self._main, **kwargs)
        return self._second
        
    def main_comment(self, comment):
        self._main.main_bar.comment = comment

In [12]:
class Trainer:
    def __init__(self, device, checkpoint_saver, checkpoint_condition,
                continue_training: bool = False,
                log_dir="./logs",
                name="model"):
        self.device = device
        self.checkpoint_condition = checkpoint_condition
        self.checkpoint_saver = checkpoint_saver
        self.continue_training = continue_training
        self.log_dir = Path(log_dir)
        self.name = name
        self.pbars = PBars()
        
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch):
        path_logs = Path(self.log_dir).joinpath(f"{self.name}_{now_as_str()}")
        metric_logger = MetricLogger(path_logs)
        model = model.to(self.device)
        
        if self.continue_training:
            last = self.checkpoint_saver.find_last(start_epoch, end_epoch)
            if last is not None:
                print(f"found pretrained results for epoch {last}. Loading...")
                self.checkpoint_saver.load(model, optimizer, scheduler, last)
                start_epoch = last + 1

        progr_train = ProgressTracker()
        
        for epoch in self.pbars.main(range(start_epoch, end_epoch)):
            metric_aggregator = MetricAggregator()
            self.train_epoch(
                train_dataloader, progr_train,
                model, optimizer, scheduler,  
                metric_proc=mod_name_train+metric_aggregator+metric_logger,
                pbars=self.pbars,
                report_step=10
            )
            self.validate_epoch(
                val_dataloader,
                model,  
                metric_proc=mod_name_val+metric_aggregator+metric_logger,
                pbars=self.pbars,
            )
            aggregated = map_dict(metric_aggregator.aggregate(), key_fn=lambda key: f"avg.{key}")
            metric_logger({**aggregated, Marker.EPOCH: epoch})
            self.pbars.main_comment(f"{aggregated}")
            
            scheduler.step()
                        
            if self.checkpoint_condition(aggregated): 
                self.checkpoint_saver.save(model, optimizer, scheduler, epoch)
            
            metric_logger.close()
        return path_logs
            
    def train_epoch(self, data_loader, progr, model, optimizer, scheduler, metric_proc, pbars, report_step=1):
        model.train()
                
        for images, target in progr.track(pbars.secondary(data_loader)):
            optimizer.zero_grad()
            
            images = to_device(images, self.device)
            target = to_device(target, self.device)
            
            output = model(images)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if progr.cnt_total_iter % report_step == 0:
                with torch.no_grad():
                    acc1, acc5 = accuracy(output, target, topk=(1, 5))

                metric_proc({'loss': loss, 'acc1': acc1, 'acc5': acc5, Marker.ITERATION: progr.cnt_total_iter})
                    
        
        
    def validate_epoch(self, data_loader, model, metric_proc, pbars):
        model.eval()
        
        with torch.no_grad():
            for images, target in pbars.secondary(data_loader):
                images = to_device(images, self.device)
                target = to_device(target, self.device)

                output = model(images)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                
                metric_proc(dict(loss=loss, acc1=acc1, acc5=acc5))

In [13]:
!rm checkpoints/epo*

In [15]:
model = models.resnet18(pretrained=True)        
        
checkpoint_condition = HistoryCondition(
    'avg.val.acc1', 
    lambda hist: len(hist) == 1 or hist[-1] > max(hist[:-1])
)
checkpoint_saver = CheckpointSaver(path="checkpoints")
trainer = Trainer(
    device='cpu', checkpoint_saver=checkpoint_saver, checkpoint_condition=checkpoint_condition,
    continue_training=True
)

optimizer = torch.optim.SGD([
    {
        'name': 'main_model',
        'params': model.parameters(),
        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 1e-4,
    }
])

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda epoch: 0.1 ** (epoch // 30)
)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=80
)

saving model to checkpoints/epoch_00000.pth
saving optimizer state to checkpoints/epoch_00000.optimizer.pth
saving scheduler state to checkpoints/epoch_00000.scheduler.pth
saving meta data to checkpoints/epoch_00000.meta.json




saving model to checkpoints/epoch_00002.pth
saving optimizer state to checkpoints/epoch_00002.optimizer.pth
saving scheduler state to checkpoints/epoch_00002.scheduler.pth
saving meta data to checkpoints/epoch_00002.meta.json
saving model to checkpoints/epoch_00017.pth
saving optimizer state to checkpoints/epoch_00017.optimizer.pth
saving scheduler state to checkpoints/epoch_00017.scheduler.pth
saving meta data to checkpoints/epoch_00017.meta.json


KeyboardInterrupt: 