<a href="https://colab.research.google.com/github/tanyavijj/Tanya-project/blob/main/SRGAN_ERROR_Notebook_Code_(2).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
# STEP 0: INSTALL DEPENDENCIES
%pip install -q torchvision scikit-image matplotlib opencv-python



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
# STEP 1: DOWNLOAD + TRIM DATASET
import os
import glob
import shutil
import urllib.request
from zipfile import ZipFile

os.makedirs("data", exist_ok=True)
os.makedirs("models", exist_ok=True)

url = "https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
zip_path = "data/DIV2K_train_HR.zip"

if not os.path.exists(zip_path):
    urllib.request.urlretrieve(url, zip_path)

with ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall("data/")

# Trim to first 100 images
all_images = sorted(glob.glob("data/DIV2K_train_HR/*.png"))
for img in all_images[100:]:
    os.remove(img)



In [4]:
# STEP 2: CONFIG
batch_size = 2   # Reduce to fit GPU memory
crop_size = 64   # Smaller crop to reduce resolution and memory
upscale_factor = 4
num_epochs_pretrain = 2
num_epochs_gan = 5
lr = 1e-4
beta1 = 0.9
beta2 = 0.999

train_hr_path = 'data/DIV2K_train_HR'
model_save_path = 'models'



In [5]:
# STEP 3: DATASET LOADER
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class SRDataset(Dataset):
    def __init__(self, image_dir, crop_size=64, upscale_factor=4):
        self.image_filenames = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        self.hr_crop_size = crop_size * upscale_factor
        self.lr_size = crop_size
        self.upscale_factor = upscale_factor

        self.hr_transform = transforms.Compose([
            transforms.RandomCrop(self.hr_crop_size),
            transforms.ToTensor()
        ])
        self.lr_downscale = transforms.Resize(self.lr_size, interpolation=Image.BICUBIC)

    def __getitem__(self, index):
        hr = Image.open(self.image_filenames[index]).convert("RGB")
        hr = self.hr_transform(hr)
        lr = self.lr_downscale(transforms.ToPILImage()(hr))
        lr = transforms.ToTensor()(lr)
        return lr, hr

    def __len__(self):
        return len(self.image_filenames)


train_dataset = SRDataset(train_hr_path, crop_size, upscale_factor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)




In [6]:
# STEP 4: GENERATOR
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, num_res_blocks=16):
        super().__init__()
        self.input_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_res_blocks)])
        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.output_conv = nn.Conv2d(64, 3, kernel_size=9, padding=4)

    def forward(self, x):
        x1 = self.input_conv(x)
        x2 = self.res_blocks(x1)
        x3 = self.mid_conv(x2)
        x4 = self.upsample(x1 + x3)
        return self.output_conv(x4)



In [7]:
# STEP 5: DISCRIMINATOR + VGG
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def block(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        self.net = nn.Sequential(
            block(3, 64, 1),
            block(64, 64, 2),
            block(64, 128, 1),
            block(128, 128, 2),
            block(128, 256, 1),
            block(256, 256, 2),
            block(256, 512, 1),
            block(512, 512, 2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.net(x)

class VGGContentLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg19(pretrained=True).features
        self.feature_extractor = nn.Sequential(*list(vgg[:36])).eval()
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, sr, hr):
        return nn.functional.mse_loss(self.feature_extractor(sr), self.feature_extractor(hr))



In [8]:
# STEP 6: PRETRAIN GENERATOR
generator = Generator().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

for epoch in range(num_epochs_pretrain):
    generator.train()
    for lr_img, hr_img in train_loader:
        lr_img, hr_img = lr_img.cuda(), hr_img.cuda()
        sr_img = generator(lr_img)
        loss = criterion(sr_img, hr_img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Pretrain Epoch [{epoch+1}/{num_epochs_pretrain}] Loss: {loss.item():.4f}")

torch.save(generator.state_dict(), f"{model_save_path}/srresnet_pretrained.pth")



Pretrain Epoch [1/2] Loss: 0.0232
Pretrain Epoch [2/2] Loss: 0.0485


In [9]:
import torchvision


In [10]:
# STEP 7: LOAD PRETRAINED + INIT GAN
generator = Generator().cuda()
generator.load_state_dict(torch.load(f"{model_save_path}/srresnet_pretrained.pth"))
discriminator = Discriminator().cuda()
vgg_loss = VGGContentLoss().cpu()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
bce_loss = nn.BCELoss()
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)



  generator.load_state_dict(torch.load(f"{model_save_path}/srresnet_pretrained.pth"))
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:04<00:00, 142MB/s]


In [16]:
#step 8
#step 8
import os
import torch
import torch.nn as nn
from torchvision import transforms

# Enable debugging for CUDA kernel errors
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# STEP 8: SRGAN TRAINING
for epoch in range(num_epochs_gan):
    generator.train()
    discriminator.train()
    for lr_img, hr_img in train_loader:
        # Check and print shapes of the images
        print(f"lr_img shape: {lr_img.shape}, hr_img shape: {hr_img.shape}")

        # Ensure tensors are float type
        lr_img, hr_img = lr_img.float(), hr_img.float()

        # Move to GPU before resizing
        lr_img, hr_img = lr_img.cuda(), hr_img.cuda() # Move tensors to GPU first

        # Resize lr_img to match the size of hr_img (256x256) using interpolate
        import torch.nn.functional as F
        lr_img = F.interpolate(lr_img, size=(hr_img.shape[2], hr_img.shape[3]), mode='bicubic', align_corners=False) # Perform resizing on the GPU

        # Ensure labels are the correct size and move them to GPU
        # real_labels = torch.ones(hr_img.size(0), 1).cuda() # This could lead to shape issues
        # fake_labels = torch.zeros(hr_img.size(0), 1).cuda() # This could lead to shape issues

        # Discriminator
        sr_img = generator(lr_img)
        real_out = discriminator(hr_img)
        fake_out = discriminator(sr_img.detach())

        # Match label shapes and types to the outputs
        real_labels = torch.ones_like(real_out) # corrected here
        fake_labels = torch.zeros_like(fake_out) # corrected here

        d_loss = bce_loss(real_out, real_labels) + bce_loss(fake_out, fake_labels)
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # Generator
        fake_out = discriminator(sr_img)
        adv_loss = bce_loss(fake_out, real_labels)
        # sr_crop = transforms.CenterCrop(hr_img.shape[2:])(sr_img)
        import torch.nn.functional as F
        sr_crop = F.center_crop(sr_img, hr_img.shape[2:]) # Ensure center crop is done on the GPU
        content_loss = vgg_loss(sr_crop.cpu(), hr_img.cpu())
        g_loss = content_loss + 1e-3 * adv_loss.cpu()
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

    print(f"Epoch [{epoch+1}/{num_epochs_gan}] G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
    torch.save(generator.state_dict(), f"{model_save_path}/generator_epoch_{epoch+1}.pth")
    torch.save(discriminator.state_dict(), f"{model_save_path}/discriminator_epoch_{epoch+1}.pth")

torch.save(generator.state_dict(), f"{model_save_path}/srgan_final.pth")



lr_img shape: torch.Size([2, 3, 64, 64]), hr_img shape: torch.Size([2, 3, 256, 256])


RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# STEP 9: VISUALIZE
import torchvision.utils as vutils
generator.eval()
with torch.no_grad():
    for i, (lr_img, hr_img) in enumerate(train_loader):
        lr_img = lr_img.cuda()
        sr_img = generator(lr_img).cpu()
        vutils.save_image(sr_img, f"srgan_output_{i}.png", normalize=True)
        vutils.save_image(hr_img, f"hr_output_{i}.png", normalize=True)
        vutils.save_image(lr_img, f"lr_input_{i}.png", normalize=True)
        break



In [None]:
# STEP 10: METRICS
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np

def evaluate(sr_img, hr_img):
    sr = sr_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    hr = hr_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    return psnr(hr, sr), ssim(hr, sr, channel_axis=2)

with torch.no_grad():
    for lr_img, hr_img in train_loader:
        lr_img = lr_img.cuda()
        sr_img = generator(lr_img).cpu()
        p, s = evaluate(sr_img, hr_img)
        print(f"PSNR: {p:.2f}, SSIM: {s:.4f}")
        break
