In [2]:
!which python

/home/s/work/nn/.venv/bin/python


In [52]:
import os
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
import torch
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateLogger

from model_base import ModelBase, get_main_model
import pytorch_nn_tools as pnt
from pytorch_nn_tools.visual import ImgShow
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [28]:
ish = ImgShow(ax=plt)

In [11]:
def _train_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    train_dir = os.path.join(path, 'train')
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    return train_dataset


def _val_dataset(path):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    val_dir = os.path.join(path, 'val')
    dataset = datasets.ImageFolder(val_dir, transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]))
    return dataset

In [4]:
num_workers = 16
batch_size_train = 16
batch_size_val = 4

data_path = "data/imagewoof2-320/"

train_dataloader = torch.utils.data.DataLoader(
        dataset=_train_dataset(data_path),
        batch_size=batch_size_train,
        shuffle=True,
        num_workers=num_workers,
    )
    
val_dataloader = torch.utils.data.DataLoader(
        dataset=_val_dataset(data_path),
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=num_workers,
    )



In [53]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [56]:
from tqdm import tqdm

In [57]:
model = models.resnet18(pretrained=True)

def to_device(batch, device):
    return batch.to(device)

class Trainer:
    def __init__(self, device):
        self.device = device
        
    def fit(self, model, optimizer, scheduler, start_epoch, end_epoch):
        model = model.to(self.device)
        for epoch in range(start_epoch, end_epoch):
            self.train_epoch(model, optimizer, scheduler)
            self.validate_epoch(model, optimizer, scheduler)
            scheduler.step()
            
    def train_epoch(self, model, optimizer, scheduler):
        model.train()
        for images, classes in self._progress_bar(train_dataloader):
            optimizer.zero_grad()
            images = to_device(images, self.device)
            output = model(images)
            loss = F.cross_entropy(output, classes)
            loss.backward()
            optimizer.step()
#             acc1, acc5 = accuracy(output, target, topk=(1, 5))
#             output = OrderedDict({
#                 'loss': loss,
#                 'acc1': acc1,
#                 'acc5': acc5,
#             })
#             return output
            
            
        
    def validate_epoch(self, model, optimizer, scheduler):
        pass
    
    def _progress_bar(self, it):
        return tqdm(it)
            
trainer = Trainer(device='cpu')

optimizer = torch.optim.SGD([
    {
        'name': 'main_model',
        'params': model.parameters(),
        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 1e-4,
    }
])

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda epoch: 0.1 ** (epoch // 30)
)

trainer.fit(
    model, optimizer, scheduler,
    start_epoch=0, end_epoch=10
)

  3%|▎         | 17/565 [00:22<11:52,  1.30s/it]


KeyboardInterrupt: 

torch.Size([16, 3, 224, 224])

In [51]:
batch[1]

tensor([8, 7, 8, 7, 6, 7, 1, 9, 1, 5, 4, 1, 4, 2, 9, 4])

In [58]:
# ish.show_image(batch[0][1])