In [1]:
%matplotlib inline
import torch
import torchvision
import torch.optim as optim
from tqdm.auto import tqdm
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import datetime
from pytz import timezone
import os
import matplotlib.pyplot as plt

from dataset import CustomDataset, CustomTransform
from model import CustomModel
from configs.config import Config
from util import Trainer, evaluate, inference

In [2]:
cfg = {
    'data':{
        'dataset': 'custom_dataset.py', 
        'data_root': '/path/to/data/where/dataset/root/folder/locates', 
        'train_dir': 'train', 
        'test_dir': 'test',
        'img_resize': 480
        },
    'train':{
        'device':0, 
        'epochs':[3,5,10], 
        'batch_size':[16,32], 
        'learning_rate':0.00001, 
        'optim':'adam'
        }
    }

cfg = Config(dic=cfg)

In [None]:
# check data
def imshow_grid(imgs, tf_inv=None, meta=None):
    figure, axes = plt.subplots(2, 3, figsize=(12,6))
    ax = axes.flatten()
    for i, im in enumerate(imgs):
        if tf_inv:
            ax[i].imshow(tf_inv(im), cmap='gray')
        else:
            ax[i].imshow(im, cmap='gray')
        if meta:
            ax[i].set_title(meta['img_path'][i][-10:])
        ax[i].axis('off')
    figure.tight_layout()
    plt.show()

tf, tf_inv = CustomTransform(cfg.data.img_resize, cfg.data.img_resize_center).get()

trainset = CustomDataset(root = cfg.data.data_root, train = True, transform = tf, train_dir=cfg.data.train_dir, test_dir=cfg.data.test_dir)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = b, shuffle = True, num_workers = 2)

testset = CustomDataset(root = cfg.data.data_root, train = False, transform = tf, train_dir=cfg.data.train_dir, test_dir=cfg.data.test_dir)
testloader = torch.utils.data.DataLoader(testset, batch_size = b, shuffle = True, num_workers = 2)

sample, label, meta = next(iter(trainloader))
anomaly_sample, _, anomaly_meta = next(iter(testloader))
print(len(trainset))
print(len(testset))

imshow_grid(sample[0:6], tf_inv, meta)
imshow_grid(anomaly_sample[0:6], tf_inv, anomaly_meta)

In [None]:
# train
trainer = Trainer(cfg)
trainer.train()

In [None]:
# infer
inference(cfg)

In [None]:
# evaluate
evaluate(cfg)