In [1]:
import torch
import sys
import types
from torchvision.transforms.functional import to_tensor, to_pil_image, rgb_to_grayscale
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets as dset
import os
import kagglehub
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.spectral_norm as spectral_norm
from torchvision.models import vgg19, VGG19_Weights
import time
from torch.utils.data import DataLoader, Dataset
import glob
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

# Create a module for `torchvision.transforms.functional_tensor`
functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
functional_tensor.rgb_to_grayscale = rgb_to_grayscale

# Add this module to sys.modules so other imports can access it
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
from basicsr.archs.rrdbnet_arch import RRDBNet


In [2]:
num_epochs = 50
batch_size = 16
upscale_factor = 2  # Scale factor in generator
lr_img_size = 64
hr_img_size = lr_img_size * upscale_factor
dataset_lenght = 10000
lr = 0.0001
gen_num_block = 6

# === Step 1: Set Paths ===
base_path = os.path.expanduser("~/real-esrgan-imitation-training-data") 
lr_path = os.path.join(base_path, "lr_images")
hr_path = os.path.join(base_path, f"hr{'+' if upscale_factor==4 else ''}_images")

# Training Dataset Preperation

In [3]:
# Check if dataset has already been created
if not os.path.exists(lr_path) or not os.path.exists(hr_path) or len(os.listdir(lr_path)) == 0 or len(os.listdir(hr_path)) == 0:
    print("ðŸš€ Creating dataset, as it's missing or incomplete...")

    # Create directories if they don't exist
    os.makedirs(lr_path, exist_ok=True)
    os.makedirs(hr_path, exist_ok=True)
    
    # === Step 2: Get Device ===
    def get_device():
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")
    
    device = get_device()
    
    # === Step 3: Load Pretrained Real-ESRGAN Model ===
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4).to(device)
    
    checkpoint = torch.load("../Real-ESRGAN/weights/RealESRGAN_x4plus_anime_6B.pth", map_location=device)
    if 'params_ema' in checkpoint:
        model.load_state_dict(checkpoint["params_ema"], strict=True)
    else:
        model.load_state_dict(checkpoint, strict=True)
    model.eval()
    
    # === Step 4: Load Kaggle Dataset ===
    dataset_name = "soumikrakshit/anime-faces"
    path = kagglehub.dataset_download(dataset_name)
    dataset = dset.ImageFolder(root=path, 
                              transform=transforms.Compose([
                                  transforms.Resize(64),
                                  transforms.CenterCrop(64),
                              ]))
    
    # === Step 5: Select 10,000 Random Images ===
    image_indices = np.random.choice(len(dataset), size=dataset_lenght, replace=False)
    
    # === Step 6: Define Resizing Transformations ===
    resize_transform_128 = transforms.Resize((128, 128))  # Downscale HR images to 128x128
    
    # === Step 7: Process and Save Images ===
    for i in tqdm(range(1, dataset_lenght+1), desc="Processing Images", unit="image"):
        idx = image_indices[i - 1]  # Ensure correct indexing
    
        # Load image
        lr_image = dataset[idx][0]  # Low-resolution (64x64)
        
        # Convert to tensor & send to device
        img_tensor = to_tensor(lr_image).unsqueeze(0).to(device)
    
        # Generate high-resolution image (256x256)
        with torch.no_grad():
            output_tensor = model(img_tensor).clamp(0,1)
    
        # Convert to PIL image
        hr_image_256 = to_pil_image(output_tensor.squeeze(0))
    
        # Downscale to 128x128
        hr_image_128 = resize_transform_128(hr_image_256)
    
        # Save images in ~/vision-project-images/
        lr_image.save(os.path.join(lr_path, f"{i}.png"))  # Save as 64x64
        hr_image_128.save(os.path.join(hr_path, f"{i}.png"))  # Save as 128x128
    
    print(f"âœ… Low-resolution images saved in: {lr_path}")
    print(f"âœ… High-resolution images saved in: {hr_path}")


else:
    print("âœ… Dataset already exists. Skipping dataset creation.")


ðŸš€ Creating dataset, as it's missing or incomplete...


Processing Images: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10000/10000 [03:38<00:00, 45.82image/s]

âœ… Low-resolution images saved in: /Users/yusuf/real-esrgan-imitation-training-data/lr_images
âœ… High-resolution images saved in: /Users/yusuf/real-esrgan-imitation-training-data/hr_images





# Dataset & Dataloader

In [4]:
# Custom Dataset for Anime Faces
class AnimeFaceDataset(Dataset):
    def __init__(self, lr_folder, hr_folder, lr_size=(64, 64), hr_size=(128, 128)):
        self.lr_images = sorted(glob.glob(os.path.join(lr_folder, "*.png")))
        self.hr_images = sorted(glob.glob(os.path.join(hr_folder, "*.png")))
        
        # Basic transforms
        self.lr_transform = transforms.Compose([
            transforms.Resize(lr_size),
            transforms.ToTensor()
        ])
        self.hr_transform = transforms.Compose([
            transforms.Resize(hr_size),
            transforms.ToTensor()
        ])

    
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        lr_img = Image.open(self.lr_images[idx]).convert('RGB')
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        
        # Create low-resolution version
        lr_img = self.lr_transform(hr_img)
        
        # Transform high-resolution image
        hr_img = self.hr_transform(hr_img)
        
        return lr_img, hr_img

In [5]:
dataset = AnimeFaceDataset(lr_path, hr_path, lr_size=(lr_img_size, lr_img_size), hr_size=(hr_img_size, hr_img_size))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator

In [6]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_device()

In [7]:
# Helper function to initialize weights
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)

# Pixel Unshuffle for downsampling
def pixel_unshuffle(x, scale):
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)



In [8]:
# RRDBNet (Residual in Residual Dense Block Network) architecture
class ResidualDenseBlock(nn.Module):
    def __init__(self, channels=64, growth_channels=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, growth_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(channels + growth_channels, growth_channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channels + 2 * growth_channels, growth_channels, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(channels + 3 * growth_channels, growth_channels, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(channels + 4 * growth_channels, channels, kernel_size=3, stride=1, padding=1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.beta = 0.2  # Scaling factor
        
        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
        
        
    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * self.beta + x

In [9]:
class RRDB(nn.Module):
    def __init__(self, channels=64):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(channels)
        self.rdb2 = ResidualDenseBlock(channels)
        self.rdb3 = ResidualDenseBlock(channels)
        self.beta = 0.2
        
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * self.beta + x

In [10]:
# Generator Network
class Generator(nn.Module):
    def __init__(self, num_in_ch=3, num_out_ch=3, scale=2, num_feat=64, num_block=6, num_grow_ch=32):
        super(Generator, self).__init__()
        self.scale = scale
        if scale == 2:
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            num_in_ch = num_in_ch * 16
            
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = nn.Sequential(*[RRDB(num_feat) for _ in range(num_block)])
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        else:
            feat = x
            
        feat = self.conv_first(feat)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        
        return out

# Discriminator

In [11]:
class Discriminator(nn.Module):
    def __init__(self, num_input=3, num_feat=64, skip_connection=True):
        super(Discriminator, self).__init__()
        self.skip_connection = skip_connection
        norm = spectral_norm
        self.conv0 = nn.Conv2d(num_input, num_feat, 3, 1, 1)  # Assume the input image has a resolution of 256x256 (can also be 128x128)
        self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))  # 256x256 --> 128x128
        self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))  # 128x128 --> 64x64
        self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))  # 64x64 --> 32x32
        self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))  # 32x32 --> 32x32
        self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))  # 32x32 --> 32x32
        self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))  # 32x32 --> 32x32
        self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))  # 32x32 --> 32x32
        self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))  # 32x32 --> 32x32
        self.conv9 = norm(nn.Conv2d(num_feat, 1, 3, 1, 1, bias=False))  # 32x32 --> 32x32

    def forward(self, x):
        y1 = F.leaky_relu(self.conv0(x), 0.2, inplace=True)
        y2 = F.leaky_relu(self.conv1(y1), 0.2, inplace=True)
        y3 = F.leaky_relu(self.conv2(y2), 0.2, inplace=True)
        y4 = F.leaky_relu(self.conv3(y3), 0.2, inplace=True)
        y4 = F.interpolate(y4, scale_factor=2, mode='bilinear', align_corners=False)  # 32x32 --> 64x64
        y5 = F.leaky_relu(self.conv4(y4), 0.2, inplace=True)
        if self.skip_connection:
            y5 = y5 + y3
        y5 = F.interpolate(y5, scale_factor=2, mode='bilinear', align_corners=False)  # 64x64 --> 128x128
        y6 = F.leaky_relu(self.conv5(y5), 0.2, inplace=True)
        if self.skip_connection:
            y6 = y6 + y2
        y6 = F.interpolate(y6, scale_factor=2, mode='bilinear', align_corners=False)  # 128x128 --> 256x256
        y7 = F.leaky_relu(self.conv6(y6), 0.2, inplace=True)
        if self.skip_connection:
            y7 = y7 + y1
        y8 = F.leaky_relu(self.conv7(y7), 0.2, inplace=True)
        y9 = F.leaky_relu(self.conv8(y8), 0.2, inplace=True)
        out = self.conv9(y9)
        
        return out

# Loss Function

In [12]:
adversarial_loss = nn.BCEWithLogitsLoss().to(device)
content_loss = nn.MSELoss().to(device)

In [13]:
def relativistic_loss(real_pred, fake_pred):
    real_loss = adversarial_loss(real_pred - torch.mean(fake_pred), torch.ones_like(real_pred))
    fake_loss = adversarial_loss(fake_pred - torch.mean(real_pred), torch.zeros_like(fake_pred))
    out = (real_loss + fake_loss) / 2
    return out

In [14]:
class perceptual_loss(nn.Module):
    def __init__(self):
        super(perceptual_loss, self).__init__()
        # vgg = vgg19(pretrained=True)
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(vgg.features[:36])).eval()
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, x, target):
        x_features = self.feature_extractor(x)
        target_features = self.feature_extractor(target)
        loss = nn.functional.l1_loss(x_features, target_features)
        return loss

In [15]:
generator = Generator(num_in_ch=3, num_out_ch=3, scale=upscale_factor, num_feat=64, num_block=gen_num_block, num_grow_ch=32).to(device)
generator

Generator(
  (conv_first): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (body): Sequential(
    (0): RRDB(
      (rdb1): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv5): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (rdb2): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1

In [16]:
discriminator = Discriminator().to(device)
discriminator

Discriminator(
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv5): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

In [17]:
feature_extractor = perceptual_loss().to(device)
feature_extractor

perceptual_loss(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3)

In [18]:
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(),lr=lr)

In [19]:
scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=10, gamma=0.5)
scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size = 10, gamma=0.5)

In [20]:
def imshow(img, title=None):
    npimg = img.cpu().numpy()
    plt.figure(figsize=(12, 3))  # Set figure size

    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    if title:
        plt.title(title)
    plt.show()

# Train

we cleared the cell output becaseu it made the notebook 180 mb

In [None]:
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    generator.train()
    discriminator.train()

    for batch_idx, (lr_images, hr_images) in enumerate(dataloader):
        batch_start_time = time.time()
        # Move images to device
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        # Create target labels for discriminator training
        valid_labels = torch.ones((lr_images.size(0), 1, hr_images.size(2), hr_images.size(3)),
                                    requires_grad=False).to(device)
        fake_labels = torch.zeros((lr_images.size(0), 1, hr_images.size(2), hr_images.size(3)),
                                   requires_grad=False).to(device)

        # Generate super-resolved images from low-resolution images
        sr_images = generator(lr_images)  # e.g. 64x64 --> 128x128 or 64x64 --> 256x256

        # Clamp images to valid range [0, 1]
        sr_images = torch.clamp(sr_images, 0, 1)
        lr_images = torch.clamp(lr_images, 0, 1)

        # ---------------------
        # Train Discriminator
        # ---------------------
        optimizer_d.zero_grad()  # Reset gradients for discriminator
        real_preds = discriminator(hr_images)
        fake_preds = discriminator(sr_images.detach())
        d_loss = relativistic_loss(real_preds, fake_preds)
        d_loss.backward()
        optimizer_d.step()

        # ---------------------
        # Train Generator
        # ---------------------
        optimizer_g.zero_grad()  # Reset gradients for generator
        fake_preds = discriminator(sr_images)
        fake_preds = fake_preds.view_as(valid_labels)
        g_adv_loss = adversarial_loss(fake_preds, valid_labels)
        g_content_loss = content_loss(sr_images, hr_images)
        g_perceptual_loss = feature_extractor(sr_images, hr_images)
        g_loss = g_content_loss + g_adv_loss * 1e-3 + g_perceptual_loss * 1e-2
        g_loss.backward()
        optimizer_g.step()

        batch_end_time = time.time()
        
        if batch_idx % 5 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}] ', end="=> ")
            print(f'Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}', end=", ")
            print(f"Batch Time: {(batch_end_time - batch_start_time):.4f}")
        
        # Display images every 100 batches
        if batch_idx % 100 == 0:
            with torch.no_grad():
                original_grid = torchvision.utils.make_grid(lr_images[:4].cpu(), nrow=4)
                enhanced_grid = torchvision.utils.make_grid(sr_images[:4].cpu(), nrow=4)
                imshow(original_grid, title='Original Images')
                imshow(enhanced_grid, title='Enhanced Images')

        # Cleanup to free memory
        del fake_preds, g_adv_loss, g_content_loss, g_perceptual_loss, g_loss, sr_images, lr_images, hr_images
        if device == torch.device("cuda"):
            torch.cuda.empty_cache()
        elif device == torch.device("mps"):
            torch.mps.empty_cache()

    # Update learning rate schedulers
    scheduler_g.step()
    scheduler_d.step()
    epoch_end_time = time.time()
    print("Epoch Time:", epoch_end_time - epoch_start_time)

    # Save model checkpoint
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizerG_state_dict': optimizer_g.state_dict(),
        'optimizerD_state_dict': optimizer_d.state_dict(),
        'epoch': epoch,
    }, f'models/model_checkpoint_{upscale_factor}x_epoch_{epoch + 1}.pth')


In [22]:
  torch.save({
    'generator_state_dict': generator.state_dict(),
    'epoch': num_epochs,
  }, f'models/final_generator_{upscale_factor}x.pth')