In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
from matplotlib import image
plt.rcParams['figure.figsize'] = (10, 10) # set default size of plots
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['font.size'] = 16

import torch
import pytorch_lightning as pl
import numpy as np
from tqdm import tqdm
import os
from pathlib import Path

from src import *
from src.models.Tools.SuperResolver import *

GLOBAL_TRAINER_SETTINGS = {
    'check_val_every_n_epoch': 1,
    'gpus': [0],
    'checkpoint_callback': False
}

In [None]:
filename = 'data/Set5/image_SRF_4/img_002_SRF_4_LR.png'

img = image.imread(filename)

In [None]:
plt.imshow(img)

In [None]:
b = 40
x_offset = 20
y_offset = 20
rgb_mask = np.ones(img.shape)
rgb_mask[x_offset:x_offset+b,y_offset:y_offset+b,:] = np.zeros((b,b,3))

masked_image = img*rgb_mask
mask = rgb_mask[:,:,0]

## Use Case 1: Inpainting

In [None]:
input = rgb_mask*img

model = NeuralKnitwork(input,
                       mask = mask,
                       antialias = False,
                       kno_coef = 1.0,
                       nov_coef = 1.0,
                       lr = 4e-3,
                       epoch_steps = 1000
                      )

In [None]:
num_epochs = 4

trainer = pl.Trainer(
    max_epochs = num_epochs,
    **GLOBAL_TRAINER_SETTINGS
)
model.unfreeze()
trainer.fit(model)

In [None]:
output = model.generate()

plt.subplot(1,2,1)
plt.imshow(input)
plt.subplot(1,2,2)
plt.imshow(output)

## Use Case 2: Super-Resolution

In [None]:
model = SuperResolver(img,
                      upscale_factor = 2
                     )

In [None]:
num_epochs = 4

trainer = pl.Trainer(
    max_epochs = num_epochs,
    **GLOBAL_TRAINER_SETTINGS
)
model.unfreeze()
trainer.fit(model)

In [None]:
output = model.generate()

plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(output)

## Use Case 3: Denoising

In [None]:
input = img + 5e-2*np.random.randn(*img.shape)

model = NeuralKnitwork(input,
                       antialias = False
                      )

In [None]:
num_epochs = 4

trainer = pl.Trainer(
    max_epochs = num_epochs,
    **GLOBAL_TRAINER_SETTINGS
)
trainer.fit(model)

In [None]:
output = model.generate()

plt.subplot(1,2,1)
plt.imshow(input)
plt.subplot(1,2,2)
plt.imshow(output)