In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torchvision
import torch
from torch import nn
from torchvision import transforms
from time import time
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import time
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from IPython.display import clear_output
import argparse
import sys
from torch.optim.lr_scheduler import StepLR

In [1]:
class ValDataset(Dataset):
    def __init__(self, csv_file, root, transform=None):
        self.info = pd.read_csv(csv_file, sep='\t', header=None)
        self.root = root
        self.transform = transform
        classes = pd.read_csv('tiny-imagenet-200/wnids.txt', sep='\t', header=None).sort_values(0).reset_index(drop=True)
        self.classes_dict = {classes[0][i]:i for i in range(200)} # class id to class

    def __len__(self):
        return len(self.info)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root, self.info.iloc[idx, 0])
        image = np.asarray(Image.open(img_name).convert('RGB'))
        target = self.classes_dict[self.info.iloc[idx, 1]]

        if self.transform:
            image = self.transform(image)

        return image, target


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Accuracy(torch.nn.Module):
    def __init__(self):
        super(Accuracy, self).__init__()

    def forward(self, outputs, targets):
        _, preds = torch.max(outputs, 1)
        return torch.mean((preds == targets).double())

NameError: name 'Dataset' is not defined

In [None]:
class ValDataset(Dataset):
    def __init__(self, csv_file, root, transform=None):
        self.info = pd.read_csv(csv_file, sep='\t', header=None)
        self.root = root
        self.transform = transform
        classes = pd.read_csv('tiny-imagenet-200/wnids.txt', sep='\t', header=None).sort_values(0).reset_index(drop=True)
        self.classes_dict = {classes[0][i]:i for i in range(200)} # class id to class

    def __len__(self):
        return len(self.info)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root, self.info.iloc[idx, 0])
        image = np.asarray(Image.open(img_name).convert('RGB'))
        target = self.classes_dict[self.info.iloc[idx, 1]]

        if self.transform:
            image = self.transform(image)

        return image, target


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Accuracy(torch.nn.Module):
    def __init__(self):
        super(Accuracy, self).__init__()

    def forward(self, outputs, targets):
        _, preds = torch.max(outputs, 1)
        return torch.mean((preds == targets).double())



def train_epoch(model, dataloader, criterion, metric, device, optimizer, eff_batch_size=512, checkpoint_seg=3):
    total_loss = 0
    total_acc  = 0
    n = len(dataloader)    
    
    #accumulating batches
    effective_batch_size = eff_batch_size
    loader_batch_size = dataloader.batch_size
    batches_per_update = effective_batch_size / loader_batch_size
    
    model.train(True)
    for i_batch, (X_batch, y_batch) in enumerate(dataloader):
        
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        
        #checkpointing
        if checkpoint_seg == 0:
            out = model(X_batch)
        else:
            X_batch.requires_grad = True
            out = checkpoint_sequential(model, checkpoint_seg, X_batch)

        loss = criterion(out, y_batch)
        loss.backward()
        total_loss += loss.item()
        
        total_acc += metric(out, y_batch).item()

        #accumulating gradients
        if (i_batch + 1) % batches_per_update == 0:
            optimizer.step()
            optimizer.zero_grad()


    av_loss = total_loss / n
    av_acc  = total_acc  / n

    return av_loss, av_acc


@torch.no_grad()
def eval_model(model, dataloader, criterion, metric, device):
    total_acc = 0
    total_loss = 0
    n = len(dataloader)  
    
    model.eval()
    for X_batch, y_batch in dataloader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        out = model(X_batch)
        total_loss +=criterion(out, y_batch)
        total_acc += metric(out, y_batch).item()

    av_acc  = total_acc / n
    av_loss = total_loss / n

    return av_loss, av_acc


def train_model(model, dataloaders, optimizer,
                criterion=nn.CrossEntropyLoss(), 
                metric=Accuracy(), 
                device=torch.device("cuda:0" if torch.cuda.is_available() else 'cpu'), 
                epochs=50,
                max_acc=0.4,
                eff_batch_size=128,
                checkpoint_seg=3):

    start = time.time()
    
    model = model.to(device)
    log_acc = []
    for epoch in range(epochs):
        start_time = time.time()
        
        train_start = time.time()
        loss, train_acc = train_epoch(model, dataloaders['train'], criterion, metric, device, 
                                      optimizer, eff_batch_size, checkpoint_seg)   
        train_time = time.time() - train_start

        val_loss, val_acc  =  eval_model(model, dataloaders['val'],  criterion, metric, device)
        test_loss, test_acc =  eval_model(model, dataloaders['test'], criterion, metric, device)
        
        log_acc.append((train_acc, val_acc, test_acc, loss, val_loss, test_loss))
        
        draw_accuracy(log_acc)

        print("Epoch [{}/{}] Time: {:.2f}s; BF Time: {:.2f}s; TrainLoss: {:.4f}; TrainAccuracy: {:.4f}; ValAccuracy: {:.4f}, TestAccuracy: {:.4f}".format(
              epoch + 1, epochs, time.time() - start_time, train_time, loss, train_acc, val_acc, test_acc))    

        if test_acc > max_acc:
            break
    
    print("Full_time: {:.2f}s".format(time.time() - start))
    if torch.cuda.is_available():
        print(f"Peak memory usage by Pytorch tensors: {(torch.cuda.max_memory_allocated() / 1024 / 1024):.2f} Mb")

    return model, val_acc, test_acc


def block(cin, cout, kernel_size=(3,3), padding=(1,1), stride=(1,1), pool_size=(2,2)):
    return nn.Sequential(
        nn.Conv2d(in_channels=cin, 
                  out_channels=cout, 
                  kernel_size=kernel_size, 
                  padding=padding, 
                  stride=stride),
        nn.BatchNorm2d(num_features=cout),
        nn.LeakyReLU(),
        nn.MaxPool2d(kernel_size=pool_size)
    )

def get_model(cin=3, cout=200, base=64, drop=0.2):
    return torch.nn.Sequential(
        block(cin=cin, 
              cout=base),
        block(cin=base, 
              cout=base*2),
        nn.Dropout(0.2),
        block(cin=base*2, 
              cout=base*4),
        block(cin=base*4, 
              cout=base*8),
        nn.Flatten(),
        nn.Dropout(drop),
        nn.Linear(in_features=base*8*4*4, 
                  out_features=base*16),
        nn.ReLU(),
        nn.Dropout(drop),
        nn.Linear(in_features=base*16, 
                  out_features=cout)
    )


def get_data(batch_size=64):
    transform_train = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(15),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    dataset = torchvision.datasets.ImageFolder('tiny-imagenet-200/train', transform=transform_train)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [80000, 20000])
    test_dataset  = ValDataset(csv_file='tiny-imagenet-200/val/val_annotations.txt',
                              root='tiny-imagenet-200/val/images',   
                              transform=transform_test)
    

    dataloaders = {
        'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8),
        'val'  : DataLoader(val_dataset,   batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : DataLoader(test_dataset,  batch_size=batch_size, shuffle=True, num_workers=4) 
    }

    return dataloaders


def create_parser():
    pars = argparse.ArgumentParser()
    pars.add_argument('-batch_size',
                      help="choose batch size", type=int, default=64)
    pars.add_argument('-eff_batch_size',
                      help="choose effective batch size", type=int, default=512)
    pars.add_argument('-drop',
                      help="dropout in model ", type=float, default=0.2)
    pars.add_argument('-base',
                      help="base in model ", type=int, default=64)
    pars.add_argument('-checkpoint_count',
                      help="count of checkpoints", type=int, default=10)
    pars.add_argument('-epoch',
                      help="choose number of epoch", type=int, default=50)
    pars.add_argument('-max_acc',
                      help="choose what test accuracy is enought", type=float, default=0.4)
    pars.add_argument('-lr',
                      help="adam learning rate", type=float, default=0.001)
    pars.add_argument('-save',
                      help="save model", type=bool, default=True)

    return pars

def main(batch_size, lr, epoch, max_acc, eff_batch_size, checkpoint_count, drop, base, save):
    
    model = get_model(base=base, drop=drop)

    dataloaders = get_data(batch_size=batch_size)

    for param in model.parameters():
        param.requires_grad = True

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

    trained_model, val_accuracy, test_accuracy = train_model(model=model, 
                                                             dataloaders=dataloaders, 
                                                             optimizer=optimizer, 
                                                             epochs=epoch,
                                                             max_acc=max_acc,
                                                             eff_batch_size=eff_batch_size,
                                                             checkpoint_seg=checkpoint_count)
    
    if save:
        torch.save(trained_model.state_dict(), 'model_state.pt')


# if __name__ == "__main__":
#     parser = create_parser()
#     namespace = parser.parse_args(sys.argv[1:])
#     print(vars(namespace))
#     main(**vars(namespace))