In [9]:
# Step 1: Import Necessary Libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from math import log10


In [10]:
class StrongLensingDataset(Dataset):
    def __init__(self, lr_files, hr_files):
        self.lr_files = lr_files
        self.hr_files = hr_files
        
    def __len__(self):
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        lr_path = self.lr_files[idx]
        hr_path = self.hr_files[idx]
        lr_img = np.load(lr_path)  # expected shape: (H, W) or (C, H, W)
        hr_img = np.load(hr_path)
        
        # If the images are single-channel (H, W), add a channel dimension.
        if len(lr_img.shape) == 2:
            lr_img = np.expand_dims(lr_img, axis=0)
        if len(hr_img.shape) == 2:
            hr_img = np.expand_dims(hr_img, axis=0)
        
        # Normalize to [0, 1]
        lr_img = lr_img.astype(np.float32) / (lr_img.max() if lr_img.max() != 0 else 1)
        hr_img = hr_img.astype(np.float32) / (hr_img.max() if hr_img.max() != 0 else 1)
        
        # Convert to torch tensors
        lr_tensor = torch.from_numpy(lr_img)
        hr_tensor = torch.from_numpy(hr_img)
        
        return lr_tensor, hr_tensor


In [11]:
# Update these paths to point to your local dataset directories.
lr_folder = "/Users/EndUser/Downloads/dataset-2/LR"
hr_folder = "/Users/EndUser/Downloads/dataset-2/HR"

# List and sort .npy files in each folder
lr_files = sorted([os.path.join(lr_folder, f) for f in os.listdir(lr_folder) if f.endswith('.npy')])
hr_files = sorted([os.path.join(hr_folder, f) for f in os.listdir(hr_folder) if f.endswith('.npy')])

assert len(lr_files) == len(hr_files), "Mismatch: Number of LR and HR files do not match."

# Split into training and validation sets
train_lr, val_lr, train_hr, val_hr = train_test_split(lr_files, hr_files, test_size=0.1, random_state=42)

# Create Dataset objects
train_dataset = StrongLensingDataset(train_lr, train_hr)
val_dataset = StrongLensingDataset(val_lr, val_hr)

# Create DataLoaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


In [12]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNN(num_channels=1).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [14]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for lr_imgs, hr_imgs in train_loader:
        # Move tensors to device
        lr_imgs = lr_imgs.to(device)   # [B, 1, 75, 75]
        hr_imgs = hr_imgs.to(device)   # [B, 1, 150, 150]
        
        # Upsample LR images to 150x150 using bicubic interpolation
        lr_imgs_upsampled = F.interpolate(lr_imgs, size=(150, 150), mode='bicubic', align_corners=False)
        
        optimizer.zero_grad()
        sr_imgs = model(lr_imgs_upsampled)  # Output shape: [B, 1, 150, 150]
        loss = criterion(sr_imgs, hr_imgs)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * lr_imgs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    
    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            lr_imgs_upsampled = F.interpolate(lr_imgs, size=(150, 150), mode='bicubic', align_corners=False)
            sr_imgs = model(lr_imgs_upsampled)
            loss = criterion(sr_imgs, hr_imgs)
            val_loss += loss.item() * lr_imgs.size(0)
    
    epoch_val_loss = val_loss / len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")


Epoch [1/10], Train Loss: 0.0007, Val Loss: 0.0001
Epoch [2/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [3/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [4/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [5/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [6/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [7/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [8/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [9/10], Train Loss: 0.0001, Val Loss: 0.0001
Epoch [10/10], Train Loss: 0.0001, Val Loss: 0.0001


In [15]:
def psnr(sr, hr):
    mse_val = ((sr - hr) ** 2).mean().item()
    if mse_val == 0:
        return 100
    return 10 * log10(1 / mse_val)

def simple_ssim(sr, hr, C1=0.01**2, C2=0.03**2):
    sr_mean = sr.mean()
    hr_mean = hr.mean()
    sr_var = sr.var()
    hr_var = hr.var()
    sr_hr_cov = ((sr - sr_mean) * (hr - hr_mean)).mean()
    
    ssim_val = (2 * sr_mean * hr_mean + C1) * (2 * sr_hr_cov + C2)
    denom = (sr_mean**2 + hr_mean**2 + C1) * (sr_var + hr_var + C2)
    return (ssim_val / denom).item()

# Evaluate the model on the validation set
model.eval()
mse_list, psnr_list, ssim_list = [], [], []
with torch.no_grad():
    for lr_imgs, hr_imgs in val_loader:
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)
        lr_imgs_upsampled = F.interpolate(lr_imgs, size=(150, 150), mode='bicubic', align_corners=False)
        sr_imgs = model(lr_imgs_upsampled)
        
        batch_mse = criterion(sr_imgs, hr_imgs).item()
        batch_psnr = psnr(sr_imgs, hr_imgs)
        batch_ssim = simple_ssim(sr_imgs, hr_imgs)
        
        mse_list.append(batch_mse)
        psnr_list.append(batch_psnr)
        ssim_list.append(batch_ssim)

avg_mse = np.mean(mse_list)
avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)
print(f"Validation MSE: {avg_mse:.6f}")
print(f"Validation PSNR: {avg_psnr:.4f} dB")
print(f"Validation SSIM: {avg_ssim:.4f}")


Validation MSE: 0.000076
Validation PSNR: 41.2130 dB
Validation SSIM: 0.9972
