<a href="https://colab.research.google.com/github/pvrancx/torch_isr/blob/master/colabs/pretrain_discriminator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Supervised pretraining of discriminator before use in GAN training

In [0]:
! pip install pytorch-lightning

In [0]:
% cd "/content"
! git clone "https://github.com/pvrancx/torch_isr.git"
% cd "/content/torch_isr"
! git pull

In [0]:
from pytorch_lightning import Trainer
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from isr.datasets.bsd import load_bsd300
from isr.datasets.transforms import ChannelSelect
from isr.datasets.isr import IsrDataset
from isr.datasets.gan import DiscriminatorDataset
from isr.models.srcnn import SrCnn, SubPixelSrCnn
from isr.models.srresnet import SrResNet
from isr.models.srgan import SrGan, Discriminator

from super_resolve import super_resolve_ycbcr, super_resolve_rgb

import os
from PIL import Image
from argparse import ArgumentParser
import matplotlib.pyplot as plt
%matplotlib inline

Create ISR dataset

In [0]:
bsd300_train = load_bsd300('../data', split='train', image_mode='RGB')

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    ])

isr_data = IsrDataset(
  bsd300_train,
  output_size=96,
  scale_factor=4,
  deterministic=False,
  base_image_transform=train_transforms,
  transform=transforms.ToTensor(),
  target_transform=transforms.ToTensor()
)


Load trained ISR model and create dataset of real and fake images

In [0]:
generator = SrResNet.load_from_checkpoint('/content/torch_isr/trained_models/srresnet4x.ckpt')
dataset = DiscriminatorDataset(isr_data, generator)

In [0]:
n_train = int(len(dataset) * 0.8)
split = [n_train, len(dataset) - n_train]

train_data, val_data = random_split(dataset, split)
train_loader = DataLoader(train_data, shuffle=True, batch_size=32, num_workers=2)
val_loader = DataLoader(val_data, shuffle=False, batch_size=32, num_workers=2)

In [0]:
base_dir = '/content/discriminator'
chkptdir_dir = os.path.join(base_dir, 'checkpoints')

!mkdir -p chkptdir_dir

checkpoint_callback = ModelCheckpoint(
  filepath=chkptdir_dir,
  save_top_k=3,
  verbose=False,
  monitor='val_loss',
  mode='min'
)

logger = TensorBoardLogger(base_dir, name="tb_logs")

Create binary classifier

In [0]:
# get default arg settings
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser = Discriminator.add_model_specific_args(parser)
hparams = parser.parse_args([])

# customize settings
hparams.img_shape = (96, 96)
hparams.lr = 1e-4
hparams.in_channels = 3
hparams.max_epochs = 200

model = Discriminator(hparams)

In [0]:
%reload_ext tensorboard
%tensorboard --logdir "$base_dir"

Train discriminator as standard classifier to distinguish between real and fake images

In [0]:
trainer = Trainer(
  max_epochs=200,
  logger=logger,
  log_gpu_memory='min_max',
  gpus=1,
  checkpoint_callback=checkpoint_callback
)

trainer.fit(
  model,
  train_dataloader=train_loader,
  val_dataloaders=val_loader,
)

Try discriminator

In [0]:
img_lr, img_hr = isr_data[0]

Real image:

In [0]:
plt.imshow(img_hr.squeeze(0).permute([1,2,0]))

Generated image:

In [0]:
fake_img = generator(img_lr.unsqueeze(0)).detach()

In [0]:
plt.imshow(fake_img.squeeze(0).permute([1,2,0]))

Test discriminator on both images

In [0]:
model(img_hr.unsqueeze(0))

In [0]:
model(fake_img)