## Install dependencies

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version "nightly"

In [None]:
!pip install -U pytorch_lightning

In [None]:
#https://drive.google.com/file/d/1KBtB-kk3O6YTHqQOAz4weDVsSkWbXvlt/view?usp=sharing

## Define dataset and dataloader via pytorch-lightning datamodule

In [None]:
import os
from pathlib import Path

import albumentations as A
import cv2
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from albumentations.pytorch import ToTensorV2
from albumentations.augmentations.transforms import Blur, RandomBrightness
from torch.utils.data import DataLoader, Dataset


class ChineseMNISTDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_root: Path,
        transform: A.BasicTransform = None,
    ) -> None:
        super().__init__()
        self.df = df
        self.image_root = image_root
        self.transform = transform

    def __getitem__(self, idx: int):
        row = self.df.loc[idx, :]
        suite_id, code, sample_id = row.suite_id, row.code, row.sample_id
        filename = self.image_root / f"input_{suite_id}_{sample_id}_{code}.jpg"
        assert os.path.isfile(filename), f"{filename} is not a file"
        image = cv2.imread(str(filename))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = image[:, np.newaxis]
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image, code - 1

    def __len__(self):
        return len(self.df)


class ChineseMNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_root: Path,
        all_df: pd.DataFrame,
        train_indices: pd.Index,
        val_indices: pd.Index,
        batch_size: int = 64
    ) -> None:
        super().__init__()
        self.data_root = data_root
        self.df = all_df
        self.image_root = self.data_root / "data" / "data"
        self.train_df = self.df.loc[train_indices, :].copy().reset_index()
        self.train_transform = A.Compose(
            [
                Blur(),
                RandomBrightness(),
                ToTensorV2(),
            ]
        )
        self.val_df = self.df.loc[val_indices, :].copy().reset_index()
        self.val_transform = A.Compose(
            [
                ToTensorV2(),
            ]
        )
        self.batch_size = batch_size

    def train_dataloader(self):
        ds = ChineseMNISTDataset(self.train_df, self.image_root, self.train_transform)
        return DataLoader(
            ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )

    def val_dataloader(self):
        ds = ChineseMNISTDataset(self.val_df, self.image_root, self.val_transform)
        return DataLoader(
            ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

    def test_dataloader(self):
        return self.val_dataloader()
    
# sanity check
is_kaggle = os.path.isdir("/kaggle")
data_root = Path("/kaggle/input/chinese-mnist" if is_kaggle else "archive")
assert os.path.isdir(data_root), f"{data_root} is not a dir"
df = pd.read_csv(data_root / "chinese_mnist.csv")

data_module = ChineseMNISTDataModule(data_root, df, df.index[:20], df.index[20:30])

## Model definition

In [None]:
import os
from pathlib import Path

import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.metrics import Accuracy
from sklearn.model_selection import StratifiedKFold
from torch import nn, optim
from torchvision.models import resnet18


class ChineseMNISTResnetModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_classes = 15
        resnet = resnet18(pretrained=True, progress=True)
        resnet.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=resnet.conv1.out_channels,
            kernel_size=resnet.conv1.kernel_size,
            stride=resnet.conv1.stride,
            dilation=resnet.conv1.dilation,
            bias=resnet.conv1.bias,
        )
        resnet.fc = nn.Linear(512, self.num_classes)
        self.resnet = resnet
        self.accuracy = Accuracy()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, image):
        image = image.permute(0, 3, 1, 2).contiguous().float()
        return self.resnet(image)

    def training_step(self, batch, batch_idx: int):
        image, y = batch
        yhat = self(image)
        loss = self.criterion(yhat, y)
        acc = self.accuracy(yhat, y)
        return loss

    def validation_step(self, batch, batch_idx: int, log: bool = True):
        image, y = batch
        yhat = self(image)
        loss = self.criterion(yhat, y)
        acc = self.accuracy(yhat, y)
        if log:
            self.log('val_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
            self.log('val_acc', acc, prog_bar=True, on_epoch=True, on_step=False)        
        return {'val_loss': loss, 'val_acc': acc}

    def test_step(self, batch, batch_idx):
        metrics = self.validation_step(batch, batch_idx, log = False)
        self.log('test_loss', metrics["val_loss"], on_epoch=True, on_step=False)
        self.log('test_acc', metrics["val_acc"], on_epoch=True, on_step=False)    
        return {"test_acc": metrics["val_acc"], "test_loss": metrics["val_loss"]}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


In [None]:
is_kaggle = os.path.isdir("/kaggle")
data_root = Path("/kaggle/input/chinese-mnist" if is_kaggle else "archive")
all_df = pd.read_csv(data_root / "chinese_mnist.csv")

skf = StratifiedKFold(n_splits=5, shuffle=True)

checkpoint_callback = ModelCheckpoint(
    filepath=os.getcwd(),
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min",
)
trainer = pl.Trainer(
    # gpus=1,
    tpu_cores=8,
    max_epochs=20,
    precision=16,
    val_check_interval=1.,
    callbacks=[checkpoint_callback]
)

for train_indices, val_indices in skf.split(all_df, all_df.code):
    data_module = ChineseMNISTDataModule(
        data_root=data_root,
        all_df=all_df,
        train_indices=train_indices,
        val_indices=val_indices,
        batch_size=32
    )
    model = ChineseMNISTResnetModel()
    trainer.fit(model, datamodule=data_module)
    break

In [None]:
model.load_state_dict(torch.load(checkpoint_callback.best_model_path)["state_dict"])
trainer.test(test_dataloaders=data_module.train_dataloader(),ckpt_path=None)

In [None]:
trainer.test(test_dataloaders=data_module.val_dataloader(),ckpt_path=None)