In [1]:
import os
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
def create_lr_image(hr_image, scale_factor=2, degradation='bicubic'):
    h, w, c = hr_image.shape
    new_h, new_w = h // scale_factor, w // scale_factor
    if degradation == 'bicubic':
        lr_image = cv2.resize(hr_image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    else:
        lr_image = cv2.resize(hr_image, (new_w, new_h), interpolation=cv2.INTER_AREA)
    return lr_image
def img_to_tensor(img):
    img = np.transpose(img, (2, 0, 1))
    img = img.astype(np.float32) / 255.0
    return torch.from_numpy(img)
class SuperResolutionDataset(Dataset):
    def __init__(self, 
                 hr_dir, 
                 patch_size=128, 
                 scale_factor=2, 
                 transform=None):
        self.hr_dir = hr_dir
        self.patch_size = patch_size
        self.scale_factor = scale_factor
        self.transform = transform
        self.image_files = [os.path.join(hr_dir, f) 
                            for f in os.listdir(hr_dir) if f.lower().endswith(('png','jpg','jpeg'))]
    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        hr_path = self.image_files[idx]
        hr_bgr = cv2.imread(hr_path)  # BGR format
        hr_rgb = cv2.cvtColor(hr_bgr, cv2.COLOR_BGR2RGB)
        h, w, c = hr_rgb.shape
        if h < self.patch_size or w < self.patch_size:
            hr_rgb = cv2.resize(hr_rgb, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)
            h, w, c = hr_rgb.shape
        top = random.randint(0, h - self.patch_size)
        left = random.randint(0, w - self.patch_size)
        hr_patch = hr_rgb[top:top+self.patch_size, left:left+self.patch_size, :]
        lr_patch = create_lr_image(hr_patch, scale_factor=self.scale_factor)
        if self.transform:
            hr_patch, lr_patch = self.transform(hr_patch, lr_patch)
        hr_tensor = img_to_tensor(hr_patch)  
        lr_tensor = img_to_tensor(lr_patch)
        return lr_tensor, hr_tensor

In [2]:
class SimpleSRNet(nn.Module):
    def __init__(self, scale_factor=2, num_channels=3, base_channels=64):
        super(SimpleSRNet, self).__init__()
        self.scale_factor = scale_factor
        self.conv1 = nn.Sequential(
            nn.Conv2d(num_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Conv2d(base_channels, base_channels * (scale_factor**2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(base_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        feat = self.conv1(x)
        upsampled = self.conv2(feat)       
        upsampled = self.pixel_shuffle(upsampled)  
        upsampled = self.relu2(upsampled)
        out = self.conv3(upsampled)        
        return out

In [3]:
import argparse
import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch import nn, optim

# Assuming SuperResolutionDataset and SimpleSRNet are defined elsewhere.

def train_super_resolution():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--hr_dir', 
        type=str, 
        default='/kaggle/input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR', 
        help='Path to high-resolution training images'
    )
    parser.add_argument(
        '--val_dir', 
        type=str, 
        default='/kaggle/input/div2k-high-resolution-images/DIV2K_valid_HR/DIV2K_valid_HR', 
        help='Path to high-resolution validation images'
    )
    parser.add_argument('--patch_size', type=int, default=128, help='Training patch size')
    parser.add_argument('--scale_factor', type=int, default=4, help='SR scale factor')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
    
    # Filter out unwanted Jupyter arguments
    args, unknown = parser.parse_known_args()
    
    # Create dataset
    train_dataset = SuperResolutionDataset(
        hr_dir=args.hr_dir,
        patch_size=args.patch_size,
        scale_factor=args.scale_factor,
    )
    val_dataset = SuperResolutionDataset(
        hr_dir=args.val_dir,
        patch_size=args.patch_size,
        scale_factor=args.scale_factor,
    )
    
    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleSRNet(scale_factor=args.scale_factor)
    model.to(device)
    
    # Loss and optimizer
    criterion = nn.MSELoss()  
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    # Training Loop
    for epoch in range(args.epochs):
        model.train()
        train_loss = 0.0
        for lr_patches, hr_patches in tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}"):
            lr_patches, hr_patches = lr_patches.to(device), hr_patches.to(device)
            
            # Forward
            sr_patches = model(lr_patches)
            loss = criterion(sr_patches, hr_patches)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for lr_patches, hr_patches in val_loader:
                lr_patches, hr_patches = lr_patches.to(device), hr_patches.to(device)
                sr_patches = model(lr_patches)
                loss = criterion(sr_patches, hr_patches)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        
        print(f"Epoch [{epoch+1}/{args.epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Save model
    torch.save(model.state_dict(), "sr_model.pth")
    print("Model saved to sr_model.pth")


if __name__ == "__main__":
    # Filter sys.argv to ignore Jupyter/Colab's unwanted arguments
    sys.argv = [sys.argv[0]]
    train_super_resolution()


Epoch 1/10: 100%|██████████| 100/100 [00:47<00:00,  2.09it/s]


Epoch [1/10] - Train Loss: 0.0617, Val Loss: 0.0190


Epoch 2/10: 100%|██████████| 100/100 [00:41<00:00,  2.44it/s]


Epoch [2/10] - Train Loss: 0.0122, Val Loss: 0.0096


Epoch 3/10: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s]


Epoch [3/10] - Train Loss: 0.0074, Val Loss: 0.0073


Epoch 4/10: 100%|██████████| 100/100 [00:40<00:00,  2.47it/s]


Epoch [4/10] - Train Loss: 0.0059, Val Loss: 0.0060


Epoch 5/10: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s]


Epoch [5/10] - Train Loss: 0.0055, Val Loss: 0.0052


Epoch 6/10: 100%|██████████| 100/100 [00:39<00:00,  2.51it/s]


Epoch [6/10] - Train Loss: 0.0052, Val Loss: 0.0055


Epoch 7/10: 100%|██████████| 100/100 [00:41<00:00,  2.43it/s]


Epoch [7/10] - Train Loss: 0.0051, Val Loss: 0.0051


Epoch 8/10: 100%|██████████| 100/100 [00:40<00:00,  2.46it/s]


Epoch [8/10] - Train Loss: 0.0049, Val Loss: 0.0056


Epoch 9/10: 100%|██████████| 100/100 [00:40<00:00,  2.48it/s]


Epoch [9/10] - Train Loss: 0.0047, Val Loss: 0.0054


Epoch 10/10: 100%|██████████| 100/100 [00:39<00:00,  2.50it/s]


Epoch [10/10] - Train Loss: 0.0048, Val Loss: 0.0050
Model saved to sr_model.pth


In [4]:
def stitch_patches(patches, patch_coords, out_h, out_w):
    stitched_image = np.zeros((out_h, out_w, 3), dtype=np.float32)
    weight_map = np.zeros((out_h, out_w, 3), dtype=np.float32)
    
    for (patch, (top, left)) in zip(patches, patch_coords):
        ph, pw, _ = patch.shape
        stitched_image[top:top+ph, left:left+pw, :] += patch
        weight_map[top:top+ph, left:left+pw, :] += 1.0
    stitched_image /= np.maximum(weight_map, 1e-8)
    return stitched_image

def inference_large_image(model_path, lr_image_path, scale_factor=2, patch_size=128, overlap=16):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SimpleSRNet(scale_factor=scale_factor)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    lr_bgr = cv2.imread(lr_image_path)
    lr_rgb = cv2.cvtColor(lr_bgr, cv2.COLOR_BGR2RGB)
    lr_h, lr_w, _ = lr_rgb.shape

    hr_h, hr_w = lr_h * scale_factor, lr_w * scale_factor

    patches = []
    patch_coords = []
    step = patch_size - overlap

    for top in range(0, lr_h, step):
        for left in range(0, lr_w, step):
            bottom = min(top + patch_size, lr_h)
            right = min(left + patch_size, lr_w)

            lr_patch = lr_rgb[top:bottom, left:right, :]
            lr_patch_t = img_to_tensor(lr_patch).unsqueeze(0).to(device)  
            with torch.no_grad():
                sr_patch_t = model(lr_patch_t)  

            sr_patch = sr_patch_t.squeeze(0).cpu().numpy().transpose(1,2,0)  # [H*scale, W*scale, 3]
            sr_patch = np.clip(sr_patch, 0.0, 1.0)

            patches.append(sr_patch)
            patch_coords.append((top*scale_factor, left*scale_factor))

    sr_image = stitch_patches(patches, patch_coords, hr_h, hr_w)

    sr_image_8u = (sr_image * 255.0).astype(np.uint8)

    result_bgr = cv2.cvtColor(sr_image_8u, cv2.COLOR_RGB2BGR)
    out_path = 'sr_result.png'
    cv2.imwrite(out_path, result_bgr)
    print(f"Super-resolved image saved to {out_path}")


if __name__ == "__main__":
    inference_large_image(
        model_path="sr_model.pth", 
        lr_image_path="/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X2/000001x2.png",
        scale_factor=4,
        patch_size=64,
        overlap=8
    )


  model.load_state_dict(torch.load(model_path, map_location=device))


Super-resolved image saved to sr_result.png
