In [None]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
import random
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'You are using : {device}')

In [None]:
BATCH_SIZE = 8
LR_RATE = 0.0001
NUM_EPOCHS = 30

In [None]:
low_res_size = 128
high_res_size = 256
transform_low = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((low_res_size, low_res_size)),
])

transform_high = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((high_res_size, high_res_size)),
])

In [None]:
class SRGANData(Dataset):
    def __init__(self, root_path, transform_low, transform_high) -> None:
        super(SRGANData, self).__init__()
        self.transform_low = transform_low
        self.transform_high = transform_high
        self.root_path = root_path
        #Resimleri sort ederek olası karışıklığın önüne geçelim.
        self.high_res_path = os.path.join(self.root_path, 'high_res')
        self.low_res_path = os.path.join(self.root_path, 'low_res')
        self.high_res = sorted(os.listdir(os.path.join(self.root_path, 'high_res')))
        self.low_res = sorted(os.listdir(os.path.join(self.root_path, 'low_res')))

    def __len__(self) -> int:
        return len(self.high_res) #Birini kullanmam yeterli aynı verinin low ve high çözünürlükleri var.
    
    def __getitem__(self, idx) -> dict:
        filename = self.high_res[idx]
        low_res_img_path = os.path.join(self.low_res_path, filename)
        high_res_img_path = os.path.join(self.high_res_path, filename)

        low_res_image = Image.open(low_res_img_path).convert('RGB')
        high_res_image = Image.open(high_res_img_path).convert('RGB')

        if self.transform_low:
            low_res_image = self.transform_low(low_res_image)
        if self.transform_high:
            high_res_image = self.transform_high(high_res_image)

        return {'low_res' : low_res_image, 'high_res' : high_res_image}

In [None]:
dataset = SRGANData(root_path = '/kaggle/input/image-super-resolution/dataset/train', transform_low = transform_low,
                   transform_high = transform_high)
dataset

In [None]:
rand_nums = []
for i in range(10):
    rand_nums.append(random.randint(0, len(dataset) - 1))

for num in rand_nums:
    plt.subplot(1,2,1)
    plt.imshow(dataset[num]['low_res'].permute(1, 2, 0))
    plt.subplot(1, 2, 2)
    plt.imshow(dataset[num]['high_res'].permute(1, 2, 0))
    plt.show()

In [None]:
class ResidualBlocks(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlocks, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                              kernel_size = 3, stride = 1, padding = 1) 
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.prelu1 = nn.PReLU()
        self.conv2 = nn.Conv2d(out_channels, 64, kernel_size = 3, stride = 1, padding = 1)
        self.bn2 = nn.BatchNorm2d(64)
    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += identity
        return x

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = in_channels * 4, 
                             kernel_size = 3, stride = 1, padding = 1)
        self.pix_shuffler = nn.PixelShuffle(2)
        self.prelu = nn.PReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.pix_shuffler(x)
        x = self.prelu(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 64, 
                               kernel_size = 9, stride = 1, padding = 1)
        self.prelu1 = nn.PReLU()
        self.res_block = nn.Sequential(
            ResidualBlocks(in_channels = 64, out_channels = 64),
            ResidualBlocks(in_channels = 64, out_channels = 64),
            ResidualBlocks(in_channels = 64, out_channels = 64),
            ResidualBlocks(in_channels = 64, out_channels = 64),
            ResidualBlocks(in_channels = 64, out_channels = 64)
            
        )
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.upsample_block = nn.Sequential(
            UpsampleBlock(in_channels = 64),
            UpsampleBlock(in_channels = 64)
        )
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 3,
                               kernel_size = 9, stride = 1, padding = 4)
        self.tan = nn.Tanh()
    def forward(self, x) -> torch.Tensor:
        x1 = self.prelu1(self.conv1(x))
        x2 = self.res_block(x1)
        x = x1 + x2 #Skip connection
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.upsample_block(x)
        x = self.conv3(x)
        x = self.tan(x)
        return x

In [None]:
class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DiscBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                              kernel_size = 3, stride = stride, padding = 1) 
        self.bn = nn.BatchNorm2d(out_channels)
        self.lrelu = nn.LeakyReLU(0.2)
    def forward(self, x):
        x = self.lrelu(self.bn(self.conv(x)))
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.lrelu1 = nn.LeakyReLU(0.2)
        self.res_block = nn.Sequential(
            DiscBlock(in_channels = 64, out_channels = 64, stride = 2),
            DiscBlock(in_channels = 64, out_channels = 128,stride = 1),
            DiscBlock(in_channels = 128, out_channels = 128, stride = 2),
            DiscBlock(in_channels = 128, out_channels = 256, stride = 1),
            DiscBlock(in_channels = 256, out_channels = 256, stride = 2),
            DiscBlock(in_channels = 256, out_channels = 512, stride = 1),
            DiscBlock(in_channels = 512, out_channels = 512, stride = 2),
            
        )
        self.gap = nn.AdaptiveAvgPool2d((8, 8))
        self.fc1 = nn.Linear(8*8*512,1024)
        self.lrelu2 = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(1024,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.conv1(x)
        x = self.lrelu1(x)
        x = self.res_block(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.lrelu2(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

In [None]:
class VggFeatureExtractor(nn.Module):
    def __init__(self):
        super(VggFeatureExtractor, self).__init__()
        vgg19 = torchvision.models.vgg19(pretrained=True).features
        self.vgg19_layers = nn.Sequential(*list(vgg19.children())[:36])
        for param in self.vgg19_layers.parameters():
            param.requires_grad = False
    def forward(self, x):
        return self.vgg19_layers(x)

In [None]:
bce_loss = nn.BCELoss()
mse_loss = nn.MSELoss()
vgg_extractor = VggFeatureExtractor().to(device)

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)
optimizer_G = torch.optim.Adam(G.parameters(), lr = LR_RATE)
optimizer_D = torch.optim.Adam(D.parameters(), lr = LR_RATE)

In [None]:
import torch.nn.functional as F

dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
for epoch in range(NUM_EPOCHS):
    for batch in dataloader:
        real_imgs = batch['high_res'].to(device)
        low_res_imgs = batch['low_res'].to(device)
        batch_size = real_imgs.size(0)
        valid = torch.ones((batch_size, 1), device=device)
        fake = torch.zeros((batch_size, 1), device=device)
        
        optimizer_G.zero_grad()

        gen_imgs = G(low_res_imgs)
        pred_fake = D(gen_imgs)

        # VGG feature loss
        real_features = vgg_extractor(real_imgs)
        gen_features = vgg_extractor(gen_imgs)
        gen_features = F.interpolate(gen_features, size=real_features.shape[2:], mode='bilinear', align_corners=False) #Shapeler tutsun diye

        loss_content = mse_loss(gen_features, real_features)

        # Adversarial loss
        loss_gan = bce_loss(pred_fake, valid)

        # Total generator loss
        loss_G = loss_content + 1e-3 * loss_gan
        loss_G.backward()
        optimizer_G.step()


        optimizer_D.zero_grad()

        pred_real = D(real_imgs)
        pred_fake = D(gen_imgs.detach())

        loss_real = bce_loss(pred_real, valid)
        loss_fake = bce_loss(pred_fake, fake)
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

    #print(f"[Epoch {epoch}] Loss_G: {loss_G.item():.4f} | Loss_D: {loss_D.item():.4f}")