In [None]:
!pip install torch torchvision
!pip install numpy opencv-python tqdm
!pip install tensorboardX

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms, ToTensor, ToPILImage
from torchvision.utils import save_image
from PIL import Image
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x
        return out

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(Generator, self).__init__()

        # Initial convolution block
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4)
        self.prelu = nn.PReLU()

        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock())
        self.res_blocks = nn.Sequential(*res_blocks)

        # Second conv block after residual blocks
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Final output layer
        self.conv3 = nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        out1 = self.prelu(self.conv1(x))
        out = self.res_blocks(out1)
        out = self.bn2(self.conv2(out))
        out = out1 + out
        out = self.conv3(out)
        return out

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride=1, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 64, stride=2),
            *discriminator_block(64, 128),
            *discriminator_block(128, 128, stride=2),
            *discriminator_block(128, 256),
            *discriminator_block(256, 256, stride=2),
            *discriminator_block(256, 512),
            *discriminator_block(512, 512, stride=2),
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted(os.listdir(lr_dir))
        self.hr_images = sorted(os.listdir(hr_dir))
        self.transform = transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, index):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_images[index])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_images[index])).convert('RGB')

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image


In [None]:
# Load and preprocess the dataset
lr_dir = '/content/drive/MyDrive/AI/Term 3/dataset/train/low_res'
hr_dir = '/content/drive/MyDrive/AI/Term 3/dataset/train/high_res'
batch_size = 16
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = ImageDataset(lr_dir, hr_dir, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Load pre-trained weights for transfer learning
pretrained_weight_path = '/content/drive/MyDrive/AI/Term 3/RealESRGAN_weights/RealESRGAN_x2.pth'
generator.load_state_dict(torch.load(pretrained_weight_path, map_location=device), strict=False)

# Define loss functions
criterion_GAN = nn.BCEWithLogitsLoss().to(device)
criterion_content = nn.L1Loss().to(device)

# Define optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)

# Training loop
num_epochs = 100
log_dir = 'logs'
writer = SummaryWriter(log_dir)

In [None]:
# Transfer Learning
pretrained_weights_path = {
    'x2': '/content/drive/MyDrive/AI/Term 3/RealESRGAN_weights/RealESRGAN_x2.pth',
    'x4': '/content/drive/MyDrive/AI/Term 3/RealESRGAN_weights/RealESRGAN_x4.pth',
    'x8': '/content/drive/MyDrive/AI/Term 3/RealESRGAN_weights/RealESRGAN_x8.pth'
}

# Load pretrained weights into the generator
def load_pretrained_weights(generator, pretrained_weights_path):
    pretrained_dict = torch.load(pretrained_weights_path, map_location=device)
    model_dict = generator.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    generator.load_state_dict(model_dict)

# Example: Load x4 pretrained weights
load_pretrained_weights(generator, pretrained_weights_path['x4'])

for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    for i, (lr, hr) in enumerate(tqdm(train_dataloader)):
        lr = lr.to(device)
        hr = hr.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_hr = generator(lr)
        real_out = discriminator(hr)
        fake_out = discriminator(fake_hr.detach())
        real_loss = criterion_GAN(real_out - torch.mean(fake_out), torch.ones_like(real_out))
        fake_loss = criterion_GAN(fake_out - torch.mean(real_out), torch.zeros_like(fake_out))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_out = discriminator(fake_hr)
        g_loss_GAN = criterion_GAN(fake_out - torch.mean(real_out.detach()), torch.ones_like(fake_out))
        g_loss_content = criterion_content(fake_hr, hr)
        g_loss = g_loss_GAN + 1e-2 * g_loss_content
        g_loss.backward()
        optimizer_G.step()

        # Logging
        writer.add_scalar('Loss/Discriminator', d_loss.item(), epoch * len(train_dataloader) + i)
        writer.add_scalar('Loss/Generator', g_loss.item(), epoch * len(train_dataloader) + i)

    print(f"Epoch [{epoch + 1}/{num_epochs}] Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")

    # Delete the previous epoch's models
    if epoch > 0:
        os.remove(f'/content/drive/MyDrive/AI/Term 3/dataset/generator_epoch_{epoch}.pth')
        os.remove(f'/content/drive/MyDrive/AI/Term 3/dataset/discriminator_epoch_{epoch}.pth')

    # Save the current epoch's models
    torch.save(generator.state_dict(), f'/content/drive/MyDrive/AI/Term 3/dataset/generator_epoch_{epoch+1}.pth')
    torch.save(discriminator.state_dict(), f'/content/drive/MyDrive/AI/Term 3/dataset/discriminator_epoch_{epoch+1}.pth')

writer.close()

# Save the final models
torch.save(generator.state_dict(), '/content/drive/MyDrive/AI/Term 3/dataset/generator_final.pth')
torch.save(discriminator.state_dict(), '/content/drive/MyDrive/AI/Term 3/dataset/discriminator_final.pth')

100%|██████████| 43/43 [19:01<00:00, 26.54s/it]


Epoch [1/100] Discriminator Loss: 0.0219, Generator Loss: 4.4322


100%|██████████| 43/43 [15:51<00:00, 22.12s/it]


Epoch [2/100] Discriminator Loss: 0.0640, Generator Loss: 4.3022


100%|██████████| 43/43 [15:54<00:00, 22.21s/it]


Epoch [3/100] Discriminator Loss: 0.0049, Generator Loss: 5.9592


100%|██████████| 43/43 [16:01<00:00, 22.35s/it]


Epoch [4/100] Discriminator Loss: 0.0017, Generator Loss: 6.8693


 72%|███████▏  | 31/43 [11:42<04:31, 22.60s/it]

In [None]:
import torch
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
import matplotlib.pyplot as plt

# Assuming the Generator class is defined as before

# Path to the saved generator model
generator_model_path = '/content/drive/MyDrive/AI/Term 3/dataset/generator_epoch_4.pth'

# Load the trained generator model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
generator.load_state_dict(torch.load(generator_model_path, map_location=device))
generator.eval()

# Function to load and preprocess a low-resolution image
def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    return image

# Path to the low-resolution test image
test_image_path = '/content/drive/MyDrive/AI/Term 3/dataset/val/low_res/0.png'

# Load and preprocess the test image
transform = ToTensor()
lr_image = load_image(test_image_path, transform).unsqueeze(0).to(device)

# Generate the high-resolution image
with torch.no_grad():
    sr_image = generator(lr_image).squeeze(0).cpu()

# Convert the tensor to PIL image
to_pil_image = ToPILImage()
sr_image = to_pil_image(sr_image)

# Save the generated high-resolution image
output_image_path = '/content/drive/MyDrive/AI/Term 3/dataset/Test Saved images/generated_sample1.jpg'
sr_image.save(output_image_path)

# Display the low-resolution and high-resolution images
lr_image = lr_image.squeeze(0).cpu()
lr_image = to_pil_image(lr_image)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Low-Resolution Image')
plt.imshow(lr_image)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Generated High-Resolution Image')
plt.imshow(sr_image)
plt.axis('off')

plt.show()
