In [2]:
import torch
import torch.nn as nn


import pandas as pd
import numpy as np

import torchvision.models as models

In [None]:
class RWMAB(nn.Module):
    """
    This class implements the Residual Whole Map Attention Network (RWMAN),
    a modification of RCAN for extracting features from low-resolution (LR) images
    and feeding them into a generator for image upscaling.
    """

    def __init__(self, input_shape: int = 64) -> None:
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_shape, 64, kernel_size=3, stride=1),
            nn.ReLU()
            nn.Conv2d(64, 64, kernel_size=3, stride=1)
        )
        self.attention = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_out = self.conv1(x)
        attention = self.attention(x_out)
        x_out = x * attention + x
        return x_out

class ShortResidualConnection(nn.Module):
    def __init__(self, input_shape: int = 64) -> None:
        super().__init__()
        RWMAN = []
        for _ in range(16):
            RWMAN.append(RWMAB())
        
        self.src = nn.Sequential(*RWMAN, 
                            nn.Conv2d(64, 64, kernel_size=1, stride=1)

        def forward(self, x):
            x_1 = src(x)
            return x_1 + x
            
            
class Generator(nn.Module):
    
    def __init__():
        super().__init__()
    
        self.conv_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.lrc = []
        for _ in range(8):
            self.lrc.append(ShortResidualConnection())
        self.conv_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        self.lrc = nn.Sequential(*lrc, conv_2)
           
        upsample_1 = nn.Sequential(nn.Conv(64, 256, kernel_size=3, stride=1),
                                  nn.PixelShuffle(2),
                                  nn.Conv(256, 256, kernel_size=3, stride=1),
                                  nn.PixelShuffle(2))
        conv_3 = nn.Conv2d(256, 1, kernel_size=3, stride=1)
                            
        self.upscaler = nn.Sequential(upsample_1, conv_3)
                            
    def forward(self, x):
        x_1 = conv_1(x)   
        x_2 = lrc(x_1)
        x_out = x_1 + x_2
        return upscaler(x_out)                  

In [None]:
class DBlock(nn.Module):
    
    def __init__(input_shape: int = 64, output_shape: int = 64, stride: int = 2, bn: bool = True) -> None:
        super().__init__()
        self.bntrue = bn
        self.conv_1 = nn.Con2d(input_shape, output_shape, kernel_size=3, stride=stride)
        self.bn = nn.BatchNorm2d()
        self.leakyr = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        if self.bntrue:
            return self.leakyr(self.bn(self.conv_1(x)))
        else:
            return self.leakyr(self.conv_1(x))

class Discriminator(nn.Module):
    
    def __init__():
        super().__init__()
        self.block_1_sr = nn.Sequential(DBlock(1, 64, 1, False),
                                DBlock(64, 64, 2))
        self.block_2_sr = nn.Sequential(DBlock(64, 128, 1),
                                DBlock(128, 128, 2))
        self.block_1_lr = nn.Sequential(DBlock(1, 64, 1, False),
                                DBlock(64, 128, 1, True))
        
        self.block_1 = nn.Sequential(DBlock(128, 256, 1),
                                DBlock(256, 256, 2))
        self.block_2 = nn.Sequential(DBlock(256, 512, 1),
                                DBlock(512, 512, 2))
        self.block_3 = nn.Sequential(DBlock(512, 1024, 1),
                                DBlock(1024, 1024, 2))
        
        self.final = nn.Sequential(nn.Linear(100),
                                  nn.LeakyReLU(0.2),
                                  nn.Linear(1),
                                  nn.Sigmoid())
        
    def forward(self, x, y):
        x_1 = self.block_1_sr(x)
        x_2 = self.block_2_sr(x_1)
        
        y_1 = self.block_1_lr(y)
        
        xy = torch.add(x_2, y_1)
        
        xy_1 = self.block_1(xy)
        xy_2 = self.block_1(xy_1)
        xy_3 = self.block_1(xy_2)
        
        final = self.final(xy_3)
        
        return (x_1, x_2, xy_1, xy_2, xy_3, final)

In [None]:
class GeneratorLossFunction(nn.Module):
    def __init__(self, device: str = 'cuda', lambda1: float = 1e-2, lambda2: float = 1e-4
                 vgg_layers: list[int] = [2, 7, 16, 25, 34],
                 weights: list[float] = [1/2, 1/4, 1/8, 1/64, 1/128]) -> None:
        
        super().__init__()
        
        vgg = models.vgg19(pretrained=True).features.to(device).eval()
        
        self.layers = vgg_layers
        self.weights = weights
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.lambda1 = lambda1
        self.lambda2 = lambda2

            
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, discriminator, LR, HR, SR):
        return content_loss(HR, SR) + self.lambda1 * adverserial_loss(discriminator, LR, HR, SR) + self.lambda2 * adverserial_feature_loss(discriminator, LR, HR, SR)

    def vgg_extract(self, x):
        features = []
        for layer in layers:
            features.append(vgg19.features[layer](x))
            
        return features
    
    def content_loss(self, HR, SR, lambda_l1=0):
        
        HR_features = self.vgg_extract(HR)
        SR_features = self.vgg_extract(SR)
        
        loss = 0.0
        for i in range(len(self.layers)):
            loss += self.weights[i] * self.mse_loss(SR_features[i], HR_features[i])
        
        l1_loss = self.l1_loss(HR, SR)
        content_loss = lambda_l1 * l1_loss + loss
        
        return content_loss
        
        
    def adversarial_loss(discriminator, lr, hr, sr):
        """
        Compute the adversarial loss for the generator.
        """
        d_real = discriminator(lr, hr)[-1]
        d_fake = discriminator(lr, sr)[-1]
        adv = -torch.log(1 - d_real) - torch.log(d_fake)
        return adv.mean()
        
    def adversarial_feature_loss(discriminator, lr, hr, sr):
        """
        Compute the adversarial feature loss for the generator.
        """

        weights = [1/2, 1/4, 1/8, 1/64, 1/128]

        d_real = discriminator(lr, hr)
        d_fake = discriminator(lr, sr)

        advfeat = 0
        for idx in range(len(weights)):
            advfeat += weights[idx] *  nn.MSELoss()(d_real[idx], d_fake[idx])

        return advfeat.mean()

In [None]:
class DiscriminatorLossFunction(nn.Module):
    def __init__(self, device: str = 'cuda', discriminator)
        super().__init__()
        self.d = discriminator
        
    def forward(lr, hr, sr):
        loss = -1*torch.log(self.d(lr, hr)) - torch.log(1-self.d(lr, sr))
        return loss

In [1]:
class Trainer(nn.Module):
    
    def __init__(self, generator: nn.Module = None, discriminator: nn.Module = None,
                 g_loss = nn.Module = None, d_loss: nn.Module = None,
                 batch_size: int = 4, dataloader: torch.Dataloader = None,
                 mean: tuple(float, float, float) = (0.5, 0.5, 0.5),
                 std: tuple(float, float, float) = (0.5, 0.5, 0.5),
                 device: str = 'cuda') -> None:
        super().__init__()
        
        self.generator = generator
        self.discriminator = discriminator
        self.dataloader = dataloader
        self.mean = mean
        self.fixed_latent = torch.randn(batch_size, 512, 14, 14)
        self.g_loss = g_loss
        self.d_loss = d_loss
        self.device = device
        
    def denorm(self, img_tensor: torch.Tensor) -> torch.Tensor:
        return img_tensor * self.mean[0] + self.std[0]
    
     def save_samples(self, index: int = 0) -> None:
            
        latent_tensors = torch.randn(64, self.latent_size, 1, 1, device=self.device)
        fake_images = self.generator(latent_tensors)
        fake_fname = f'generated-images-{index:04d}.png'
        save_image(fake_images, os.path.join(fake_fname), nrow=8)
        print('Saving', fake_fname)
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(self.denorm(fake_images.cpu().detach()), nrow=8).permute(1, 2, 0))
        plt.show()
        
    def train_discriminator(self, real_images, opt_d):
        opt_d.zero_grad()

        real_preds = self.discriminator(real_images)[-1]
        real_targets = torch.ones(real_images.size(0), 1, device=self.device)
        real_loss = F.binary_cross_entropy(real_preds, real_targets)
        
        real_score = torch.mean(real_preds).item()

        latent = torch.randn(real_images.size(0), self.latent_size, 1, 1, device=self.device)
        fake_images = self.generator(latent)

        fake_targets = torch.zeros(real_images.size(0), 1, device=self.device)
        fake_preds = self.discriminator(fake_images.detach())
        fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
        fake_score = torch.mean(fake_preds).item()

        loss = real_loss + fake_loss
        loss.backward()
        opt_d.step()

        return loss.item(), real_score, fake_score

    def train_generator(self, opt_g):
        opt_g.zero_grad()

        latent = torch.randn(self.train_dl.batch_size, self.latent_size, 1, 1, device=self.device)
        fake_images = self.generator(latent)

        preds = self.discriminator(fake_images)
        targets = torch.ones(self.train_dl.batch_size, 1, device=self.device)
        loss = F.binary_cross_entropy(preds, targets)

        loss.backward()
        opt_g.step()

        return loss.item()

SyntaxError: non-default argument follows default argument (3728780873.py, line 4)

In [None]:
generator = Generator().to('cuda')
discriminator = Discriminator().to('cuda')

gen_loss = GeneratorLoss()
disc_loss = DescriminatorLoss()