<a href="https://colab.research.google.com/github/pvrancx/torch_isr/blob/master/colabs/bsd300.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! pip install pytorch-lightning

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
! mkdir -p "/content/gdrive/My Drive/pytorch/isr/bsd300"

In [0]:
% cd "/content/gdrive/My Drive/pytorch/isr/bsd300"
! git clone "https://github.com/pvrancx/torch_isr.git"

In [0]:
% cd /content/gdrive/My Drive/pytorch/isr/bsd300/torch_isr/
! git pull

In [0]:
%load_ext tensorboard
%tensorboard --logdir "/content/gdrive/My Drive/pytorch/isr/bsd300/logs/"

In [0]:
from pytorch_lightning import Trainer
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from isr.datasets.bsd import load_bsd300
from isr.datasets.isr import IsrDataset
from isr.models.srcnn import SrCnn

from argparse import ArgumentParser
import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
def main():
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = SrCnn.add_model_specific_args(parser)
    hparams = parser.parse_args('')

    bsd300_train = load_bsd300('../data', split='train')
    bsd300_test = load_bsd300('../data', split='test')

    train_data = IsrDataset(
        bsd300_train,
        output_size=32,
        scale_factor=hparams.scale_factor,
        deterministic=False,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor()
    )
    n_train = int(len(train_data) * 0.8)
    split = [n_train, len(train_data) - n_train]
    train_data, val_data = random_split(train_data, split)
    test_data = IsrDataset(
        bsd300_test,
        output_size=32,
        scale_factor=hparams.scale_factor,
        deterministic=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor()
    )

    checkpoint_callback = ModelCheckpoint(
        filepath="/content/gdrive/My Drive/pytorch/isr/bsd300/weights.ckpt",
        save_top_k=3,
        verbose=False,
        monitor='val_loss',
        mode='min'
      )
    
    logger = TensorBoardLogger("/content/gdrive/My Drive/pytorch/isr/bsd300/logs", 
                               name="isr")
    
    train_loader = DataLoader(train_data, shuffle=True, batch_size=32, num_workers=2)
    val_loader = DataLoader(val_data, shuffle=False, batch_size=32, num_workers=2)

    model = SrCnn(hparams)
    trainer = Trainer(
        max_epochs=2000,
        logger=logger,
        log_gpu_memory='min_max',
        gpus=1,
        checkpoint_callback=checkpoint_callback
    )
    lr_finder = trainer.lr_find(
        model,
        train_dataloader=train_loader,
        val_dataloaders=val_loader
    )
    lr_finder.results

    fig = lr_finder.plot(suggest=True)
    logger.experiment.add_figure('learning_rate', fig)
    new_lr = lr_finder.suggestion()

    print('lr found %1.5f:'%new_lr)

    model.hparams.learning_rate = new_lr


    trainer.fit(
        model,
        train_dataloader=train_loader,
        val_dataloaders=val_loader,
    )
    trainer.test(
        model,
        test_dataloaders=DataLoader(test_data, shuffle=False, batch_size=32, num_workers=2)
    )

In [0]:
main()

In [0]:

model = SrCnn.load_from_checkpoint(
    '/content/gdrive/My Drive/pytorch/isr/bsd300/_ckpt_epoch_1800.ckpt')


In [0]:
  bsd300_test = load_bsd300('../data', split='test')
  test_data = IsrDataset(
        bsd300_test,
        output_size=200,
        scale_factor=2,
        deterministic=True,
        transform=transforms.ToTensor(),
        target_transform=transforms.ToTensor()
    )

In [0]:
import torch.nn.functional as F
sample = test_data[2][0]
interpolate = F.interpolate(sample.view(1,3,100,100), scale_factor=2, mode='bicubic')
output = model(sample.view(1,3,100,100))

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,20))
ax1.imshow(sample.view(3,100,100).permute([1,2,0]))
ax2.imshow(interpolate.view(3,200,200).permute([1,2,0]).detach())
ax3.imshow(output.view(3,200,200).permute([1,2,0]).detach())

In [0]:
import torch
torch.max(output)