# <div align="center"> Pytorch Lightning: Classifier Letters </div> 

In [1]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%watermark -v -p numpy,matplotlib,torch,torchvision,pytorch_lightning,PIL

CPython 3.6.9
IPython 7.16.1

numpy 1.18.5
matplotlib 3.2.1
torch 1.6.0.dev20200609+cu101
torchvision 0.7.0.dev20200609+cu101
pytorch_lightning 1.1.2
PIL 7.1.2


In [2]:
import warnings

import os
import json
import struct
import numpy as np
import torch
import shutil
import torchvision
import pytorch_lightning as pl

from PIL import Image

warnings.filterwarnings('ignore')

## 数据集处理

- [数据集: emnist_letters: 8j58](https://pan.baidu.com/s/1COMbe7nuW7gS-hDLCT03Hw "提取码: 8j58")
    
   >原始图片是水平翻转并旋转90°的, 需要处理后使用
    
- 解压到`/data/datasets/cv/EMNIST_Letters`

In [3]:
def decode_idx3_ubyte(idx3_ubyte_file):
    with open(idx3_ubyte_file, 'rb') as fr:
        bin_data = fr.read()
 
    offset = 0
    fmt_header = '>iiii'
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('magic:%d, total: %d, size: %d*%d' % (magic_number, num_images, num_rows, num_cols))
 
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)
    fmt_image = '>' + str(image_size) + 'B'
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        # T: transpose
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols)).T
        offset += struct.calcsize(fmt_image)
    return images

def decode_idx1_ubyte(idx1_ubyte_file):
    with open(idx1_ubyte_file, 'rb') as fr:
        bin_data = fr.read()

    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('magic:%d, total: %d' % (magic_number, num_images))

    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images, dtype=int)
    for i in range(num_images):
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels

## 封装分类器

In [4]:
class EasyaiClassifier(pl.LightningModule):
    def __init__(self, trainer, model, optimizer, scheduler):
        super().__init__()
        self.model = model
        self.trainer = trainer
        self.optimizer = optimizer
        self.scheduler = scheduler
        # self.model.cuda()

    def _get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

    def forward(self, x):
        # x = x.cuda()
        return self.model(x)

    def criterion(self, inputs, targets):
        return torch.nn.functional.cross_entropy(inputs, targets)

    def configure_optimizers(self):
        return {'monitor': 'val_loss', 'optimizer': self.optimizer, 'lr_scheduler': self.scheduler}

    def training_step(self, batch, batch_idx):
        inputs, y_trues, paths = batch
        y_preds = self(inputs)
        loss = self.criterion(y_preds, y_trues)
        acc = (torch.argmax(y_preds, dim=1) == y_trues).float().mean()
        log = {'loss': loss, 'acc': acc}
        return log

    def training_epoch_end(self, outputs):
        log = {
            'lr': self._get_lr(),
            'train_loss': torch.stack([x['loss'] for x in outputs]).mean()
        }
        if 'acc' in outputs[0]:
            log['train_acc'] = torch.stack([x['acc'] for x in outputs]).mean()
        self.log_dict(log, prog_bar=True, on_epoch=True)

    def validation_step(self, batch, batch_idx):
        inputs, y_trues, paths = batch
        y_preds = self(inputs)
        loss = self.criterion(y_preds, y_trues)
        acc = (torch.argmax(y_preds, dim=1) == y_trues).float().mean()
        log = {'val_loss': loss, 'val_acc': acc}
        return log

    def validation_epoch_end(self, outputs):
        log = {
            'val_loss': torch.stack([x['val_loss'] for x in outputs]).mean()
        }
        if 'val_acc' in outputs[0]:
            log['val_acc'] = torch.stack([x['val_acc'] for x in outputs]).mean()
        self.log_dict(log, prog_bar=True, on_epoch=True)

    def test_step(self, batch, batch_idx):
        inputs, y_trues, paths = batch
        y_preds = self(inputs)
        acc = (torch.argmax(y_preds, dim=1) == y_trues).float().mean()
        log = {'test_acc': acc}
        return log

    def test_epoch_end(self, outputs):
        log = {}
        if 'test_acc' in outputs[0]:
            log['test_acc'] = torch.stack([x['test_acc'] for x in outputs]).mean().numpy()
        self.log_dict(log)

    def get_progress_bar_dict(self):
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def fit(self, train_loader, valid_loader):
        return self.trainer.fit(self, train_loader, valid_loader)

    def test(self, test_loader):
        return self.trainer.test(self, test_loader)

    def predict(self, test_loader):
        def predict_step(self, batch, batch_idx):
            inputs, tags, paths = batch
            y_preds = self(inputs)
            return list(zip(paths, tags, F.softmax(y_preds, dim=1)))

        def predict_epoch_end(self, outputs):
            result = {'output':[]}
            for item in outputs:
                for path, tag, preds in item:
                    probs = preds.cpu().numpy().astype(float).tolist()
                    result['output'].append({'image_path': path, 'image_id': tag, 'probs': probs})
            self.log_dict(result)
        try:
            _test_step = getattr(self.__class__, 'test_step', None)
            _test_epoch_end = getattr(self.__class__, 'test_epoch_end', None)
            setattr(self.__class__, 'test_step', predict_step)
            setattr(self.__class__, 'test_epoch_end', predict_epoch_end)
            return self.trainer.test(self, test_loader, verbose=True)
        except Exception as err:
            raise err
        finally:
            if _test_step:
                setattr(self.__class__, 'test_step', _test_step)
            if _test_epoch_end:
                setattr(self.__class__, 'test_epoch_end', _test_epoch_end)

class EasyaiTrainer(pl.Trainer):
    default_root_dir = '/tmp/emnist_letters'
    def __init__(self, logger, callbacks, pretrained=False, *args, **kwargs):
        self.resume_path = f'{self.default_root_dir}/checkpoints/best.ckpt'
        if pretrained:
            if not os.path.exists(self.resume_path):
                pretrained = False
        os.makedirs(f'{self.default_root_dir}/checkpoints', exist_ok=True)
        super().__init__(
            logger=logger,
            callbacks=callbacks,
            num_sanity_val_steps=0,
            resume_from_checkpoint=self.resume_path if pretrained else None,
            default_root_dir=self.default_root_dir,
            *args, **kwargs)
        
    def on_save_checkpoint(self):
        states = super().on_save_checkpoint()
        # for v in states.values():
        #     if 'best_model_path' in v:
        #         shutil.move(v['best_model_path'], self.resume_path)
        #         break
        return states

    def save_checkpoint(self, filepath, weights_only: bool = False): 
        return super().save_checkpoint(self.resume_path, weights_only)

## 数据加载器

In [5]:
class UbyteReaderDataset(torch.utils.data.Dataset):
    def __init__(self, source, imgtrans, augtrans=[]):
        self.images, self.labels = self.data_reader(source)
        self.imgtrans = torchvision.transforms.Compose(imgtrans)
        self.augtrans = torchvision.transforms.RandomOrder(augtrans)
        
    def data_reader(self, path):
        image_list = decode_idx3_ubyte(f'{path}-images-idx3-ubyte')
        label_list = decode_idx1_ubyte(f'{path}-labels-idx1-ubyte')
        return image_list, label_list

    def __getitem__(self, index):
        img = Image.fromarray(self.images[index]).convert('RGB')
        img = self.augtrans(img)
        return self.imgtrans(img), self.labels[index] - 1, index

    def __len__(self):
        return len(self.images)

In [6]:
# class JsonReaderDataset(torch.utils.data.Dataset):
#     def __init__(self, source, imgtrans, augtrans=[]):
#         self.images, self.labels = self.data_reader(source)
#         self.imgtrans = torchvision.transforms.Compose(imgtrans)
#         self.augtrans = torchvision.transforms.RandomOrder(augtrans)
#         
#     def data_reader(self, path):
#         image_list = []
#         label_list = []
#         root = os.path.dirname(path)
#         with open(path, 'r') as f:
#             items = json.load(f)
#             for item in items:
#                 image_list.append(os.path.join(root, item['image_path']))
#                 label_list.append(item['label'])
#         return image_list, label_list
# 
#     def __getitem__(self, index):
#         img = Image.open(self.images[index]).convert('RGB')
#         img = self.augtrans(img)
#         return self.imgtrans(img), self.labels[index], self.images[index]
# 
#     def __len__(self):
#         return len(self.images)

In [7]:
augtrans = [
    torchvision.transforms.RandomRotation(degrees=75)
]

imgtrans = [
    torchvision.transforms.Resize((28, 28)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.1362,), std=(0.2893,)),
]

IMG_ROOT='/data/datasets/cv/EMNIST_Letters'

# train_dataset = JsonReaderDataset(
#     f'{IMG_ROOT}/emnist-letters-train', # f'{IMG_ROOT}/train.json',
#     imgtrans=imgtrans,
#     # augtrans=augtrans
# )

# val_dataset = JsonReaderDataset(
#     f'{IMG_ROOT}/val.json',
#     imgtrans=imgtrans
# )

train_dataset = UbyteReaderDataset(
    f'{IMG_ROOT}/emnist-letters-train',
    imgtrans=imgtrans
)

val_dataset = UbyteReaderDataset(
    f'{IMG_ROOT}/emnist-letters-test',
    imgtrans=imgtrans
)


train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=250, drop_last=True, num_workers=8,
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=250, num_workers=8
)

magic:2051, total: 124800, size: 28*28
magic:2049, total: 124800
magic:2051, total: 20800, size: 28*28
magic:2049, total: 20800


## 参数配置

In [8]:
class BBResnet18(torch.nn.Module):
    def __init__(self, num_classes, use_gpu=True):
        super().__init__()
        self.bbmodel = torchvision.models.resnet18(pretrained=True)
        self.bbmodel.fc = torch.nn.Linear(self.bbmodel.fc.in_features, num_classes) 
        # self.use_gpu = use_gpu
        # if use_gpu:
        #     self.bbmodel.cuda()
    def forward(self, x):
        # if self.use_gpu:
        #     x = x.cuda()
        return self.bbmodel(x)

model = BBResnet18(26)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.1,
    patience=3,
    min_lr=1e-6,
    threshold=1e-4,
    verbose=False)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=23, gamma=0.1)

cb_early_stop = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=5,
    verbose=True,
    min_delta=0.0001)

cb_model_checkpoint = pl.callbacks.ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    period=2,
    save_weights_only=False,
    filename="best")

trainer = EasyaiTrainer(
    logger=False,
    callbacks=[cb_early_stop, cb_model_checkpoint],
    pretrained=True, 
    max_epochs=20,
    gpus=[0],
    )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


## 训练

In [9]:
classifer = EasyaiClassifier(trainer, model, optimizer, scheduler)

classifer.fit(train_dataloader, val_dataloader)


  | Name  | Type       | Params
-------------------------------------
0 | model | BBResnet18 | 11.2 M
-------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1