# Setup

In [None]:
%%capture
#!pip install pytorch-lightning Augmentor
!pip install Augmentor
!pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

# The dataset

In [None]:
import os
import shutil
import random
from random import Random
import numpy as np
from PIL import Image
import torch
import Augmentor
import torchvision.datasets as dset
from torch.utils.data.dataset import Dataset
from torchvision import transforms


def copy_alphabets(write_dir, alphabets):
    for alphabet in alphabets:
        alpha_dir = os.path.basename(os.path.normpath(alphabet)) + '_'
        for char in os.listdir(alphabet):
            char = os.fsdecode(char)
            dir_name = alpha_dir + char

            val_path = os.path.join(write_dir, dir_name)
            os.makedirs(val_path)

            char_path = os.path.join(alphabet, char)
            for drawer in os.listdir(char_path):
                drawer_path = os.path.join(char_path, drawer)
                shutil.copyfile(
                    drawer_path, os.path.join(
                        val_path, drawer
                    )
                )


# adapted from https://github.com/kevinzakka/one-shot-siamese
class Omniglot(dset.ImageFolder):
    resources = [
        ("https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip", "68d2efa1b9178cc56df9314c21c6e718"),
        ("https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip", "6b91aef0f799c5bb55b94e3f2daec811")
    ]

    def __init__(self, data_path, mode, seed=0):
        self.raw_path = os.path.join(data_path, 'raw')
        self.processed_path = os.path.join(data_path, 'processed')
        self._rng_seed = seed
        self._download()
        super().__init__(root=os.path.join(self.processed_path, mode))

    def _processed_check_exists(self):
        return (os.path.exists(os.path.join(self.processed_path, 'train')) and
                os.path.exists(os.path.join(self.processed_path, 'valid')) and
                os.path.exists(os.path.join(self.processed_path, 'test')))

    def _raw_check_exists(self):
        return (os.path.exists(os.path.join(self.raw_path, 'images_background')) and
                os.path.exists(os.path.join(self.raw_path, 'images_evaluation')))

    def _download(self):
        if self._processed_check_exists():
            return

        if self._raw_check_exists():
            self._process()
            return

        import shutil
        os.makedirs(self.raw_path, exist_ok=True)
        for (url, md5) in self.resources:
            filename = url.rpartition('/')[2]
            dset.utils.download_and_extract_archive(url, download_root=self.raw_path, filename=filename, md5=md5)

        bg_path = os.path.join(self.raw_path, 'images_background')
        eval_path = os.path.join(self.raw_path, 'images_evaluation')
        for d in sorted(next(os.walk(eval_path))[1])[:10]:
            shutil.move(os.path.join(eval_path, d), bg_path)

        self._process()

    def _process(self):
        np.random.seed(self._rng_seed)
        os.makedirs(self.processed_path, exist_ok=True)

        back_dir = os.path.join(self.raw_path, 'images_background')
        eval_dir = os.path.join(self.raw_path, 'images_evaluation')
        write_dir = self.processed_path

        # get list of all alphabets
        background_alphabets = [os.path.join(back_dir, x) for x in next(os.walk(back_dir))[1]]
        background_alphabets.sort()

        # list of all drawers (1 to 20)
        background_drawers = list(np.arange(1, 21))
        print("There are {} alphabets.".format(len(background_alphabets)))

        # from 40 alphabets, randomly select 30
        train_alphabets = list(np.random.choice(background_alphabets, size=30, replace=False))

        valid_alphabets = [x for x in background_alphabets if x not in train_alphabets]
        test_alphabets = [os.path.join(eval_dir, x) for x in next(os.walk(eval_dir))[1]]

        train_alphabets.sort()
        valid_alphabets.sort()
        test_alphabets.sort()

        copy_alphabets(os.path.join(write_dir, 'train'), train_alphabets)
        copy_alphabets(os.path.join(write_dir, 'valid'), valid_alphabets)
        copy_alphabets(os.path.join(write_dir, 'test'), test_alphabets)


# from https://github.com/kevinzakka/one-shot-siamese
class OmniglotTrain(Dataset):
    def __init__(self, dataset, num_train, augment=False):
        super().__init__()
        self.dataset = dataset
        self.num_train = num_train
        self.augment = augment

    def __len__(self):
        return self.num_train

    def __getitem__(self, index):
        image1 = random.choice(self.dataset.imgs)

        # get image from same class
        label = None
        if index % 2 == 1:
            label = 1.0
            while True:
                image2 = random.choice(self.dataset.imgs)
                if image1[1] == image2[1]:
                    break
        # get image from different class
        else:
            label = 0.0
            while True:
                image2 = random.choice(self.dataset.imgs)
                if image1[1] != image2[1]:
                    break
        image1 = Image.open(image1[0])
        image2 = Image.open(image2[0])
        image1 = image1.convert('L')
        image2 = image2.convert('L')

        # apply transformation on the fly
        if self.augment:
            p = Augmentor.Pipeline()
            p.rotate(probability=0.5, max_left_rotation=15, max_right_rotation=15)
            p.random_distortion(
                probability=0.5, grid_width=6, grid_height=6, magnitude=10,
            )
            trans = transforms.Compose([
                p.torch_transform(),
                transforms.ToTensor(),
            ])
        else:
            trans = transforms.ToTensor()

        image1 = trans(image1)
        image2 = transforms.ToTensor()(image2)
        y = torch.from_numpy(np.array([label], dtype=np.float32))
        return (image1, image2, y)


# from https://github.com/kevinzakka/one-shot-siamese
class OmniglotTest(Dataset):
    def __init__(self, dataset, trials, way, seed=0):
        super().__init__()
        self.dataset = dataset
        self.trials = trials
        self.way = way
        self.transform = transforms.ToTensor()
        self.seed = seed
        self.img1 = None

    def __len__(self):
        return (self.trials * self.way)

    def __getitem__(self, index):
        self.rng = Random(self.seed + index)

        idx = index % self.way
        label = None
        # generate image pair from same class
        if idx == 0:
            label = 1.0
            self.img1 = self.rng.choice(self.dataset.imgs)
            while True:
                img2 = self.rng.choice(self.dataset.imgs)
                if self.img1[1] == img2[1]:
                    break
        # generate image pair from different class
        else:
            label = 0.0
            while True:
                img2 = self.rng.choice(self.dataset.imgs)
                if self.img1[1] != img2[1]:
                    break

        img1 = Image.open(self.img1[0])
        img2 = Image.open(img2[0])
        img1 = img1.convert('L')
        img2 = img2.convert('L')
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        y = torch.from_numpy(np.array([label], dtype=np.float32))
        return (img1, img2, y)


# The Model

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from typing_extensions import Final
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from argparse import ArgumentParser


class CNNLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            # 1-channel input
            nn.Conv2d(1, 64, kernel_size=10),
            nn.MaxPool2d(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=7),
            nn.MaxPool2d(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=4),
            nn.MaxPool2d(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=4),
            nn.ReLU(inplace=True)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.cnn(x)


class TwinNet(pl.LightningModule):
    @staticmethod
    def add_model_specific_args(parser):
        parser = ArgumentParser(parents=[parser], add_help=False)
        parser.add_argument('--learning_rate', type=float, default=1e-3, help='Initial learning rate used by auto_lr_find')
        parser.add_argument('--batch_size', type=int, default=128)
        parser.add_argument('--num_workers', type=int, default=1, help='number of workers used by DataLoader')
        parser.add_argument('--trials', type=int, default=320)
        parser.add_argument('--train_classes', type=int, default=20)
        parser.add_argument('--num_train', type=int, default=50000)
        parser.add_argument('--rng_seed', type=int, default=1)
        parser.add_argument('--data_path', type=str, default='./data/')
        return parser

    def __init__(self, learning_rate, batch_size,
                 num_workers, data_path, rng_seed,
                 train_classes, trials, num_train, **kwargs):
        super().__init__()
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.save_hyperparameters('learning_rate', 'batch_size')

        # TODO: Not yet exactly clear on what these mean...
        self._way: Final = train_classes
        self._trials: Final = trials
        self._num_train: Final = num_train

        self._rng_seed: Final = rng_seed
        self._data_path: Final = data_path

        # TODO: Does this work right, what about TPUs?
        self._num_workers: Final = num_workers
        self._pin_memory: Final = False
        if torch.cuda.is_available():
            self._num_workers = num_workers
            self._pin_memory = True

        self.cnn: nn.Module = CNNLayer()
        # 256*6*6 = 9216
        self.fcl: nn.Module = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid())
        self.out: nn.Module = nn.Linear(4096, 1)

        self.train_accuracy = pl.metrics.Accuracy()
        self.val_accuracy = pl.metrics.Accuracy(compute_on_step=False)
        self.test_accuracy = pl.metrics.Accuracy(compute_on_step=False)

    # prediction/inference
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        # print(x1.shape)
        x1 = self.cnn(x1)
        # print('{0}::{1}'.format(x1.shape, x1.size()))
        x1 = x1.view(x1.size()[0], -1)
        # print(x1.shape)
        x1 = self.fcl(x1)

        x2 = self.cnn(x2)
        x2 = x2.view(x2.size()[0], -1)
        x2 = self.fcl(x2)

        dist = torch.abs(x1 - x2)

        return self.out(dist)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=(self.learning_rate or self.lr))
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
        return [optimizer], [{ 'scheduler': scheduler, 'monitor': 'val_loss', 'interval': 'epoch' }]

    @staticmethod
    def loss(x: torch.Tensor, y: torch.Tensor):
        return F.binary_cross_entropy_with_logits(x, y)

    def training_step(self, batch, batch_idx):
        x1, x2, y = batch
        out = self.forward(x1, x2)
        loss = self.loss(out, y)

        acc = self.train_accuracy(out, y)
        self.log('train_acc', acc, on_step=True, on_epoch=False)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def train_epoch_end(self):
        self.log('learning_rate_epoch', self.learning_rate or self.lr)
        self.log('train_acc_epoch', self.train_accuracy.compute())

        # Graph the model, requires input data for forward(),
        # so we need to do it here as we need a dataloader().
        if self.current_epoch == 1:
            x, _ = next(iter(self.train_dataloader()))
            self.logger.log_graph(TwinNet(batch_size=128, learning_rate=1e-3, rng_seed=0, train_classes=20, trials=320, num_train=50000, num_workers=4, data_path='./data/'), x)

    def validation_step(self, batch, batch_idx):
        x1, x2, y = batch
        out = self.forward(x1, x2)
        loss = self.loss(out, y)

        self.val_accuracy(out, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)

    def validation_epoch_end(self, outputs):
        self.log('val_acc_epoch', self.val_accuracy.compute())

    def test_step(self, batch, batch_idx):
        x1, x2, y = batch
        out = self.forward(x1, x2)
        loss = self.loss(out, y)

        self.test_accuracy(out, y)
        self.log('test_loss', loss, on_step=False, on_epoch=True)

    def test_epoch_end(self, outputs):
        self.log('test_acc_epoch', self.test_accuracy.compute())

    def prepare_data(self):
        # TODO: Download and augment data here
        #dset.ImageFolder(root=os.path.join(self._data_path, 'train'))
        self._train_dataset = Omniglot(data_path=self._data_path, mode='train')
        self._val_dataset = Omniglot(data_path=self._data_path, mode='valid')
        self._test_dataset = Omniglot(data_path=self._data_path, mode='test')

    def setup(self, stage):
        if stage == 'fit':
            self.training_set = OmniglotTrain(self._train_dataset, augment=True,
                                              num_train=self._num_train)
            self.validation_set = OmniglotTest(self._val_dataset, seed=self._rng_seed,
                                               trials=self._trials, way=self._way)
        if stage == 'test':
            self.test_set = OmniglotTest(self._test_dataset, seed=self._rng_seed,
                                         trials=self._trials, way=self._way)

    def train_dataloader(self):
        return DataLoader(
            self.training_set, batch_size=self.batch_size,
            shuffle=True, num_workers=self._num_workers, pin_memory=self._pin_memory,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validation_set, batch_size=self._way, shuffle=False,
            num_workers=self._num_workers, pin_memory=self._pin_memory,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_set, batch_size=self._way, shuffle=False,
            num_workers=self._num_workers, pin_memory=self._pin_memory,
        )


# Train

In [None]:
    pl.seed_everything(0)
    logger = TensorBoardLogger('./lightning_logs/', name='snn')
    logger.log_hyperparams({'learning_rate': 0.001, 'rng_seed': 0, 'train_classes': 20, 'trials': 320, 'num_train': 50000, 'tpu_cores': 8, 'num_workers': 4, 'batch_size': 128, 'max_epochs': 5, 'profiler': True})

    model = TwinNet(batch_size=128, learning_rate=1e-3, rng_seed=0, train_classes=20, trials=320, num_train=50000, num_workers=4, data_path='./data/')

    # early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.05, patience=7, verbose=False, mode='min')
    # TODO: Is val_loss a good choice here?.
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath='./snn-omniglot-{epoch}', save_top_k=3, mode='min')
    trainer = pl.Trainer(progress_bar_refresh_rate=20, deterministic=True,
                         gpus=1, 
                         max_epochs=50,
                         logger=logger, checkpoint_callback=checkpoint_callback,
                         auto_lr_find=True)

    # Tune learning rate.
    trainer.tune(model)
    # Train model.
    trainer.fit(model)
    print('Best model saved to: ', checkpoint_callback.best_model_path)

# Results

In [None]:
# Test using best checkpoint.
trainer.test()

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/