<a href="https://colab.research.google.com/github/pvrancx/torch_isr/blob/master/colabs/bsd300.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! pip install pytorch-lightning

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
! mkdir -p "/content/gdrive/My Drive/pytorch/isr/bsd300"

In [0]:
% cd "/content/gdrive/My Drive/pytorch/isr/bsd300"

In [0]:
! git clone https://github.com/pvrancx/torch_isr.git

In [0]:
% cd /content/gdrive/My Drive/pytorch/isr/bsd300/torch_isr/
! git pull

In [0]:
%load_ext tensorboard
%tensorboard --logdir "/content/gdrive/My Drive/pytorch/isr/bsd300/logs/"

In [0]:
from pytorch_lightning import Trainer
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from isr.datasets.bsd import load_bsd300
from isr.datasets.isr import IsrDataset
from isr.lightning_model import LightningIsr
from isr.models.srcnn import SrCnn

import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
def main(scale_factor=2):
    bsd300_train = load_bsd300('../data', split='train')
    bsd300_test = load_bsd300('../data', split='test')

    train_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor()
    ])

    test_transforms = transforms.ToTensor()
    train_data = IsrDataset(
        bsd300_train,
        output_size=200,
        scale_factor=scale_factor,
        deterministic=False,
        transform=train_transforms,
        target_transform=transforms.ToTensor()
    )
    n_train = int(len(train_data) * 0.8)
    split = [n_train, len(train_data) - n_train]
    train_data, val_data = random_split(train_data, split)
    test_data = IsrDataset(
        bsd300_test,
        output_size=200,
        scale_factor=scale_factor,
        deterministic=True,
        transform=test_transforms,
        target_transform=transforms.ToTensor()
    )

    checkpoint_callback = ModelCheckpoint(
        filepath="/content/gdrive/My Drive/pytorch/isr/bsd300/weights.ckpt",
        save_top_k=3,
        verbose=False,
        monitor='val_loss',
        mode='min'
      )
    
    logger = TensorBoardLogger("/content/gdrive/My Drive/pytorch/isr/bsd300/logs", 
                               name="isr")


    model = LightningIsr(SrCnn, {'model_params': {'scale_factor': scale_factor}})
    trainer = Trainer(max_epochs=200,
                      logger=logger,
                      log_gpu_memory='min_max',
                      gpus=1,
                      checkpoint_callback=checkpoint_callback)
    trainer.fit(
        model,
        train_dataloader=DataLoader(train_data, shuffle=True, batch_size=32, num_workers=2),
        val_dataloaders=DataLoader(val_data, shuffle=False, batch_size=32, num_workers=2),
    )
    trainer.test(
        model,
        test_dataloaders=DataLoader(test_data, shuffle=False, batch_size=32, num_workers=2)
    )

In [0]:
main(scale_factor=2)

In [0]:

model = LightningIsr.load_from_checkpoint(
    '/content/gdrive/My Drive/pytorch/isr/bsd300/_ckpt_epoch_111.ckpt',
    model_factory=SrCnn, 
    hparams={'model_params':{'scale_factor': 2}}
    )


In [0]:
  bsd300_test = load_bsd300('../data', split='test')
  test_data = IsrDataset(
        bsd300_test,
        output_size=200,
        scale_factor=2,
        deterministic=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor()
    )

In [0]:
sample = test_data[10][0]
output = model(sample.view(1,3,100,100))

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(sample.view(3,100,100).permute([1,2,0]))
ax2.imshow(output.view(3,200,200).permute([1,2,0]).detach())