In [23]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.utils import save_image

In [24]:
class UNetForUpsampling(nn.Module):
    def __init__(self):
        super(UNetForUpsampling, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )

        # Only two encoder stages due to smaller input size
        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True)
        )

        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

        # Adjusting the decoder to match the target upsampling size
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(128*2, 64, kernel_size=2, stride=2),  # kernel size and stride adjusted for upscaling
            nn.ReLU(inplace=True)
        )

        # Final layer to upscale from current size to desired 48x48 output
        self.upsample = nn.Upsample(size=(256, 256), mode='bicubic', align_corners=False)
        self.final = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)

        xm = self.middle(x2)

        d2 = self.decoder2(xm)
        d2 = torch.cat((x2, d2), dim=1)  # Skip connection
        d1 = self.decoder1(d2)
        d1 = torch.cat((x1, d1), dim=1)  # Skip connection

        upsampled = self.upsample(d1)  # Upsample to target size
        out = self.final(upsampled)
        return out

In [25]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torch

# -------- Dataset Definition --------
class ImagePairDataset(Dataset):
    def __init__(self, lr_folder, hr_folder, lr_transform=None, hr_transform=None):
        self.lr_folder = lr_folder
        self.hr_folder = hr_folder
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform

        # Only use image names that are in both folders
        self.image_names = list(set(os.listdir(lr_folder)) & set(os.listdir(hr_folder)))
        self.image_names = [img for img in self.image_names if img.endswith('.png')]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        lr_image_path = os.path.join(self.lr_folder, image_name)
        hr_image_path = os.path.join(self.hr_folder, image_name)

        lr_image = Image.open(lr_image_path).convert('RGB')
        hr_image = Image.open(hr_image_path).convert('RGB')

        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)
        if self.hr_transform:
            hr_image = self.hr_transform(hr_image)

        return lr_image, hr_image

In [26]:

# -------- Folder Paths --------
folder_path_train_hr = '/home/vittorio/Documenti/Upsampling_CFD/datasets/SplitDatasetOutdoorFlow/train/high_res'
folder_path_train_lr = '/home/vittorio/Documenti/Upsampling_CFD/datasets/SplitDatasetOutdoorFlow/train/low_res'
folder_path_val_hr =  '/home/vittorio/Documenti/Upsampling_CFD/datasets/SplitDatasetOutdoorFlow/test/high_res'
folder_path_val_lr ='/home/vittorio/Documenti/Upsampling_CFD/datasets/SplitDatasetOutdoorFlow/test/low_res'

# -------- Transforms --------
lr_transform = transforms.Compose([
    #transforms.Resize((16, 16)),  # Adjust as needed
    transforms.ToTensor(),
])

hr_transform = transforms.Compose([
    transforms.ToTensor(),  # Assumes HR is already at correct resolution
])


In [27]:


# -------- Create Datasets --------
full_train_dataset = ImagePairDataset(folder_path_train_lr, folder_path_train_hr,
                                      lr_transform=lr_transform, hr_transform=hr_transform)

test_dataset = ImagePairDataset(folder_path_val_lr, folder_path_val_hr,
                                 lr_transform=lr_transform, hr_transform=hr_transform)

# -------- DataLoaders --------
train_loader = DataLoader(full_train_dataset, batch_size=8, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=8, shuffle=False)

# -------- Print Stats --------
print(f"Train Dataset Size: {len(full_train_dataset)}")
print(f"Test Dataset Size: {len(test_dataset)}")


Train Dataset Size: 772
Test Dataset Size: 192


In [28]:
num_epochs = 80
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# model = UpsampleDiffusionModel().to(device)
model = UNetForUpsampling().to(device)
# model = UNetForUpsamplingGELU().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

train_losses = []
val_losses = []

for epoch in tqdm(range(num_epochs)):
    model.train()  # Set model to training mode

    # Training phase
    total_train_loss = 0
    for lr_imgs, hr_imgs in train_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        optimizer.zero_grad()
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)

    # Validation phase
    model.eval()  # Set model to evaluation mode
    total_val_loss = 0
    with torch.no_grad():  # No gradients needed for validation
        for lr_imgs, hr_imgs in test_loader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)
            total_val_loss += loss.item()
    avg_val_loss = total_val_loss / len(test_loader)

    # Save the average losses
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

cuda


  1%|▏         | 1/80 [00:08<11:39,  8.86s/it]

Epoch [1/80], Train Loss: 0.1209, Val Loss: 0.0150


  2%|▎         | 2/80 [00:16<10:56,  8.42s/it]

Epoch [2/80], Train Loss: 0.0077, Val Loss: 0.0043


  4%|▍         | 3/80 [00:25<10:56,  8.53s/it]

Epoch [3/80], Train Loss: 0.0032, Val Loss: 0.0024


  5%|▌         | 4/80 [00:34<10:46,  8.51s/it]

Epoch [4/80], Train Loss: 0.0021, Val Loss: 0.0017


  6%|▋         | 5/80 [00:42<10:39,  8.53s/it]

Epoch [5/80], Train Loss: 0.0015, Val Loss: 0.0013


  8%|▊         | 6/80 [00:51<10:32,  8.55s/it]

Epoch [6/80], Train Loss: 0.0012, Val Loss: 0.0011


  9%|▉         | 7/80 [00:59<10:21,  8.51s/it]

Epoch [7/80], Train Loss: 0.0010, Val Loss: 0.0009


 10%|█         | 8/80 [01:08<10:13,  8.52s/it]

Epoch [8/80], Train Loss: 0.0009, Val Loss: 0.0008


 11%|█▏        | 9/80 [01:16<10:06,  8.55s/it]

Epoch [9/80], Train Loss: 0.0008, Val Loss: 0.0008


 12%|█▎        | 10/80 [01:25<10:00,  8.57s/it]

Epoch [10/80], Train Loss: 0.0007, Val Loss: 0.0007


 14%|█▍        | 11/80 [01:34<09:57,  8.67s/it]

Epoch [11/80], Train Loss: 0.0007, Val Loss: 0.0007


 15%|█▌        | 12/80 [01:43<09:56,  8.78s/it]

Epoch [12/80], Train Loss: 0.0007, Val Loss: 0.0007


 16%|█▋        | 13/80 [01:52<10:00,  8.96s/it]

Epoch [13/80], Train Loss: 0.0006, Val Loss: 0.0006


 18%|█▊        | 14/80 [02:01<09:56,  9.03s/it]

Epoch [14/80], Train Loss: 0.0006, Val Loss: 0.0006


 19%|█▉        | 15/80 [02:10<09:46,  9.02s/it]

Epoch [15/80], Train Loss: 0.0006, Val Loss: 0.0006


 20%|██        | 16/80 [02:20<09:45,  9.15s/it]

Epoch [16/80], Train Loss: 0.0006, Val Loss: 0.0006


 21%|██▏       | 17/80 [02:29<09:34,  9.12s/it]

Epoch [17/80], Train Loss: 0.0006, Val Loss: 0.0006


 22%|██▎       | 18/80 [02:39<09:35,  9.28s/it]

Epoch [18/80], Train Loss: 0.0006, Val Loss: 0.0006


 24%|██▍       | 19/80 [02:48<09:26,  9.28s/it]

Epoch [19/80], Train Loss: 0.0006, Val Loss: 0.0005


 25%|██▌       | 20/80 [02:57<09:19,  9.32s/it]

Epoch [20/80], Train Loss: 0.0005, Val Loss: 0.0005


 26%|██▋       | 21/80 [03:07<09:14,  9.40s/it]

Epoch [21/80], Train Loss: 0.0005, Val Loss: 0.0005


 28%|██▊       | 22/80 [03:16<09:08,  9.46s/it]

Epoch [22/80], Train Loss: 0.0005, Val Loss: 0.0005


 29%|██▉       | 23/80 [03:26<08:57,  9.43s/it]

Epoch [23/80], Train Loss: 0.0005, Val Loss: 0.0005


 30%|███       | 24/80 [03:35<08:46,  9.40s/it]

Epoch [24/80], Train Loss: 0.0005, Val Loss: 0.0005


 31%|███▏      | 25/80 [03:45<08:36,  9.40s/it]

Epoch [25/80], Train Loss: 0.0005, Val Loss: 0.0005


 32%|███▎      | 26/80 [03:54<08:28,  9.42s/it]

Epoch [26/80], Train Loss: 0.0005, Val Loss: 0.0005


 34%|███▍      | 27/80 [04:03<08:19,  9.42s/it]

Epoch [27/80], Train Loss: 0.0005, Val Loss: 0.0005


 35%|███▌      | 28/80 [04:13<08:12,  9.47s/it]

Epoch [28/80], Train Loss: 0.0005, Val Loss: 0.0005


 36%|███▋      | 29/80 [04:23<08:08,  9.58s/it]

Epoch [29/80], Train Loss: 0.0005, Val Loss: 0.0005


 38%|███▊      | 30/80 [04:33<08:03,  9.66s/it]

Epoch [30/80], Train Loss: 0.0005, Val Loss: 0.0005


 39%|███▉      | 31/80 [04:42<07:53,  9.67s/it]

Epoch [31/80], Train Loss: 0.0005, Val Loss: 0.0005


 40%|████      | 32/80 [04:52<07:41,  9.61s/it]

Epoch [32/80], Train Loss: 0.0005, Val Loss: 0.0005


 41%|████▏     | 33/80 [05:01<07:30,  9.59s/it]

Epoch [33/80], Train Loss: 0.0005, Val Loss: 0.0005


 42%|████▎     | 34/80 [05:11<07:18,  9.54s/it]

Epoch [34/80], Train Loss: 0.0005, Val Loss: 0.0005


 44%|████▍     | 35/80 [05:20<07:07,  9.51s/it]

Epoch [35/80], Train Loss: 0.0005, Val Loss: 0.0004


 45%|████▌     | 36/80 [05:30<06:59,  9.53s/it]

Epoch [36/80], Train Loss: 0.0005, Val Loss: 0.0004


 46%|████▋     | 37/80 [05:39<06:50,  9.54s/it]

Epoch [37/80], Train Loss: 0.0004, Val Loss: 0.0004


 48%|████▊     | 38/80 [05:49<06:42,  9.58s/it]

Epoch [38/80], Train Loss: 0.0004, Val Loss: 0.0004


 49%|████▉     | 39/80 [05:59<06:31,  9.56s/it]

Epoch [39/80], Train Loss: 0.0004, Val Loss: 0.0004


 50%|█████     | 40/80 [06:08<06:20,  9.52s/it]

Epoch [40/80], Train Loss: 0.0004, Val Loss: 0.0004


 51%|█████▏    | 41/80 [06:18<06:12,  9.55s/it]

Epoch [41/80], Train Loss: 0.0004, Val Loss: 0.0004


 52%|█████▎    | 42/80 [06:27<06:03,  9.56s/it]

Epoch [42/80], Train Loss: 0.0004, Val Loss: 0.0004


 54%|█████▍    | 43/80 [06:37<05:53,  9.56s/it]

Epoch [43/80], Train Loss: 0.0004, Val Loss: 0.0004


 55%|█████▌    | 44/80 [06:46<05:43,  9.53s/it]

Epoch [44/80], Train Loss: 0.0004, Val Loss: 0.0004


 56%|█████▋    | 45/80 [06:56<05:33,  9.54s/it]

Epoch [45/80], Train Loss: 0.0004, Val Loss: 0.0004


 57%|█████▊    | 46/80 [07:05<05:25,  9.57s/it]

Epoch [46/80], Train Loss: 0.0004, Val Loss: 0.0004


 59%|█████▉    | 47/80 [07:15<05:16,  9.58s/it]

Epoch [47/80], Train Loss: 0.0004, Val Loss: 0.0004


 60%|██████    | 48/80 [07:25<05:06,  9.57s/it]

Epoch [48/80], Train Loss: 0.0004, Val Loss: 0.0004


 61%|██████▏   | 49/80 [07:34<04:58,  9.62s/it]

Epoch [49/80], Train Loss: 0.0004, Val Loss: 0.0004


 62%|██████▎   | 50/80 [07:44<04:47,  9.59s/it]

Epoch [50/80], Train Loss: 0.0004, Val Loss: 0.0004


 64%|██████▍   | 51/80 [07:53<04:38,  9.59s/it]

Epoch [51/80], Train Loss: 0.0004, Val Loss: 0.0004


 65%|██████▌   | 52/80 [08:03<04:27,  9.56s/it]

Epoch [52/80], Train Loss: 0.0004, Val Loss: 0.0004


 66%|██████▋   | 53/80 [08:13<04:18,  9.57s/it]

Epoch [53/80], Train Loss: 0.0004, Val Loss: 0.0004


 68%|██████▊   | 54/80 [08:22<04:09,  9.59s/it]

Epoch [54/80], Train Loss: 0.0004, Val Loss: 0.0004


 69%|██████▉   | 55/80 [08:32<04:01,  9.67s/it]

Epoch [55/80], Train Loss: 0.0004, Val Loss: 0.0004


 70%|███████   | 56/80 [08:42<03:53,  9.74s/it]

Epoch [56/80], Train Loss: 0.0004, Val Loss: 0.0004


 71%|███████▏  | 57/80 [08:52<03:46,  9.85s/it]

Epoch [57/80], Train Loss: 0.0004, Val Loss: 0.0004


 72%|███████▎  | 58/80 [09:02<03:39,  9.98s/it]

Epoch [58/80], Train Loss: 0.0004, Val Loss: 0.0004


 74%|███████▍  | 59/80 [09:12<03:29, 10.00s/it]

Epoch [59/80], Train Loss: 0.0004, Val Loss: 0.0004


 75%|███████▌  | 60/80 [09:22<03:18,  9.90s/it]

Epoch [60/80], Train Loss: 0.0004, Val Loss: 0.0004


 76%|███████▋  | 61/80 [09:32<03:08,  9.94s/it]

Epoch [61/80], Train Loss: 0.0004, Val Loss: 0.0004


 78%|███████▊  | 62/80 [09:42<02:58,  9.91s/it]

Epoch [62/80], Train Loss: 0.0004, Val Loss: 0.0004


 79%|███████▉  | 63/80 [09:52<02:47,  9.86s/it]

Epoch [63/80], Train Loss: 0.0004, Val Loss: 0.0004


 80%|████████  | 64/80 [10:01<02:36,  9.80s/it]

Epoch [64/80], Train Loss: 0.0004, Val Loss: 0.0004


 81%|████████▏ | 65/80 [10:11<02:25,  9.73s/it]

Epoch [65/80], Train Loss: 0.0004, Val Loss: 0.0004


 82%|████████▎ | 66/80 [10:21<02:15,  9.70s/it]

Epoch [66/80], Train Loss: 0.0004, Val Loss: 0.0004


 84%|████████▍ | 67/80 [10:30<02:06,  9.76s/it]

Epoch [67/80], Train Loss: 0.0004, Val Loss: 0.0004


 85%|████████▌ | 68/80 [10:40<01:56,  9.74s/it]

Epoch [68/80], Train Loss: 0.0004, Val Loss: 0.0004


 86%|████████▋ | 69/80 [10:50<01:47,  9.78s/it]

Epoch [69/80], Train Loss: 0.0004, Val Loss: 0.0004


 88%|████████▊ | 70/80 [11:00<01:37,  9.80s/it]

Epoch [70/80], Train Loss: 0.0004, Val Loss: 0.0003


 89%|████████▉ | 71/80 [11:10<01:28,  9.79s/it]

Epoch [71/80], Train Loss: 0.0003, Val Loss: 0.0003


 90%|█████████ | 72/80 [11:19<01:17,  9.74s/it]

Epoch [72/80], Train Loss: 0.0003, Val Loss: 0.0003


 91%|█████████▏| 73/80 [11:29<01:08,  9.74s/it]

Epoch [73/80], Train Loss: 0.0004, Val Loss: 0.0003


 92%|█████████▎| 74/80 [11:39<00:58,  9.76s/it]

Epoch [74/80], Train Loss: 0.0003, Val Loss: 0.0004


 94%|█████████▍| 75/80 [11:48<00:48,  9.71s/it]

Epoch [75/80], Train Loss: 0.0003, Val Loss: 0.0003


 95%|█████████▌| 76/80 [11:58<00:38,  9.68s/it]

Epoch [76/80], Train Loss: 0.0003, Val Loss: 0.0004


 96%|█████████▋| 77/80 [12:08<00:28,  9.66s/it]

Epoch [77/80], Train Loss: 0.0003, Val Loss: 0.0003


 98%|█████████▊| 78/80 [12:17<00:19,  9.62s/it]

Epoch [78/80], Train Loss: 0.0003, Val Loss: 0.0003


 99%|█████████▉| 79/80 [12:27<00:09,  9.65s/it]

Epoch [79/80], Train Loss: 0.0003, Val Loss: 0.0003


100%|██████████| 80/80 [12:37<00:00,  9.46s/it]

Epoch [80/80], Train Loss: 0.0003, Val Loss: 0.0003





In [29]:
import torch
import torch.nn.functional as F

def psnr(pred, target, max_val=1.0):
    mse = F.mse_loss(pred, target)
    return 20 * torch.log10(max_val / torch.sqrt(mse))

def calculate_average_psnr(data_loader, model, device):
    total_psnr_model = 0.0
    total_batches = 0

    model.eval()
    with torch.no_grad():
        for lr_imgs, hr_imgs in data_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)

            # Forward pass
            sr_imgs = model(lr_imgs)

            # Optional visualization (only on the first sample in batch)
            # visualize_sr_vs_hr(sr_imgs[0].cpu(), hr_imgs[0].cpu())

            # Compute PSNR directly
            psnr_model = psnr(sr_imgs, hr_imgs).item()
            total_psnr_model += psnr_model
            total_batches += 1

    avg_psnr_model = total_psnr_model / total_batches
    return avg_psnr_model

val_psnr_model = calculate_average_psnr(test_loader, model, device)
print(f"Validation PSNR - Model Output: {val_psnr_model:.2f}")


Validation PSNR - Model Output: 34.91


In [30]:
import torch
import torch.nn.functional as F

def ssim(pred, target, C1=0.01**2, C2=0.03**2):
    """Structural Similarity Index Measure (SSIM) for image batches."""
    # Assumes pred and target are shape (B, C, H, W)
    B, C, H, W = pred.size()

    # Create a Gaussian filter manually (uniform box blur here for simplicity)
    kernel = torch.ones((C, 1, 11, 11), device=pred.device) / 121

    # Compute means
    mu_pred = F.conv2d(pred, kernel, padding=5, groups=C)
    mu_target = F.conv2d(target, kernel, padding=5, groups=C)

    # Compute variances and covariance
    sigma_pred = F.conv2d(pred * pred, kernel, padding=5, groups=C) - mu_pred ** 2
    sigma_target = F.conv2d(target * target, kernel, padding=5, groups=C) - mu_target ** 2
    sigma_cross = F.conv2d(pred * target, kernel, padding=5, groups=C) - mu_pred * mu_target

    # SSIM formula
    numerator = (2 * mu_pred * mu_target + C1) * (2 * sigma_cross + C2)
    denominator = (mu_pred ** 2 + mu_target ** 2 + C1) * (sigma_pred + sigma_target + C2)
    ssim_map = numerator / denominator

    return ssim_map.mean()

def calculate_average_ssim(data_loader, model, device):
    total_ssim_model = 0.0
    total_ssim_bicubic = 0.0
    total_batches = 0

    model.eval()
    with torch.no_grad():
        for lr_imgs, hr_imgs in data_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)

            # Get model output
            sr_imgs = model(lr_imgs)
            sr_imgs = torch.clamp(sr_imgs, 0.0, 1.0)

            # Bicubic baseline
            bicubic = F.interpolate(lr_imgs, scale_factor=4, mode='bicubic', align_corners=False)
            bicubic = torch.clamp(bicubic, 0.0, 1.0)

            # Compute SSIM
            ssim_model = ssim(sr_imgs, hr_imgs).item()
            ssim_bicubic = ssim(bicubic, hr_imgs).item()

            total_ssim_model += ssim_model
            total_ssim_bicubic += ssim_bicubic
            total_batches += 1

    avg_ssim_model = total_ssim_model / total_batches
    avg_ssim_bicubic = total_ssim_bicubic / total_batches
    return avg_ssim_model, avg_ssim_bicubic


val_ssim_model, val_ssim_bicubic = calculate_average_ssim(test_loader, model, device)
print(f"Validation SSIM - Model Output: {val_ssim_model:.4f}, Bicubic: {val_ssim_bicubic:.4f}")


RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 3

In [None]:
import torch
import torch.nn.functional as F

def compute_rmse(pred, target):
    """
    Compute Root Mean Squared Error (RMSE) between prediction and ground truth.
    Assumes both are tensors in [0, 1] and of shape (B, C, H, W).
    """
    mse = F.mse_loss(pred, target)
    return torch.sqrt(mse).item()

def calculate_average_rmse(data_loader, model, device):
    """
    Compute average RMSE for model predictions and bicubic interpolation over a dataset.
    """
    total_rmse_model = 0.0
    total_rmse_bicubic = 0.0
    total_batches = 0

    model.eval()
    with torch.no_grad():
        for lr_imgs, hr_imgs in data_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)

            # Model prediction
            sr_imgs = model(lr_imgs)
            sr_imgs = torch.clamp(sr_imgs, 0.0, 1.0)

            # Bicubic interpolation baseline
            bicubic = F.interpolate(lr_imgs, scale_factor=4, mode='bicubic', align_corners=False)
            bicubic = torch.clamp(bicubic, 0.0, 1.0)

            # Compute RMSEs
            rmse_model = compute_rmse(sr_imgs, hr_imgs)
            rmse_bicubic = compute_rmse(bicubic, hr_imgs)

            total_rmse_model += rmse_model
            total_rmse_bicubic += rmse_bicubic
            total_batches += 1

    avg_rmse_model = total_rmse_model / total_batches
    avg_rmse_bicubic = total_rmse_bicubic / total_batches

    return avg_rmse_model, avg_rmse_bicubic

# Assuming your DataLoaders and model are ready
val_rmse_model, val_rmse_bicubic = calculate_average_rmse(test_loader, model, device)

print(f"Validation RMSE - Model Output: {val_rmse_model:.4f}, Bicubic: {val_rmse_bicubic:.4f}")


Validation RMSE - Model Output: 0.0402, Bicubic: 0.0607


: 

In [None]:
def show_images_from_dataset(dataset_type='test', batch_num=0, num_images=4, first_image=0):
    # Select the appropriate DataLoader
    if dataset_type == 'train':
        selected_loader = train_loader
    elif dataset_type == 'val':
        selected_loader = val_loader
    else:
        raise ValueError("Invalid dataset type. Choose 'train', 'val', or 'test'.")

    model.eval()  # Set the model to evaluation mode

    current_batch = 0
    with torch.no_grad():  # Inference without gradient calculation
        for lr_imgs, hr_imgs in selected_loader:
            if current_batch == batch_num:
                lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
                outputs = model(lr_imgs)
                bicubic = torch.nn.functional.interpolate(lr_imgs, scale_factor=2, mode='bicubic', align_corners=False)
                break  # Exit loop after processing the specified batch
            current_batch += 1

    # Visualization
    plt.figure(figsize=(10, 10))
    for i in range(min(num_images, lr_imgs.size(0))):  # Ensure not to exceed batch size
        # Plotting code remains the same as before
        # Low-resolution images
        plt.subplot(4, num_images, i+1)
        plt.imshow(lr_imgs[i+first_image].permute(1, 2, 0).cpu().detach().numpy())
        plt.title('Low-Res')
        plt.axis('off')

        # High-resolution (Ground Truth)
        plt.subplot(4, num_images, num_images+i+1)
        plt.imshow(hr_imgs[i+first_image].permute(1, 2, 0).cpu().detach().numpy())
        plt.title('High-Res')
        plt.axis('off')

        # Model output
        plt.subplot(4, num_images, 2*num_images+i+1)
        plt.imshow(outputs[i+first_image].permute(1, 2, 0).cpu().detach().numpy())
        plt.title('Output')
        plt.axis('off')

        # Bicubic upsampled
        plt.subplot(4, num_images, 3*num_images+i+1)
        plt.imshow(bicubic[i+first_image].permute(1, 2, 0).cpu().detach().numpy())
        plt.title('Bicubic')
        plt.axis('off')
    plt.show()

# Example usage:
show_images_from_dataset(dataset_type='val', batch_num=3, num_images=4, first_image=20)
print(f"Number of batches in Training DataLoader: {len(train_loader)}")
print(f"Number of batches in Validation DataLoader: {len(val_loader)}")
print(lr_imgs.size())
print(hr_imgs.size())
print(outputs.size())
