In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision.transforms import ToTensor, ToPILImage, Compose, Normalize, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip
from PIL import Image
import os
import glob
import numpy as np
import pandas as pd 
import math
from tqdm import tqdm
import cv2 
import sys
import time
import subprocess 

In [2]:
# Define paths for downloaded files within the writable /kaggle/working/ directory
WORKING_DIR = "/kaggle/working/"
SWINIR_CODE_DIR = os.path.join(WORKING_DIR, "swinir_code")
WEIGHTS_DIR = os.path.join(WORKING_DIR, "weights")
SWINIR_MODEL_FILE = os.path.join(SWINIR_CODE_DIR, "network_swinir.py")
WEIGHTS_FILE = os.path.join(WEIGHTS_DIR, "001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth")

In [3]:
# Create directories if they don't exist
os.makedirs(SWINIR_CODE_DIR, exist_ok=True)
os.makedirs(WEIGHTS_DIR, exist_ok=True)

In [4]:
# URLs
swinir_code_url = "https://raw.githubusercontent.com/JingyunLiang/SwinIR/main/models/network_swinir.py"
weights_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"


In [5]:
# Download SwinIR model code if not already present
if not os.path.exists(SWINIR_MODEL_FILE):
    print(f"Downloading SwinIR code from {swinir_code_url}...")
    try:
        # Using subprocess for better control and error handling than !wget
        subprocess.run(['wget', '-q', '-O', SWINIR_MODEL_FILE, swinir_code_url], check=True)
        print("SwinIR code downloaded successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error downloading SwinIR code: {e}")
        sys.exit(1) # Exit if download fails
    except FileNotFoundError:
        print("Error: 'wget' command not found. Please ensure wget is installed.")
        sys.exit(1)
else:
    print("SwinIR code already exists.")

SwinIR code already exists.


In [6]:
# Download pre-trained weights if not already present
if not os.path.exists(WEIGHTS_FILE):
    print(f"Downloading SwinIR weights from {weights_url}...")
    try:
        subprocess.run(['wget', '-q', '-O', WEIGHTS_FILE, weights_url], check=True)
        print("SwinIR weights downloaded successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error downloading SwinIR weights: {e}")
        sys.exit(1) # Exit if download fails
    except FileNotFoundError:
        print("Error: 'wget' command not found. Please ensure wget is installed.")
        sys.exit(1)
else:
    print("SwinIR weights already exists.")

SwinIR weights already exists.


In [7]:
# Add the directory containing the downloaded model code to Python's path
if SWINIR_CODE_DIR not in sys.path:
    sys.path.append(SWINIR_CODE_DIR)
    print(f"Added {SWINIR_CODE_DIR} to sys.path")

Added /kaggle/working/swinir_code to sys.path


In [8]:
# --- Now we can import SwinIR ---
try:
    # Since network_swinir.py is now in the path, import its class
    from network_swinir import SwinIR as ActualSwinIRNet
    print("Successfully imported SwinIR.")
except ImportError as e:
    print(f"Error importing SwinIR even after download: {e}")
    print("Please check the downloaded file structure and potential dependencies.")
    sys.exit(1)
except Exception as e:
    print(f"An unexpected error occurred during SwinIR import: {e}")
    sys.exit(1)


Successfully imported SwinIR.




In [9]:

image = Image.open('/kaggle/input/dlp-jan-2025-nppe-3/archive/train/gt/gt_00001.png')
width, height = image.size
print(f"The ground truth image resolution is: {width}x{height}")

image = Image.open('/kaggle/input/dlp-jan-2025-nppe-3/archive/train/train/gt_00001.png')
width, height = image.size
print(f"The train image resolution is: {width}x{height}")


The ground truth image resolution is: 1024x640
The train image resolution is: 256x160


In [11]:
CONFIG = {
    "data_root": "/kaggle/input/dlp-jan-2025-nppe-3/archive", # Input data
    "output_dir": os.path.join(WORKING_DIR, "output"), # Save outputs here
    "test_output_dir": os.path.join(WORKING_DIR, "submission_images"), # Final test outputs
    # Paths are now set to the download locations
    "pretrained_weights_path": WEIGHTS_FILE,
    "model_save_name": "swinir_finetuned_model.pth",
    "epochs": 30,
    "batch_size": 4,
    "learning_rate": 5e-5,
    "upscale_factor": 4,
    "window_size": 8,
    "train_image_size": 64,
    "num_workers": 2,
    "seed": 42,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
}

In [13]:
# Create output directories (should already exist, but good practice)
os.makedirs(CONFIG["output_dir"], exist_ok=True)
os.makedirs(CONFIG["test_output_dir"], exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(CONFIG["seed"])
np.random.seed(CONFIG["seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG["seed"])

In [14]:
def calculate_psnr(img1, img2, border=0):
    """Calculates PSNR between two images (range 0-255)."""
    if not isinstance(img1, np.ndarray): img1 = np.array(img1)
    if not isinstance(img2, np.ndarray): img2 = np.array(img2)
    if img1.dtype != np.uint8: img1 = img1.clip(0, 255).astype(np.uint8)
    if img2.dtype != np.uint8: img2 = img2.clip(0, 255).astype(np.uint8)
    if border > 0:
        img1 = img1[border:-border, border:-border, ...]
        img2 = img2[border:-border, border:-border, ...]
    mse = np.mean((img1.astype(np.float64) - img2.astype(np.float64)) ** 2)
    if mse == 0: return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [15]:
class ImageDataset(Dataset):
    @staticmethod
    def _make_key(fname):
        # strip off anything before the first underscore, if present
        return fname.split('_', 1)[-1]

    def __init__(self, lr_dir, hr_dir, image_size, upscale_factor, is_train=True):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.image_size = image_size
        self.upscale_factor = upscale_factor
        self.hr_image_size = image_size * upscale_factor
        self.is_train = is_train

        # Build HR lookup: key → full path
        self.hr_image_files = {}
        for hr_path in glob.glob(os.path.join(hr_dir, '*.*')):
            key = self._make_key(os.path.basename(hr_path))
            self.hr_image_files[key] = hr_path

        # Collect LR files whose key exists in HR
        all_lr = sorted(glob.glob(os.path.join(lr_dir, '*.*')))
        self.lr_image_files = []
        for lr_path in all_lr:
            key = self._make_key(os.path.basename(lr_path))
            if key in self.hr_image_files:
                self.lr_image_files.append(lr_path)

        self.transform = ToTensor()

    def __len__(self):
        return len(self.lr_image_files)

    def __getitem__(self, idx):
        lr_path = self.lr_image_files[idx]
        key     = self._make_key(os.path.basename(lr_path))
        hr_path = self.hr_image_files[key]

        # load
        lr_img = Image.open(lr_path).convert('RGB')
        hr_img = Image.open(hr_path).convert('RGB')
        lr = self.transform(lr_img)
        hr = self.transform(hr_img)

        if self.is_train:
            # random crop
            lh, lw = lr.shape[1:]
            if lh < self.image_size or lw < self.image_size:
                lr = F.interpolate(lr.unsqueeze(0),
                                   size=(self.image_size, self.image_size),
                                   mode='bicubic', align_corners=False
                                  ).squeeze(0)
                hr = F.interpolate(hr.unsqueeze(0),
                                   size=(self.hr_image_size, self.hr_image_size),
                                   mode='bicubic', align_corners=False
                                  ).squeeze(0)
                top, left = 0, 0
            else:
                top  = torch.randint(0, lh - self.image_size + 1, (1,)).item()
                left = torch.randint(0, lw - self.image_size + 1, (1,)).item()
                lr   = lr[:, top:top+self.image_size, left:left+self.image_size]
                top_hr, left_hr = top * self.upscale_factor, left * self.upscale_factor
                hr   = hr[:, top_hr:top_hr+self.hr_image_size,
                             left_hr:left_hr+self.hr_image_size]

            # random horizontal flip
            if torch.rand(1) < 0.5:
                lr = torch.flip(lr, dims=[2])
                hr = torch.flip(hr, dims=[2])

        return lr, hr

In [16]:
class TestImageDataset(Dataset):
    def __init__(self, test_dir):
        self.test_dir = test_dir
        self.image_files = sorted(glob.glob(os.path.join(test_dir, '*.*')))
        self.transform = ToTensor()

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('RGB')
        img_tensor = self.transform(img)
        return img_tensor, os.path.basename(img_path)


In [17]:
print("Defining SwinIR model...")
model = ActualSwinIRNet(
    upscale=CONFIG["upscale_factor"], in_chans=3, img_size=CONFIG["train_image_size"],
    window_size=CONFIG["window_size"], img_range=1.0, depths=[6, 6, 6, 6, 6, 6],
    embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2,
    upsampler='pixelshuffle', resi_connection='1conv'
)

Defining SwinIR model...


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [18]:
if os.path.exists(CONFIG["pretrained_weights_path"]):
    print(f"Loading pre-trained weights from: {CONFIG['pretrained_weights_path']}")
    pretrained_dict = torch.load(CONFIG["pretrained_weights_path"], map_location='cpu')
    load_key = 'params_ema' if 'params_ema' in pretrained_dict else 'params'
    if load_key in pretrained_dict:
        model_dict = model.state_dict()
        # Filter state dict - allow loading even if some keys mismatch (e.g., different head)
        pretrained_dict_filtered = {k: v for k, v in pretrained_dict[load_key].items() if k in model_dict and v.shape == model_dict[k].shape}
        missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict_filtered, strict=False)
        print(f"Loaded {len(pretrained_dict_filtered)} keys.")
        if missing_keys: print(f"Missing keys: {missing_keys}")
        if unexpected_keys: print(f"Unexpected keys: {unexpected_keys}")
        del pretrained_dict # Free memory
    else:
        print(f"Warning: Key '{load_key}' not found in weights file. Training from scratch or check file.")
else:
    print(f"Warning: Pre-trained weights {CONFIG['pretrained_weights_path']} not found. Training from scratch.")

model = model.to(CONFIG["device"])
print("Model definition and weight loading complete.")

Loading pre-trained weights from: /kaggle/working/weights/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth
Loaded 550 keys.


  pretrained_dict = torch.load(CONFIG["pretrained_weights_path"], map_location='cpu')


Model definition and weight loading complete.


In [19]:

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'], eta_min=1e-7)


In [20]:
print("Setting up DataLoaders...")

train_dataset = ImageDataset(
    # Reads Low-Res from train/train
    lr_dir=os.path.join(CONFIG["data_root"], "train", "train"),
    # Reads High-Res from train/gt
    hr_dir=os.path.join(CONFIG["data_root"], "train", "gt"),
    image_size=CONFIG["train_image_size"],
    upscale_factor=CONFIG["upscale_factor"],
    is_train=True
)

val_dataset = ImageDataset(
    # Reads Low-Res from val/val
    lr_dir=os.path.join(CONFIG["data_root"], "val", "val"),
    # Reads High-Res from val/gt
    hr_dir=os.path.join(CONFIG["data_root"], "val", "gt"),
    image_size=CONFIG["train_image_size"], # Using full images for validation
    upscale_factor=CONFIG["upscale_factor"],
    is_train=False
)

test_dataset = TestImageDataset(
    # Reads Low-Res from test
    test_dir=os.path.join(CONFIG["data_root"], "test")
)

Setting up DataLoaders...


In [21]:
print("  ▶  # train images:", len(train_dataset))
print("  ▶  # val   images:", len(val_dataset))
print("  ▶  # test   images:", len(test_dataset))

  ▶  # train images: 1105
  ▶  # val   images: 267
  ▶  # test   images: 60


In [22]:
# Define collate_fn to filter out None samples from dataset errors
def collate_fn(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    if not batch: # If all samples in batch failed
        return torch.Tensor(), torch.Tensor() # Return empty tensors
    return torch.utils.data.dataloader.default_collate(batch)

train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"], pin_memory=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=CONFIG["num_workers"], pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=CONFIG["num_workers"]) # Test loader doesn't need collate_fn usually
print("DataLoaders ready.")

DataLoaders ready.


In [24]:
best_val_psnr = 0.0
to_pil = ToPILImage()

print(f"\nStarting fine-tuning on {CONFIG['device']}...")


for epoch in range(CONFIG["epochs"]):
    start_time = time.time()
    model.train()
    train_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]", leave=False)

    for lr_imgs, hr_imgs in progress_bar:
        if lr_imgs.nelement() == 0: continue # Skip empty batch from collate_fn
        lr_imgs, hr_imgs = lr_imgs.to(CONFIG["device"]), hr_imgs.to(CONFIG["device"])

        optimizer.zero_grad()
        sr_imgs = model(lr_imgs)
        loss = criterion(sr_imgs, hr_imgs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = train_loss / len(train_loader) if len(train_loader) > 0 else 0

    # --- Validation ---
    model.eval()
    val_psnr, val_loss = 0.0, 0.0
    progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]", leave=False)

    with torch.no_grad():
        for lr_imgs, hr_imgs in progress_bar_val:
            if lr_imgs.nelement() == 0: continue
            lr_imgs, hr_imgs = lr_imgs.to(CONFIG["device"]), hr_imgs.to(CONFIG["device"])
            sr_imgs = model(lr_imgs) 

            loss = criterion(sr_imgs, hr_imgs)
            val_loss += loss.item()

            # PSNR Calculation
            sr_img_np = (sr_imgs.squeeze(0).cpu().numpy().transpose(1, 2, 0).clip(0, 1) * 255.0).astype(np.uint8)
            hr_img_np = (hr_imgs.squeeze(0).cpu().numpy().transpose(1, 2, 0).clip(0, 1) * 255.0).astype(np.uint8)
            current_psnr = calculate_psnr(sr_img_np, hr_img_np, border=CONFIG['upscale_factor'])
            val_psnr += current_psnr
            progress_bar_val.set_postfix(psnr=f"{current_psnr:.2f}dB")

    avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
    avg_val_psnr = val_psnr / len(val_loader) if len(val_loader) > 0 else 0
    epoch_time = time.time() - start_time

    print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Time: {epoch_time:.2f}s - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f} - Val PSNR: {avg_val_psnr:.2f} dB")

    scheduler.step()

    # Save best model
    if avg_val_psnr > best_val_psnr:
        best_val_psnr = avg_val_psnr
        best_model_path = os.path.join(CONFIG["output_dir"], CONFIG["model_save_name"])
        torch.save(model.state_dict(), best_model_path)
        print(f"----> Saved new best model to {best_model_path} with PSNR: {best_val_psnr:.2f} dB")
    # Save periodic checkpoint
    elif (epoch + 1) % 10 == 0:
         chkpt_path = os.path.join(CONFIG["output_dir"], f"model_epoch_{epoch+1}.pth")
         torch.save(model.state_dict(), chkpt_path)
         print(f"Saved checkpoint to {chkpt_path}")


print("\nFine-tuning finished.")


Starting fine-tuning on cuda...


                                                                                  

Epoch 1/30 - Time: 271.25s - Train Loss: 0.0115 - Val Loss: 0.0090 - Val PSNR: 37.76 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 37.76 dB


                                                                                  

Epoch 2/30 - Time: 260.91s - Train Loss: 0.0110 - Val Loss: 0.0082 - Val PSNR: 38.20 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 38.20 dB


                                                                                  

Epoch 3/30 - Time: 289.89s - Train Loss: 0.0108 - Val Loss: 0.0091 - Val PSNR: 37.90 dB


                                                                                  

Epoch 4/30 - Time: 272.66s - Train Loss: 0.0104 - Val Loss: 0.0078 - Val PSNR: 38.52 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 38.52 dB


                                                                                  

Epoch 5/30 - Time: 289.51s - Train Loss: 0.0102 - Val Loss: 0.0077 - Val PSNR: 38.57 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 38.57 dB


                                                                                  

Epoch 6/30 - Time: 267.02s - Train Loss: 0.0100 - Val Loss: 0.0076 - Val PSNR: 38.67 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 38.67 dB


                                                                                  

Epoch 7/30 - Time: 272.94s - Train Loss: 0.0101 - Val Loss: 0.0074 - Val PSNR: 38.65 dB


                                                                                  

Epoch 8/30 - Time: 272.10s - Train Loss: 0.0098 - Val Loss: 0.0079 - Val PSNR: 38.29 dB


                                                                                  

Epoch 9/30 - Time: 270.84s - Train Loss: 0.0097 - Val Loss: 0.0072 - Val PSNR: 38.97 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 38.97 dB


                                                                                   

Epoch 10/30 - Time: 270.85s - Train Loss: 0.0096 - Val Loss: 0.0072 - Val PSNR: 38.86 dB
Saved checkpoint to /kaggle/working/output/model_epoch_10.pth


                                                                                   

Epoch 11/30 - Time: 267.69s - Train Loss: 0.0096 - Val Loss: 0.0072 - Val PSNR: 38.88 dB


                                                                                   

Epoch 12/30 - Time: 260.12s - Train Loss: 0.0097 - Val Loss: 0.0071 - Val PSNR: 39.14 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 39.14 dB


                                                                                   

Epoch 13/30 - Time: 260.16s - Train Loss: 0.0095 - Val Loss: 0.0070 - Val PSNR: 39.05 dB


                                                                                   

Epoch 14/30 - Time: 260.55s - Train Loss: 0.0096 - Val Loss: 0.0072 - Val PSNR: 38.76 dB


                                                                                   

Epoch 15/30 - Time: 262.42s - Train Loss: 0.0095 - Val Loss: 0.0071 - Val PSNR: 39.03 dB


                                                                                   

Epoch 16/30 - Time: 260.16s - Train Loss: 0.0095 - Val Loss: 0.0071 - Val PSNR: 39.02 dB


                                                                                   

Epoch 17/30 - Time: 259.34s - Train Loss: 0.0096 - Val Loss: 0.0070 - Val PSNR: 39.09 dB


                                                                                   

Epoch 18/30 - Time: 259.66s - Train Loss: 0.0096 - Val Loss: 0.0070 - Val PSNR: 39.15 dB
----> Saved new best model to /kaggle/working/output/swinir_finetuned_model.pth with PSNR: 39.15 dB


                                                                                   

Epoch 19/30 - Time: 261.21s - Train Loss: 0.0095 - Val Loss: 0.0069 - Val PSNR: 39.03 dB


                                                                                   

Epoch 20/30 - Time: 260.24s - Train Loss: 0.0093 - Val Loss: 0.0069 - Val PSNR: 39.05 dB
Saved checkpoint to /kaggle/working/output/model_epoch_20.pth


                                                                                   

Epoch 21/30 - Time: 260.90s - Train Loss: 0.0093 - Val Loss: 0.0069 - Val PSNR: 39.15 dB


                                                                                   

Epoch 22/30 - Time: 259.88s - Train Loss: 0.0093 - Val Loss: 0.0069 - Val PSNR: 39.04 dB


                                                                                   

Epoch 23/30 - Time: 262.17s - Train Loss: 0.0093 - Val Loss: 0.0069 - Val PSNR: 39.05 dB


                                                                                   

Epoch 24/30 - Time: 260.87s - Train Loss: 0.0094 - Val Loss: 0.0069 - Val PSNR: 39.14 dB


                                                                                   

Epoch 25/30 - Time: 259.69s - Train Loss: 0.0094 - Val Loss: 0.0069 - Val PSNR: 39.10 dB


                                                                                   

Epoch 26/30 - Time: 260.62s - Train Loss: 0.0094 - Val Loss: 0.0069 - Val PSNR: 39.05 dB


                                                                                   

Epoch 27/30 - Time: 259.58s - Train Loss: 0.0095 - Val Loss: 0.0069 - Val PSNR: 39.10 dB


                                                                                   

Epoch 28/30 - Time: 264.43s - Train Loss: 0.0093 - Val Loss: 0.0069 - Val PSNR: 39.07 dB


                                                                                   

Epoch 29/30 - Time: 276.08s - Train Loss: 0.0095 - Val Loss: 0.0069 - Val PSNR: 39.11 dB


                                                                                   

Epoch 30/30 - Time: 273.60s - Train Loss: 0.0095 - Val Loss: 0.0069 - Val PSNR: 39.10 dB
Saved checkpoint to /kaggle/working/output/model_epoch_30.pth

Fine-tuning finished.




In [27]:
print("Starting inference on test set...")

best_model_path = os.path.join(CONFIG["output_dir"], CONFIG["model_save_name"])
if not os.path.exists(best_model_path):
    print(f"Error: Best model not found at {best_model_path}. Trying last checkpoint if available or exiting.")
    sys.exit(1)

print(f"Loading best fine-tuned model from {best_model_path}")
model_inference = ActualSwinIRNet(
    upscale=CONFIG["upscale_factor"],
    in_chans=3, 
    img_size=CONFIG["train_image_size"],
    window_size=CONFIG["window_size"], 
    img_range=1.0, 
    depths=[6, 6, 6, 6, 6, 6],
    embed_dim=180, 
    num_heads=[6, 6, 6, 6, 6, 6], 
    mlp_ratio=2,
    upsampler='pixelshuffle', 
    resi_connection='1conv'
)

model_inference.load_state_dict(torch.load(best_model_path, map_location=CONFIG["device"]))
model_inference.to(CONFIG["device"])
model_inference.eval()

progress_bar_test = tqdm(test_loader, desc="[Inference]")

with torch.no_grad():
    for lr_imgs, filenames in progress_bar_test:
        if lr_imgs.nelement() == 0: continue
        lr_imgs = lr_imgs.to(CONFIG["device"])
        sr_imgs = model_inference(lr_imgs) # Output tensor [0, 1]

        for i in range(sr_imgs.shape[0]):
            sr_img_tensor = sr_imgs[i].cpu()
            sr_img_pil = to_pil(sr_img_tensor.clamp(0, 1))
            output_filename = os.path.join(CONFIG["test_output_dir"], filenames[i])
            try:
                sr_img_pil.save(output_filename)
            except Exception as e:
                print(f"Error saving image {output_filename}: {e}")


Starting inference on test set...
Loading best fine-tuned model from /kaggle/working/output/swinir_finetuned_model.pth


  model_inference.load_state_dict(torch.load(best_model_path, map_location=CONFIG["device"]))
[Inference]: 100%|██████████| 60/60 [00:42<00:00,  1.43it/s]


In [31]:
import os
import numpy as np
import pandas as pd
from PIL import Image

def images_to_csv(folder_path, output_csv):
    data_rows = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert('L') 
            image_array = np.array(image).flatten()[::8]
            # Replace 'test_' with 'gt_' in the ID
            image_id = filename.split('.')[0].replace('test_', 'gt_')
            data_rows.append([image_id, *image_array])
    column_names = ['ID'] + [f'pixel_{i}' for i in range(len(data_rows[0]) - 1)]
    df = pd.DataFrame(data_rows, columns=column_names)
    df.to_csv(output_csv, index=False)
    print(f'Successfully saved to {output_csv}')

folder_path = CONFIG['test_output_dir']  # Use predicted images, not original test images
output_csv = 'submission.csv'
images_to_csv(folder_path, output_csv)