<a href="https://colab.research.google.com/github/pvrancx/torch_isr/blob/master/train_srresnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! pip install pytorch-lightning

Checkout ISR github repo

In [0]:
% cd "/content"
! git clone "https://github.com/pvrancx/torch_isr.git"
% cd "/content/torch_isr"
! git pull

In [0]:
from pytorch_lightning import Trainer
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from isr.datasets.bsd import load_bsd300
from isr.datasets.transforms import ChannelSelect
from isr.datasets.isr import IsrDataset
from isr.models.srcnn import SrCnn, SubPixelSrCnn
from isr.models.srresnet import SrResNet
from super_resolve import super_resolve_ycbcr, super_resolve_rgb

import os
from PIL import Image
from argparse import ArgumentParser
import matplotlib.pyplot as plt
%matplotlib inline

Configure experiment

In [0]:
img_mode = 'RGB'
scale_factor = 4
output_size = 200
max_epochs = 2000
model_type = SrResNet

Download and load datasets

In [0]:
bsd300_train = load_bsd300('../data', split='train', image_mode=img_mode)
bsd300_test = load_bsd300('../data', split='test', image_mode=img_mode)

if img_mode == 'RGB':
  train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    ])
  test_transforms = None
elif img_mode == 'yCbCr':
    train_transforms = transforms.Compose([
      ChannelSelect(0),
      transforms.RandomHorizontalFlip(),
      transforms.RandomRotation(15),
    ])
    test_transforms = ChannelSelect(0)

train_data = IsrDataset(
  bsd300_train,
  output_size=output_size,
  scale_factor=scale_factor,
  deterministic=False,
  base_image_transform=train_transforms,
  transform=transforms.ToTensor(),
  target_transform=transforms.ToTensor()
)

n_train = int(len(train_data) * 0.8)
split = [n_train, len(train_data) - n_train]

train_data, val_data = random_split(train_data, split)
test_data = IsrDataset(
  bsd300_test,
  output_size=output_size,
  scale_factor=scale_factor,
  deterministic=True,
  base_image_transform=test_transforms,
  transform=transforms.ToTensor(),
  target_transform=transforms.ToTensor()
)

train_loader = DataLoader(train_data, shuffle=True, batch_size=32, num_workers=2)
val_loader = DataLoader(val_data, shuffle=False, batch_size=32, num_workers=2)
test_loader = DataLoader(test_data, shuffle=False, batch_size=32, num_workers=2)

Mount google drive to setup logging and checkpointing

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
base_dir = os.path.join('/content/gdrive/My Drive/pytorch/isr/bsd300/', model_type.__name__)
chkptdir_dir = os.path.join(base_dir, 'checkpoints')

!mkdir -p chkptdir_dir

checkpoint_callback = ModelCheckpoint(
  filepath=chkptdir_dir,
  save_top_k=3,
  verbose=False,
  monitor='val_loss',
  mode='min'
)

logger = TensorBoardLogger(base_dir, name="tb_logs")

Configure and create model

In [0]:
# get default arg settings
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser = model_type.add_model_specific_args(parser)
hparams = parser.parse_args([])

# customize settings
hparams.img_mode = img_mode
hparams.scale_factor = scale_factor
hparams.learning_rate = 1e-4
hparams.in_channels = 3
hparams.n_blocks = 16
hparams.max_epochs = max_epochs
hparams.lr_epochs = max_epochs

model = model_type(hparams)


Start tensorboard for monitoring

In [0]:
%reload_ext tensorboard
%tensorboard --logdir "$base_dir"

Train model

In [0]:
trainer = Trainer(
  max_epochs=max_epochs,
  logger=logger,
  log_gpu_memory='min_max',
  gpus=1,
  checkpoint_callback=checkpoint_callback
)

trainer.fit(
  model,
  train_dataloader=train_loader,
  val_dataloaders=val_loader,
)

best_model_path = min(
    checkpoint_callback.best_k_models, 
    key=checkpoint_callback.best_k_models.get
)

print(best_model_path)

Load and evaluate best model checkpoint

In [0]:

best_model = model_type.load_from_checkpoint('/content/gdrive/My Drive/pytorch/isr/bsd300/SrResNet/_ckpt_epoch_1553.ckpt')

test_trainer = Trainer(gpu=1,
  train_dataloader=train_loader,
  val_dataloaders=val_loader,
)

test_trainer.train_dataloader = train_loader
test_trainer.val_dataloaders = [val_loader]
test_trainer.test(
  best_model,
  test_dataloaders=test_loader
)

Load full images and apply model

In [0]:
bsd300_test = load_bsd300('../data', split='test', image_mode=img_mode)
test_data = IsrDataset(
        bsd300_test,
        output_size=None,
        scale_factor=4,
        deterministic=True
)

In [0]:
img_idx = 28

sample, target = test_data[img_idx]
scale_factor = best_model.scale_factor
interpolate = sample.resize((sample.size[0] * scale_factor, sample.size[1] * scale_factor), Image.BICUBIC)
if img_mode == 'RGB':
  output = super_resolve_rgb(best_model, sample)
else:
  output = super_resolve_ycbcr(best_model, sample)

fig, (ax1, ax2,ax3, ax4) = plt.subplots(1, 4, figsize=(20,20))
ax1.imshow(sample)
ax1.set_title('low res')
ax1.set_axis_off()
ax2.imshow(interpolate)
ax2.set_title('bicubic')
ax2.set_axis_off()
ax3.imshow(output)
ax3.set_title('srresnet')
ax3.set_axis_off()
ax4.imshow(target)
ax4.set_title('original')
ax4.set_axis_off()
