<a href="https://colab.research.google.com/github/tnshq/esrgan/blob/main/ESRGAN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
from torchvision.utils import save_image
from google.colab import files

# Define dataset class for loading and preprocessing images
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# Define transforms (resize images for simplicity)
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Create Dataset and DataLoader
dataset = ImageDataset(root_dir="BSDS300/images/train", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

class RRDB(nn.Module):
    def __init__(self, in_channels):
        super(RRDB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return x + out  # Residual connection

class Generator(nn.Module):
    def __init__(self, in_channels=3, num_rrdb=23):
        super(Generator, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.rrdb_blocks = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb)])
        self.final_conv = nn.Conv2d(64, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        initial_feature = self.initial_conv(x)
        out = self.rrdb_blocks(initial_feature)
        out = self.final_conv(out)
        return out

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_channels, 64, normalize=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 3, stride=1, padding=1)
        )

    def forward(self, img):
        return self.model(img)

class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()

    def forward(self, sr, hr):
        return F.mse_loss(sr, hr)

class PerceptualLoss(nn.Module):
    def __init__(self, vgg_model):
        super(PerceptualLoss, self).__init__()
        self.vgg = vgg_model.features[:36]  # Use pre-trained VGG features
        self.vgg.eval()

    def forward(self, sr, hr):
        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)
        return F.mse_loss(sr_features, hr_features)

def save_checkpoint(epoch, generator, optimizer_G, discriminator, optimizer_D, path='./esrgan_checkpoint.pth'):
    """Saves the model and optimizer states to a checkpoint file."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
    }, path)

def load_checkpoint(path='./esrgan_checkpoint.pth'):
    """Loads the model and optimizer states from a checkpoint file."""
    checkpoint = torch.load(path)
    return checkpoint['epoch'], checkpoint['generator_state_dict'], checkpoint['optimizer_G_state_dict'], checkpoint['discriminator_state_dict'], checkpoint['optimizer_D_state_dict']

def train(generator, discriminator, dataloader, num_epochs, optimizer_G, optimizer_D, criterion_content, criterion_perceptual, device, start_epoch=0):
    """Trains the ESRGAN model."""
    generator.to(device)
    discriminator.to(device)

    for epoch in range(start_epoch, num_epochs):
        for i, img in enumerate(dataloader):
            img = img.to(device)

            # Generate super-resolved image
            sr_image = generator(img)

            # Train Generator
            optimizer_G.zero_grad()
            content_loss = criterion_content(sr_image, img)
            perceptual_loss = criterion_perceptual(sr_image, img)
            g_loss = content_loss + perceptual_loss
            g_loss.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()
            real_output = discriminator(img)
            fake_output = discriminator(sr_image.detach())
            d_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output)) + \
                     F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
            d_loss.backward()
            optimizer_D.step()

            if i % 10 == 0:
                print(f"Epoch {epoch}/{num_epochs}, Step {i}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")

        # Save checkpoint after each epoch
        save_checkpoint(epoch, generator, optimizer_G, discriminator, optimizer_D)

def enhance_image(image_path, generator, device):
    """Enhances the resolution of an image using the ESRGAN model."""
    # Load the image
    image = Image.open(image_path).convert("RGB")

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Resize to match training data
        transforms.ToTensor(),
    ])
    image = transform(image).unsqueeze(0).to(device)

    # Enhance the image using the generator
    generator.eval()
    with torch.no_grad():
        enhanced_image = generator(image)

    # Postprocess the enhanced image
    enhanced_image = enhanced_image.squeeze().cpu().numpy()
    enhanced_image = np.transpose(enhanced_image, (1, 2, 0))
    enhanced_image = (enhanced_image * 255).astype(np.uint8)  # Convert to uint8

    return enhanced_image

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models and optimizers
generator = Generator()
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Load pre-trained VGG model for Perceptual Loss
vgg = models.vgg19(pretrained=True).to(device)
criterion_content = ContentLoss()
criterion_perceptual = PerceptualLoss(vgg)

# Load checkpoint if available
start_epoch = 0
if os.path.exists('./esrgan_checkpoint.pth'):
    start_epoch, generator_state_dict, optimizer_G_state_dict, discriminator_state_dict, optimizer_D_state_dict = load_checkpoint()
    generator.load_state_dict(generator_state_dict)
    optimizer_G.load_state_dict(optimizer_G_state_dict)
    discriminator.load_state_dict(discriminator_state_dict)
    optimizer_D.load_state_dict(optimizer_D_state_dict)
    print(f"Resuming training from epoch {start_epoch}")

# Train ESRGAN
train(generator, discriminator, dataloader, num_epochs=2, optimizer_G=optimizer_G, optimizer_D=optimizer_D,
      criterion_content=criterion_content, criterion_perceptual=criterion_perceptual, device=device, start_epoch=start_epoch)

# Get input image from the user
uploaded = files.upload()
image_path = list(uploaded.keys())[0]

# Enhance the image
enhanced_image = enhance_image(image_path, generator, device)

# Display the enhanced image
plt.imshow(enhanced_image)
plt.title("Enhanced Image")
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'BSDS300/images/train'