In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from dataset import SRDataset
from loss import *
from networks import SRGenerator, SRBayesCap
from bayescap import *

## Loading the pre-trained (frozen) SRGAN base model and instantiating BayesCap

In [2]:
device = "cuda:0"
NetG = SRGenerator()
NetG.load_state_dict(torch.load("../ckpt/srgan-ImageNet-bc347d67.pth", map_location=device))
# 
model_parameters = filter(lambda p: True, NetG.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of Parameters:", params)
# 
NetC = SRBayesCap(in_channels=3, out_channels=3)

Number of Parameters: 1547350


## Load the datasets

In [3]:
dataset_train = SRDataset(data_root="../data/SRGAN_ImageNet", image_size=(86,86), upscale_factor=4, mode="train")
dataset_val = SRDataset(data_root="../data/Set5/original", image_size=(256,256), upscale_factor=4, mode="val")
dataset_test = SRDataset(data_root="../data/Set5/original", image_size=(256,256), upscale_factor=4, mode="test")

loader_train = DataLoader(dataset_train, batch_size=2, pin_memory=True, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=1, pin_memory=True, shuffle=False)
loader_test = DataLoader(dataset_test, batch_size=1, pin_memory=True, shuffle=False)

## Training loop

In [4]:
train_BayesCap(
	NetC,
	NetG,
	loader_train,
	loader_val,
	Cri = TempCombLoss(alpha_eps=1e-5, beta_eps=1e-2),
	device=device,
	dtype=torch.cuda.FloatTensor,
	init_lr=1e-4,
	num_epochs=2000,
	eval_every=2,
	ckpt_path="../ckpt/BayesCap_SRGAN",
)

Epoch 0:   0%|          | 0/2446 [00:01<?, ?batch/s]


RuntimeError: The size of tensor a (84) must match the size of tensor b (86) at non-singleton dimension 3

## Evaluating BayesCap

In [None]:
NetG = SRGenerator()
NetG.load_state_dict(torch.load("../ckpt/srgan-ImageNet-bc347d67.pth", map_location=device))
NetG.to('cuda')
NetG.eval()

NetC = SRBayesCap(in_channels=3, out_channels=3)
NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location=device))
NetC.to('cuda')
NetC.eval()

In [None]:
eval_BayesCap(NetC, NetG, loader_test, device=device, dtype=torch.cuda.FloatTensor)

## Displaying output

In [None]:
# device = 'cuda'
# dtype=torch.cuda.FloatTensor
# num_imgs = 0
# mean_ssim = 0
# for (idx, batch) in enumerate(loader_val):
#     print('Image {} ...'.format(idx))
#     xLR, xHR = batch[0].to(device), batch[1].to(device)
#     xLR, xHR = xLR.type(dtype), xHR.type(dtype)
#     # pass them through the network
#     with torch.no_grad():
#         xSR = NetG(xLR)
#         xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
#     n_batch = xSRC_mu.shape[0]
#     for j in range(n_batch):
#         num_imgs += 1
#         mean_ssim += compute_img_ssim(xSRC_mu[j], xHR[j])
#         
#     plt.figure(figsize=(30,10))
#     plt.subplot(1,4,1)
#     plt.imshow(xLR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
#     plt.axis('off')
#     
#     plt.subplot(1,4,2)
#     plt.imshow(xSR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
#     plt.axis('off')
#     
#     plt.subplot(1,4,3)
#     a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
#     plt.imshow(a_map.transpose(0,2).transpose(0,1), cmap='inferno')
#     plt.clim(0, 0.1)
#     plt.axis('off')
#     
#     plt.subplot(1,4,4)
#     error_map = torch.mean(torch.pow(torch.abs(xSR[0]-xHR[0]),2), dim=0).to('cpu').data 
#     plt.imshow(error_map, cmap='jet')
#     plt.clim(0,0.01)
#     plt.axis('off')
#     
#     plt.subplots_adjust(wspace=0, hspace=0)
#     plt.show()
#     ########################
#     plt.figure(figsize=(30,10))
#     plt.subplot(1,4,1)
#     plt.imshow(xHR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
#     
#     plt.axis('off')
#     
#     plt.subplot(1,4,2)
#     plt.imshow((0.6*xSRC_mu[0]+0.4*xSR[0]).to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
#     plt.axis('off')
#     
#     plt.subplot(1,4,3)
#     b_map = xSRC_beta[0].to('cpu').data
#     plt.imshow(b_map.transpose(0,2).transpose(0,1), cmap='cividis')
#     plt.clim(0.45, 0.75)
#     plt.axis('off')
#     
#     plt.subplot(1,4,4)
#     u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) 
#     plt.imshow((u_map).transpose(0,2).transpose(0,1), cmap='hot')
#     plt.clim(0,0.15)
#     plt.axis('off')
#     
#     plt.subplots_adjust(wspace=0, hspace=0)
#     plt.show()