<a href="https://colab.research.google.com/github/vnayakde/Dual-Image-Super-Resolution-for-High-Resolution-Optical-Satellite-Imagery/blob/main/model_buildingEDSR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
'''LR1 → Conv Block 1 → F1 ┐
                        │
                        ├─ Fusion Block (ESA, attention, or concat) → SR Backbone (EDSR, RCNN) → HR Output
                        │
LR2 → Conv Block 2 → F2 ┘
'''

'LR1 → Conv Block 1 → F1 ┐\n                        │\n                        ├─ Fusion Block (ESA, attention, or concat) → SR Backbone (EDSR, RCAN) → HR Output\n                        │\nLR2 → Conv Block 2 → F2 ┘\n'

In [None]:
pip install torch torchvision


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn

class InitialConvBlock(nn.Module):
    def __init__(self, in_channels=1, out_channels=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [None]:
class FusionBlock(nn.Module):
    def __init__(self, in_channels=128, out_channels=64):
        super().__init__()
        self.fuse = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, f1, f2):
        fused = torch.cat([f1, f2], dim=1)  # Concatenate along channel axis
        return self.fuse(fused)


In [None]:
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return x + self.block(x)  # skip connection


In [None]:
class EDSRBackbone(nn.Module):
    def __init__(self, channels=64, num_blocks=8):
        super().__init__()
        self.blocks = nn.Sequential(
            *[ResidualBlock(channels) for _ in range(num_blocks)]
        )

    def forward(self, x):
        return self.blocks(x)


In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels=64, scale=4):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * scale**2, kernel_size=3, padding=1),
            nn.PixelShuffle(scale),
            nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)  # output = 1 channel HR
        )

    def forward(self, x):
        return self.upsample(x)


In [None]:
class DualImageSR(nn.Module):
    def __init__(self, in_channels=1, scale=4):
        super().__init__()
        self.conv1 = InitialConvBlock(in_channels, 64)  # For LR1
        self.conv2 = InitialConvBlock(in_channels, 64)  # For LR2
        self.fusion = FusionBlock(in_channels=128, out_channels=64)
        self.backbone = EDSRBackbone(channels=64, num_blocks=8)
        self.upsample = UpsampleBlock(in_channels=64, scale=scale)

    def forward(self, lr1, lr2):
        f1 = self.conv1(lr1)         # Feature from LR1
        f2 = self.conv2(lr2)         # Feature from LR2
        fused = self.fusion(f1, f2)  # Fusion of both
        refined = self.backbone(fused)  # Deep enhancement
        output = self.upsample(refined)  # Upsample to HR
        return output


In [None]:
model = DualImageSR(in_channels=1, scale=4)

lr1 = torch.randn(1, 1, 64, 64)  # Example LR1
lr2 = torch.randn(1, 1, 64, 64)  # Example LR2

output = model(lr1, lr2)  # Output shape: [1, 1, 256, 256]
print(output.shape)


torch.Size([1, 1, 256, 256])


In [None]:
!unzip dual_sr_dataset.zip -d /content/dual_sr_dataset

Archive:  dual_sr_dataset.zip
   creating: /content/dual_sr_dataset/dual_sr_dataset/
   creating: /content/dual_sr_dataset/dual_sr_dataset/train/
   creating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0594_LR0.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0594_LR1.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0595_LR0.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0595_LR1.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0597_LR1.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0598_LR0.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0599_LR1.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgset0600_LR0.png  
  inflating: /content/dual_sr_dataset/dual_sr_dataset/train/low_res/imgse

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class DualImageSRDataset(Dataset):
    def __init__(self, root_dir, hr_image_size=(256, 256), transform=None):

        self.lr_dir = os.path.join(root_dir, "low_res")
        self.hr_dir = os.path.join(root_dir, "high_res")


        self.filenames = sorted([
            fname.replace("_LR0.png", "")
            for fname in os.listdir(self.lr_dir)
            if fname.endswith("_LR0.png")
        ])


        self.transform = transform or transforms.Compose([
            transforms.Resize(hr_image_size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor()
        ])

        self.lr_transform = transforms.Compose([
            transforms.Resize((hr_image_size[0] // 4, hr_image_size[1] // 4), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor()
        ])


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        base = self.filenames[idx]

        # Paths
        lr1_path = os.path.join(self.lr_dir, f"{base}_LR0.png")
        lr2_path = os.path.join(self.lr_dir, f"{base}_LR1.png")
        hr_path  = os.path.join(self.hr_dir,  f"{base}_HR.png")

        # Load and transform
        lr1 = self.lr_transform(Image.open(lr1_path).convert('L'))
        lr2 = self.lr_transform(Image.open(lr2_path).convert('L'))
        hr  = self.transform(Image.open(hr_path).convert('L'))


        return {
            "lr1": lr1,  # shape: [1, H, W]
            "lr2": lr2,
            "hr": hr     # shape: [1, scale×H, scale×W]
        }

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

device(type='cpu')

In [None]:
import torch.nn.functional as F
import math

def calculate_psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 100
    psnr = 20 * math.log10(max_val / math.sqrt(mse))
    return psnr


In [None]:
from torch.utils.data import DataLoader


train_ds = DualImageSRDataset("/content/dual_sr_dataset/dual_sr_dataset/train")
test_ds  = DualImageSRDataset("/content/dual_sr_dataset/dual_sr_dataset/test")

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)


model = DualImageSR(scale=4).to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(1, 51):
    model.train()
    total_train_loss = 0

    for batch in train_loader:
        lr1, lr2, hr = batch['lr1'].to(device), batch['lr2'].to(device), batch['hr'].to(device)

        sr = model(lr1, lr2)
        loss = criterion(sr, hr)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    # -eval
    model.eval()
    total_psnr = 0
    total_test_loss = 0

    with torch.no_grad():
        for batch in test_loader:
            lr1, lr2, hr = batch['lr1'].to(device), batch['lr2'].to(device), batch['hr'].to(device)
            sr = model(lr1, lr2)
            loss = criterion(sr, hr)
            psnr = calculate_psnr(sr, hr)

            total_test_loss += loss.item()
            total_psnr += psnr

    avg_test_loss = total_test_loss / len(test_loader)
    avg_psnr = total_psnr / len(test_loader)

    print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | PSNR: {avg_psnr:.2f} dB")

[Epoch 1] Train Loss: 0.1035 | Test Loss: 0.0815 | PSNR: 20.91 dB
[Epoch 2] Train Loss: 0.0841 | Test Loss: 0.0778 | PSNR: 21.51 dB
[Epoch 3] Train Loss: 0.0804 | Test Loss: 0.0743 | PSNR: 21.94 dB
[Epoch 4] Train Loss: 0.0789 | Test Loss: 0.0815 | PSNR: 21.37 dB
[Epoch 5] Train Loss: 0.0762 | Test Loss: 0.0713 | PSNR: 22.33 dB
[Epoch 6] Train Loss: 0.0758 | Test Loss: 0.0737 | PSNR: 22.10 dB
[Epoch 7] Train Loss: 0.0756 | Test Loss: 0.0771 | PSNR: 21.76 dB
[Epoch 8] Train Loss: 0.0745 | Test Loss: 0.0724 | PSNR: 22.24 dB
[Epoch 9] Train Loss: 0.0754 | Test Loss: 0.0729 | PSNR: 22.25 dB
[Epoch 10] Train Loss: 0.0733 | Test Loss: 0.0703 | PSNR: 22.56 dB
[Epoch 11] Train Loss: 0.0737 | Test Loss: 0.0702 | PSNR: 22.52 dB
[Epoch 12] Train Loss: 0.0744 | Test Loss: 0.0765 | PSNR: 21.91 dB
[Epoch 13] Train Loss: 0.0726 | Test Loss: 0.0757 | PSNR: 21.92 dB
[Epoch 14] Train Loss: 0.0745 | Test Loss: 0.0699 | PSNR: 22.54 dB
[Epoch 15] Train Loss: 0.0725 | Test Loss: 0.0703 | PSNR: 22.52 dB
[Epo

In [None]:
os.makedirs("checkpoints", exist_ok=True)


In [None]:
best_test_loss = float('inf')


if avg_test_loss < best_test_loss:
    best_test_loss = avg_test_loss
    torch.save(model.state_dict(), f"checkpoints/best_model_epoch_{epoch}.pt")
    print(f" Model saved at epoch {epoch} with test loss {avg_test_loss:.4f}")


✅ Model saved at epoch 50 with test loss 0.0685


In [None]:
from torchvision.utils import save_image

os.makedirs("visuals", exist_ok=True)

for i, batch in enumerate(test_loader):
    ...
    sr = model(lr1, lr2)


    if i < 5:
        save_image(sr, f"visuals/sr_{epoch}_sample_{i}.png")
        save_image(hr, f"visuals/hr_{epoch}_sample_{i}.png")


In [None]:
train_losses = []
test_losses = []
psnrs = []

train_losses.append(avg_train_loss)
test_losses.append(avg_test_loss)
psnrs.append(avg_psnr)


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class BlindTestDataset(Dataset):
    def __init__(self, low_res_dir, transform=None):
        self.low_res_dir = low_res_dir
        self.transform = transform or transforms.ToTensor()

        self.filenames = sorted(set([
            f.replace("_LR0.png", "").replace("_LR1.png", "")
            for f in os.listdir(low_res_dir)
            if f.endswith(".png")
        ]))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        base = self.filenames[idx]

        lr1 = self.transform(Image.open(os.path.join(self.low_res_dir, f"{base}_LR0.png")).convert('L'))
        lr2 = self.transform(Image.open(os.path.join(self.low_res_dir, f"{base}_LR1.png")).convert('L'))

        return {
            "lr1": lr1,
            "lr2": lr2,
            "id": base
        }


In [None]:
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch
import os


model = DualImageSR(scale=4).to("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("checkpoints/best_model_epoch_50.pt"))
model.eval()

# Create output folder
os.makedirs("blind_sr_outputs", exist_ok=True)

# Dataset and loader
blind_ds = BlindTestDataset("/content/dual_sr_dataset/dual_sr_dataset/blind_test_set")
blind_loader = DataLoader(blind_ds, batch_size=1, shuffle=False)

# Inference loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
    for i, batch in enumerate(blind_loader):
        lr1 = batch['lr1'].to(device)
        lr2 = batch['lr2'].to(device)
        img_id = batch['id'][0]

        sr = model(lr1, lr2)  # Output: [1, 1, H, W]

        # Save image
        save_image(sr, f"blind_sr_outputs/{img_id}_SR.png")

        print(f"Saved: {img_id}_SR.png")

Saved: imgset1160_SR.png
Saved: imgset1161_SR.png
Saved: imgset1162_SR.png
Saved: imgset1163_SR.png
Saved: imgset1164_SR.png
Saved: imgset1165_SR.png
Saved: imgset1166_SR.png
Saved: imgset1167_SR.png
Saved: imgset1168_SR.png
Saved: imgset1169_SR.png
Saved: imgset1170_SR.png
Saved: imgset1171_SR.png
Saved: imgset1172_SR.png
Saved: imgset1173_SR.png
Saved: imgset1174_SR.png
Saved: imgset1175_SR.png
Saved: imgset1176_SR.png
Saved: imgset1177_SR.png
Saved: imgset1178_SR.png
Saved: imgset1179_SR.png
Saved: imgset1180_SR.png
Saved: imgset1181_SR.png
Saved: imgset1182_SR.png
Saved: imgset1183_SR.png
Saved: imgset1184_SR.png
Saved: imgset1185_SR.png
Saved: imgset1186_SR.png
Saved: imgset1187_SR.png
Saved: imgset1188_SR.png
Saved: imgset1189_SR.png
Saved: imgset1190_SR.png
Saved: imgset1191_SR.png
Saved: imgset1192_SR.png
Saved: imgset1193_SR.png
Saved: imgset1194_SR.png
Saved: imgset1195_SR.png
Saved: imgset1196_SR.png
Saved: imgset1197_SR.png
Saved: imgset1198_SR.png
Saved: imgset1199_SR.png


In [None]:


from google.colab import files


!zip -r /content/blind_sr_outputs.zip /content/blind_sr_outputs


files.download('/content/blind_sr_outputs.zip')


  adding: content/blind_sr_outputs/ (stored 0%)
  adding: content/blind_sr_outputs/imgset1306_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1262_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1287_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1328_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1427_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1174_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1268_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1393_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1264_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1297_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1448_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1278_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1374_SR.png (deflated 0%)
  adding: content/blind_sr_outputs/imgset1196_SR.png (deflated 0%)
  adding: cont

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>