In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [16]:
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod

import os, json
from collections import OrderedDict
import torch
import torchvision
import pytorch_lightning as pl
from torch import optim
import torch.nn as nn
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import (Dataset, DataLoader)

from k12libs.utils.nb_easy import K12AI_DATASETS_ROOT

import warnings

warnings.filterwarnings('ignore')

from torchvision.transforms import ( # noqa
        Compose,
        ToTensor,
        Normalize,
        RandomOrder,
        Resize,
        ColorJitter,
        RandomRotation,
        RandomGrayscale,
        RandomResizedCrop,
        RandomHorizontalFlip,
        RandomVerticalFlip)

class IDataTransforms(object):

    @staticmethod
    def compose_order(transforms: List):
        return Compose(transforms)

    @staticmethod
    def shuffle_order(transforms: List):
        return RandomOrder(transforms)
    
    @staticmethod
    def image_resize(size):
        return Resize(size=size)

    @staticmethod
    def image_to_tensor():
        return ToTensor()

    @staticmethod
    def normalize_tensor(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        return Normalize(mean, std)

    @staticmethod
    def random_brightness(factor=0.25):
        return ColorJitter(brightness=factor)

    @staticmethod
    def random_contrast(factor=0.25):
        return ColorJitter(contrast=factor)

    @staticmethod
    def random_saturation(factor=0.25):
        return ColorJitter(saturation=factor)

    @staticmethod
    def random_hue(factor=0.25):
        return ColorJitter(hue=factor)

    @staticmethod
    def random_rotation(degrees=25):
        return RandomRotation(degrees=degrees)

    @staticmethod
    def random_grayscale(p=0.5):
        return RandomGrayscale(p=p)

    @staticmethod
    def random_resized_crop(size):
        return RandomResizedCrop(size=size)

    @staticmethod
    def random_horizontal_flip(p=0.5):
        return RandomHorizontalFlip(p=p)

    @staticmethod
    def random_vertical_flip(p=0.5):
        return RandomVerticalFlip(p=p)


class EasyaiDataset(ABC, Dataset):
    def __init__(self, mean=None, std=None, **kwargs):
        self.mean, self.std = mean, std
        self.augtrans = None
        if mean and std:
            self.imgtrans = self.compose_order([
                self.image_to_tensor(),
                self.normalize_tensor(mean=mean, std=std)
            ])
        else:
            self.imgtrans = self.image_to_tensor()

        # reader
        data = self.data_reader(**kwargs)
        if isinstance(data, (tuple, list)) and len(data) == 2:
            self.images, self.labels = data
        elif isinstance(data, dict) and all([x in data.keys() for x in ('images', 'labels')]):
            self.images, self.labels = data['images'], data['labels']
        else:
            raise ValueError('The return of data_reader must be List or Dict')

    @abstractmethod
    def data_reader(self, **kwargs) -> Union[Tuple[List, List, List], Dict[str, List]]:
        """
        (M)
        """
    
    def set_aug_trans(self, transforms:Union[list, None], shuffle=False):
        if transforms:
            if any([not hasattr(x, '__call__') for x in transforms]):
                raise ValueError(f'set_aug_trans: transforms params is invalid.')
            if shuffle:
                self.augtrans = self.shuffle_order(transforms)
            else:
                self.augtrans = self.compose_order(transforms)
    
    def set_img_trans(self, input_size:Union[Tuple[int, int], int, None], normalize=True):
        trans = []
        if input_size:
            trans.append(self.image_resize(input_size))
        trans.append(self.image_to_tensor())
        if normalize and self.mean and self.std:
            trans.append(self.normalize_tensor(mean=self.mean, std=self.std))
        self.imgtrans = self.compose_order(trans)

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        if self.augtrans:
            img = self.augtrans(img)
        return self.imgtrans(img), self.labels[index], self.images[index]

    def __len__(self):
        return len(self.images)


class EasyaiClassifier(pl.LightningModule, IDataTransforms):
    def __init__(self):
        super(EasyaiClassifier, self).__init__()
        self.model = self.build_model()
        self.criterion = None
        self.datasets = {'train': None, 'val': None, 'test': None}

    def setup(self, stage: str):
        pass

    def teardown(self, stage: str):
        pass

    # Data
    def load_presetting_dataset_(self, dataset_name):
        class JsonfileDataset(EasyaiDataset):
            def data_reader(self, path, phase):
                """
                Args:
                    path: the dataset root directory
                    phase: the json file name (train.json / val.json / test.json)
                """
                image_list = []
                label_list = []
                with open(os.path.join(path, f'{phase}.json')) as f:
                    items = json.load(f)
                    for item in items:
                        image_list.append(os.path.join(path, item['image_path']))
                        label_list.append(item['label'])
                return image_list, label_list

        root = f'{K12AI_DATASETS_ROOT}/cv/{dataset_name}'
        datasets = {
            'train': JsonfileDataset(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), path=root, phase='train'),
            'val': JsonfileDataset(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), path=root, phase='val'),
            'test': JsonfileDataset(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), path=root, phase='test'),
        }
        return datasets 

    def prepare_dataset(self) -> Union[EasyaiDataset, List[EasyaiDataset], Dict[str, EasyaiDataset]]:
        return self.load_presetting_dataset_('rmnist')

    @staticmethod
    def _safe_delattr(cls, method):
        try:
            delattr(cls, method)
        except Exception:
            pass

    def prepare_data(self):
        datasets = self.prepare_dataset()
        if isinstance(datasets, EasyaiDataset):
            self._safe_delattr(self.__class__, 'val_dataloader')
            self._safe_delattr(self.__class__, 'validation_step')
            self._safe_delattr(self.__class__, 'validation_epoch_end')
            self._safe_delattr(self.__class__, 'test_dataloader')
            self._safe_delattr(self.__class__, 'test_step')
            self._safe_delattr(self.__class__, 'test_epoch_end')
            self.datasets['train'] = datasets
        elif isinstance(datasets, (list, tuple)) and len(datasets) <= 3:
            self.datasets['train'] = datasets[0]
            self.datasets['val'] = datasets[1]
            if len(datasets) == 2:
                self._safe_delattr(self.__class__, 'test_dataloader')
                self._safe_delattr(self.__class__, 'test_step')
                self._safe_delattr(self.__class__, 'test_epoch_end')
            else:
                self.datasets['test'] = datasets[2]

        elif isinstance(datasets, dict) and \
                all([x in datasets.keys() for x in ('train', 'val', 'test')]):
            self.datasets = datasets
        else:
            raise ValueError('the return of prepare_dataset is invalid.')

    # Model
    def load_pretrained_model_(self, model_name, num_classes=None, pretrained=True):
        model = getattr(torchvision.models, model_name)(pretrained)
        if num_classes:
            if model_name.startswith('vgg'):
                model.classifier[6] = nn.Linear(4096, num_classes)
            elif model_name.startswith('resnet'):
                model.fc = nn.Linear(model.fc.in_features, num_classes)
            elif model_name.startswith('alexnet'):
                model.classifier[6] = nn.Linear(4096, num_classes)
            elif model_name.startswith('mobilenet_v2'):
                model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
            elif model_name.startswith('squeezenet'):
                model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
            elif model_name.startswith('shufflenet'):
                model.fc = nn.Linear(model.fc.in_features, num_classes)
            elif model_name.startswith('densenet'):
                in_features = {
                    "densenet121": 1024,
                    "densenet161": 2208,
                    "densenet169": 1664,
                    "densenet201": 1920,
                }
                model.classifier = nn.Linear(in_features[model_name], num_classes)
            else:
                raise NotImplementedError(f'{model_name}')
        return model

    def build_model(self):
        return self.load_pretrained_model_('resnet18', 10)

    # Hypes
    @property
    def loss(self):
        if self.criterion is None:
            self.criterion = self.configure_criterion()
        return self.criterion

    def configure_criterion(self):
        # default
        loss = nn.CrossEntropyLoss(reduction='mean')
        return loss

    def configure_optimizer(self):
        # default
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=0.001)
        return optimizer

    def configure_scheduler(self, optimizer):
        # default
        scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
        return scheduler

    def configure_optimizers(self):
        optimizer = self.configure_optimizer()
        scheduler = self.configure_scheduler(optimizer)
        return [optimizer], [scheduler]

    def forward(self, x, *args, **kwargs):
        return self.model(x)

    def calculate_acc_(self, y_pred, y_true):
        return (torch.argmax(y_pred, axis=1) == y_true).float().mean()

    def step_(self, batch):
        x, y, _ = batch
        y_hat = self.model(x)
        loss = self.loss(y_hat, y)
        return x, y, y_hat, loss

    def get_dataloader(self, phase, \
                       trans_augment=None, trans_shuffle=False, \
                       normalize=False, input_size=None, \
                       batch_size=32, num_workers=1, \
                       drop_last=False, shuffle=False):
        if phase not in self.datasets.keys():
            raise RuntimeError(f'get {phase} data loader  error.')
        dataset = self.datasets[phase]
        dataset.set_aug_trans(transforms=trans_augment, shuffle=trans_shuffle)
        dataset.set_img_trans(input_size=input_size, normalize=normalize)
        return DataLoader(
                dataset,
                batch_size=batch_size,
                num_workers=num_workers,
                drop_last=drop_last,
                shuffle=shuffle)

    ## Train
    def train_dataloader(self) -> DataLoader:
        return self.get_dataloader(
            phase='train',
            trans_augment=[
                self.random_resized_crop(size=(128, 128)),
                self.random_brightness(factor=0.3),
                self.random_rotation(degrees=30)
            ],
            trans_shuffle=False, 
            input_size=128,
            normalize=True,
            batch_size=32,
            num_workers=1,
            drop_last=False,
            shuffle=False)

    def training_step(self, batch, batch_idx):
        x, y, y_hat, loss = self.step_(batch)
        with torch.no_grad():
            accuracy = self.calculate_acc_(y_hat, y)
        log = {'train_loss': loss, 'train_acc': accuracy}
        output = OrderedDict({
            'loss': loss,        # M
            'acc': accuracy,     # O
            'progress_bar': log, # O
            "log": log           # O
        })
        return output

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        log = {'train_loss': avg_loss, 'train_acc': avg_acc}
        return {'progress_bar': log, 'log': log}

    ## Valid
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader('val',
                batch_size=32,
                num_workers=2,
                drop_last=False,
                shuffle=False)

    def validation_step(self, batch, batch_idx):
        x, y, y_hat, loss = self.step_(batch)
        accuracy = self.calculate_acc_(y_hat, y)
        return {'val_loss': loss, 'val_acc': accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        log = {'val_loss': avg_loss, 'val_acc': avg_acc}
        return {'progress_bar': log, 'log': log}

    ## Test
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader('test',
                batch_size=32,
                num_workers=1,
                drop_last=False,
                shuffle=False)

    def test_step(self, batch, batch_idx):
        x, y, y_hat, loss = self.step_(batch)
        accuracy = self.calculate_acc_(y_hat, y)
        return {'test_loss': loss, 'test_acc': accuracy}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        log = {'test_loss': avg_loss, 'test_acc': avg_acc}
        return {'progress_bar': log, 'log': log}

    
class EasyaiTrainer(pl.Trainer):
    def __init__(self, *args, **kwargs):
        super(EasyaiTrainer, self).__init__(*args, **kwargs, gpus=[0])

In [18]:
trainer = EasyaiTrainer(max_epochs=2, logger=None, weights_summary='full')
trainer.fit(EasyaiClassifier())
trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

   | Name                        | Type              | Params
-------------------------------------------------------------------
0  | model                       | ResNet            | 11 M  
1  | model.conv1                 | Conv2d            | 9 K   
2  | model.bn1                   | BatchNorm2d       | 128   
3  | model.relu                  | ReLU              | 0     
4  | model.maxpool               | MaxPool2d         | 0     
5  | model.layer1                | Sequential        | 147 K 
6  | model.layer1.0              | BasicBlock        | 73 K  
7  | model.layer1.0.conv1        | Conv2d            | 36 K  
8  | model.layer1.0.bn1          | BatchNorm2d       | 128   
9  | model.layer1.0.relu         | ReLU              | 0     
10 | model.layer1.0.conv2        | Conv2d            | 36 K  
11 | model.layer1.0.bn2          | BatchNorm2d       | 128   
12 | model.layer1.1       

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': tensor(0.1364, device='cuda:0'),
 'test_loss': tensor(2.6977, device='cuda:0')}
--------------------------------------------------------------------------------



{'test_loss': 2.6976938247680664, 'test_acc': 0.136408731341362}