In [None]:
#Importing necessary libraries
import numpy
import copy
import torch
from torch import nn
from torchvision.models import vgg19
from PIL import Image
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# CONFIGURATION

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4
num_epochs = 100
batch_size = 16
HR = 96
LR = HR // 4
num_channels = 3

# Models

### Convolution Block(Conv -> BN -> PReLU/Element-wise Sum)

In [None]:
class convBlock(nn.Module):
  def __init__(self, in_channels, out_channels, discriminator = False, use_activation = True, use_bn = True, **kwargs):
    super().__init__()
    self.use_activation = use_activation
    self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias = not use_bn)
    if(use_bn == True):
      self.bn = nn.BatchNorm2d(out_channels)
    else:
      self.bn = nn.Identity()
    
    if(discriminator == True):
      self.act = nn.LeakyReLU(0.2, inplace = True)
    else:
      self.act = nn.PReLU(num_parameters = out_channels)

  def forward(self, x):
    out = self.cnn(x)
    out = self.bn(out)
    if( self.use_activation == True):
      out = self.act(out)
    return out

### Upsample Block

In [None]:
class upsampleBlock(nn.Module):
  def __init__(self, in_channels, scale_factor):
    super().__init__()
    self.cnn = nn.Conv2d(in_channels, in_channels * scale_factor ** 2, kernel_size = 3, stride = 1, padding = 1)
    self.ps = nn.PixelShuffle(scale_factor)   # in_channel*4,H,W --> in_channel,2H,2W
    self.act = nn.PReLU(num_parameters = in_channels)
    
  def forward(self, x):
    out = self.cnn(x)
    out = self.ps(out)
    out = self.act(out)
    return out

### Residual Block

In [None]:
class residualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.block1 = convBlock(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1)
    self.block2 = convBlock(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1, use_activation = False)
    
  def forward(self, x):
    out = self.block1(x)
    out = self.block2(out)
    return out + x

## Generator

In [None]:
class generator(nn.Module):
  def __init__(self, in_channels = 3, num_channels = 64, num_blocks = 16):  #num_blocks = B
    super().__init__()
    
    # Without BN
    self.initial_conv = convBlock(in_channels, num_channels, kernel_size = 9, stride = 1, padding = 4, use_bn = False)
    
    # B = 16 Residuals Block 
    self.residuals = []
    for _ in range(num_blocks):
      self.residuals.append(residualBlock(num_channels))
    
    # conv layer after residual
    self.later_conv = convBlock(num_channels, num_channels, kernel_size = 3, stride = 1, padding = 1, use_activation = False)
    
    # 2 conv layer with PixelShuffler
    self.PS_conv1 = upsampleBlock(num_channels, scale_factor = 2)
    self.PS_conv2 = upsampleBlock(num_channels, scale_factor = 2)
    
    self.last_conv = nn.Conv2d(num_channels,in_channels,kernel_size = 9, stride = 1, padding = 4)
    
  def forward(self, x):
    out = self.initial_conv(x)
    initial = copy.copy(out)
    for i in range(len(self.residuals)):
      out = self.residuals[i](out)
    out = self.later_conv(out) + initial
    print(out.shape)
    out = self.PS_conv1(x)
    out = self.PS_conv2(x)
    out = self.last_conv(x)
    out = torch.tanh(out)
    return out   
    

## Discriminator

In [None]:
class discriminator(nn.Module):
  def __init__(self, in_channels, features = [64, 64, 128, 128, 256, 256, 512, 512]):
    super.init__()
    self.blocks = []
    for i, feature in enumerate(features):
      self.blocks.append(convBlock(in_channels, feature, kernel_size = 3, stride = (1 + (i%2)), padding = 1, discriminator = True, use_act = True, use_bn = False if i == 0 else True))
      in_channels = feature

    # Final Dense Layers
    self.final_layers = nn.Sequential(
      nn.AdaptiveAvgPool2d((6, 6)),
      nn.Flatten(),
      nn.Linear(512*6*6, 1024),
      nn.LeakyReLU(0.2, inplace = True),
      nn.Linear(1024, 1)
    )
    
    def forward(self,x):
      for i in range(len(self.blocks)):
        x = blocks[i](x)
        
      out = final_layers(x)
      return out

## Testing Generator and Discriminator

In [None]:
# LR = 24
# with torch.cuda.amp.autocast():
#   x = torch.randn((5,3,LR,LR))
#   gen = generator()
#   gen_out = gen(x)
#   disc = discriminator()
#   disc_out = disc(gen_out)
  
#   print(gen_out.shape)
#   print(disc_out.shape)

# Loss Function

In [None]:
# phi_5,4: 5th conv layer before maxpooling but after activation
class VGG19Loss(nn.Module):
  def __init__(self):
    super().__init__()
    self.vgg = vgg19(pretrained = True).features[:36].eval().to(config.DEVICE)
    self.loss = nn.MSELoss()
    
    for param in self.vgg.parameters():
      param.requires_grad = False
      
  def forward(self, input, target):
    vgg_input_features = self.vgg(input)
    vgg_target_features self.vgg(target)
    loss = self.loss(vgg_input_features, vgg_target_features)
    return loss 

### Function to transform images

In [None]:
highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

# Loading Data

In [None]:
class MyImageFolder(Dataset):
  def __init__(self, root_dir):
    super(MyImageFolder, self).__init__()
    self.data = []
    self.root_dir = root_dir
    self.class_names = os.listdir(root_dir)

    for index, name in enumerate(self.class_names):
      files = os.listdir(os.path.join(root_dir, name))
      self.data += list(zip(files, [index] * len(files)))

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    img_file, label = self.data[index]
    root_and_dir = os.path.join(self.root_dir, self.class_names[label])

    image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
    image = both_transforms(image=image)["image"]
    high_res = highres_transform(image=image)["image"]
    low_res = lowres_transform(image=image)["image"]
    return low_res, high_res

train_dataset = MyImageFolder(root_dir="trainData/")
train_data_loader = DataLoader(train_dataset, batch_size=1)

# Training Function

In [None]:
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
  loop = tqdm(loader, leave=True)

  for idx, (low_res, high_res) in enumerate(loop):
    high_res = high_res.to(device)
    low_res = low_res.to(device)

    ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
    fake = gen(low_res)
    disc_real = disc(high_res)
    disc_fake = disc(fake.detach())
    disc_loss_real = bce(
        disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
    )
    disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = disc_loss_fake + disc_loss_real

    opt_disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()

    # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
    disc_fake = disc(fake)
    #l2_loss = mse(fake, high_res)
    adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
    loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
    gen_loss = loss_for_vgg + adversarial_loss

    opt_gen.zero_grad()
    gen_loss.backward()
    opt_gen.step()

### Training Starts Here...

In [None]:
gen = generator(in_channels=num_channels).to(device)
disc = discriminator(img_channels=num_channels).to(device)
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.9, 0.999))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()

for epoch in range(num_epochs):
    train_fn(train_data_loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)

### Testing Models here..

In [None]:
def apply_SR(lr_folder, gen):
  files = os.listdir(lr_folder)

  gen.eval()
  for file in files:
    image = Image.open(lr_folder + "/" + file)
    with torch.no_grad():
      upscaled_img = gen(test_transform(image=np.asarray(image))["image"].unsqueeze(0).to(device))
    save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
  gen.train()

In [None]:
apply_SR("testData")