In [1]:
from kernelgan.dataloader import CropDataModule
from kernelgan.networks import KernelGAN

from zssr.ZSSR import ZSSR

import matplotlib.pyplot as plt

from pathlib import Path
import time

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

In [3]:
class Experiment:
    def __init__(self, img_path):
        self.img_path = img_path
        self.gan = KernelGAN()
        
    def train_gan(self, max_iters):
        self.data_dl = CropDataModule(
            img_path=self.img_path,
            d_input_shape=self.gan.d_input_shape,
            d_forward_shave=self.gan.D.forward_shave,
            max_iters=max_iters
        )
        trainer = pl.Trainer(
            gpus=1,
            max_epochs=1
        )
        trainer.fit(self.gan, self.data_dl)
        wandb.finish()
        self.gan.post_process_kernel()
        self.gan.save_kernel(Path('kernel.mat'))
        
    def run_zssr(self):
        noise_scale = 1.0
        start_time = time.time()
        print('~' * 30 + '\nRunning ZSSR X2...')
        sr = ZSSR(
            self.img_path.absolute(),
            scale_factor=2,
            kernels=[self.gan.kernel],
            is_real_img=True,
            noise_scale=noise_scale
        ).run()
        plt.imsave(
            'ZSSR_%s.png' % self.img_path.stem,
            sr,
            vmin=0,
            vmax=255 if sr.dtype == 'uint8' else 1.,
            dpi=1
        )
        runtime = int(time.time() - start_time)
        print('Completed! runtime=%d:%d\n' % (runtime // 60, runtime % 60) + '~' * 30)


In [4]:
kg = Experiment(img_path=Path('input.png'))

In [5]:
kg.train_gan(max_iters=3000)

  grads_comb = lm_x / lm_x.sum() + lm_y / lm_y.sum() + gmag / gmag.sum()
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type          | Params
-------------------------------------------------
0 | G              | Generator     | 150 K 
1 | D              | Discriminator | 31.0 K
2 | GAN_loss_layer | GANLoss       | 0     
3 | bicubic_loss   | DownScaleLoss | 0     
-------------------------------------------------
181 K     Trainable params
0         Non-trainable params
181 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




In [6]:
kg.run_zssr()

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Running ZSSR X2...

ZSSR configuration is for a real image
Completed! runtime=3:48
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
