# Env Prep

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [1]:
!nvidia-smi

Thu Aug 14 11:11:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   43C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Config

In [2]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal, Union, Dict
import os
import sys
import subprocess
import shutil

EnvModeType = Literal["colab", "remote", "local"]


@dataclass(frozen=True)
class Config:
    env_mode: EnvModeType
    selected_subsets: List[str] = field(
        default_factory=lambda: ["SUDOKU_4", "SUDOKU_5", "SUDOKU_6"]
    )

    # Derived fields (init=False)
    data_root: Path = field(init=False)
    base_dir: Path = field(init=False)
    taco_raw_dir: Path = field(init=False)
    taco_file_paths: List[Path] = field(init=False)
    normalized_sets_dir: Path = field(init=False)
    finetune_dir: Path = field(init=False)
    train_dir: Path = field(init=False)
    val_dir: Path = field(init=False)
    test_dir: Path = field(init=False)

    def __post_init__(self) -> None:
        # Set data_root based on env_mode (defaults; override in factory if needed)
        data_root_map = {
            # "local": Path.home() / "mnt/shared",
            "local": Path("/mnt/shared"),
            # "remote": Path("/mnt/data"),
            "remote": Path.home(),
            "colab": Path("/content/drive/MyDrive"),
        }
        object.__setattr__(
            self, "data_root", data_root_map.get(self.env_mode, Path.cwd())
        )

        # Derive other paths
        object.__setattr__(
            self, "base_dir", self.data_root / "datasets/sen2venus"
        )
        # object.__setattr__(self, "taco_raw_dir", self.base_dir / "TACO_raw_data")
        # object.__setattr__(
        #     self,
        #     "taco_file_paths",
        #     [self.taco_raw_dir / f"{subset}.taco" for subset in self.selected_subsets],
        # )
        object.__setattr__(
            self, "normalized_sets_dir", self.base_dir / "normalized_sets"
        )
        object.__setattr__(self, "finetune_dir", self.base_dir / "finetune")
        object.__setattr__(self, "train_dir", self.normalized_sets_dir / "train")
        object.__setattr__(self, "val_dir", self.normalized_sets_dir / "val")
        object.__setattr__(self, "test_dir", self.normalized_sets_dir / "test")

    def validate(self) -> None:
        """Validate config paths exist; raise errors otherwise."""
        missing_paths = []
        for attr in [
            "data_root",
            "base_dir",
            # "taco_raw_dir",
            "normalized_sets_dir",
            "finetune_dir",
            "train_dir",
            "val_dir",
            "test_dir",
        ]:
            path: Path = getattr(self, attr)
            if not path.exists():
                missing_paths.append(str(path))
        if missing_paths:
            raise ValueError(f"Missing paths: {', '.join(missing_paths)}")
        # for file_path in self.taco_file_paths:
        #     if not file_path.exists():
        #         missing_paths.append(str(file_path))
        if missing_paths:
            raise ValueError(f"Missing taco files: {', '.join(missing_paths)}")


def setup_environment(env_mode: EnvModeType) -> None:
    """Perform environment-specific setup (side effects isolated here)."""
    if env_mode == "colab":
        try:
            import super_image
        except ImportError:
            print("Installing 'super-image'...")
            try:
                subprocess.run(["pip", "install", "--quiet", "super-image"], check=True)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to install super-image: {e}")

        try:
            from google.colab import drive

            drive.mount("/content/drive", force_remount=True)
        except ImportError:
            raise RuntimeError("Google Colab module not found. Are you in Colab?")
        except Exception as e:
            raise RuntimeError(f"Failed to mount Google Drive: {e}")

        # Optional: Copy data to local /content for faster I/O in Colab
        colab_vm_dir = Path("/content/taco_normalized")
        if not colab_vm_dir.exists():
            print("Copying normalized data to local Colab storage for performance...")
            shutil.copytree(
                Path(
                    "/content/drive/MyDrive/datasets/sen2venus/normalized_sets"
                ),
                colab_vm_dir,
            )
            print("Copy complete.")
        # Avoid os.chdir; let users handle working dir if needed

    elif env_mode == "remote":
        print("Remote environment detected. No specific setup needed.")

    elif env_mode == "local":
        print("Local environment detected. Ensuring dependencies...")


def create_config(env_mode: EnvModeType | None = None) -> Config:
    """Factory to create and setup config based on detected environment."""
    if env_mode is None:
        if "google.colab" in sys.modules:
            env_mode = "colab"
        elif "REMOTE_ENV_VAR" in os.environ:
            env_mode = "remote"
        else:
            env_mode = "local"

    setup_environment(env_mode)
    config = Config(env_mode=env_mode)
    config.validate()
    return config


In [3]:
config = create_config('colab')
print(config.data_root)

Installing 'super-image'...
Mounted at /content/drive
Copying normalized data to local Colab storage for performance...
Copy complete.
/content/drive/MyDrive


## Project imports

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader

from super_image import Trainer, PreTrainedModel, TrainingArguments
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter, compute_metrics

import numpy as np
from tqdm.auto import tqdm
from collections import OrderedDict

# Data Loader

In [6]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """
    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx)->Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample['lr'], squeezed_sample['hr']


# Step 1: Quantitative Evaluation on the Test Set

### 1.1 Load the Fine-Tuned Model

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_data_dir = config.test_dir
checkpoint_path = config.finetune_dir / 'edsr_base/best_model_checkpoint.pt'


model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)

# If multiple GPUs are available, wrap the model in DataParallel
# This ensures the model's structure matches the training environment.
if torch.cuda.device_count() > 1:
  print(f"Using {torch.cuda.device_count()} GPUs. Wrapping model in DataParallel.")
  model = torch.nn.DataParallel(model)

model.to(device)

# load the state dictionary from best checkpoint
print(f"Loading best model checkpoint from: {checkpoint_path}")
try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_state_dict = checkpoint['model_state_dict']

    # --- Logic to handle all DataParallel cases ---
    is_model_parallel = isinstance(model, torch.nn.DataParallel)
    is_checkpoint_parallel = list(model_state_dict.keys())[0].startswith('module.')

    final_state_dict = OrderedDict()

    if is_model_parallel and not is_checkpoint_parallel:
        # If the current model is parallel but the checkpoint isn't, add "module." prefix
        print("Model is parallel, checkpoint is not. Adding 'module.' prefix to keys...")
        for k, v in model_state_dict.items():
            final_state_dict['module.' + k] = v
    elif not is_model_parallel and is_checkpoint_parallel:
        # If the checkpoint is parallel but the current model isn't, strip "module." prefix
        print("Checkpoint is parallel, model is not. Stripping 'module.' prefix from keys...")
        for k, v in model_state_dict.items():
            final_state_dict[k[7:]] = v # k[7:] removes 'module.'
    else:
        # If they are both parallel or both not parallel, the keys match already.
        print("Model and checkpoint states match. Loading directly.")
        final_state_dict = model_state_dict

    # Load the correctly formatted state dictionary
    model.load_state_dict(final_state_dict)

    print(f"\n✅Successfully loaded model from epoch {checkpoint['epoch']} with best validation PSNR {checkpoint['best_metric']:.4f}")

except FileNotFoundError:
    print(f"❌ ERROR: Checkpoint file not found at '{checkpoint_path}'. Please verify the path.")
except Exception as e:
    print(f"❌ ERROR: An error occurred while loading the checkpoint: {e}")

model.eval()

https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_2x.pt
Loading best model checkpoint from: /content/drive/MyDrive/datasets/sen2venus/finetune/edsr_base/best_model_checkpoint.pt
Model and checkpoint states match. Loading directly.

✅Successfully loaded model from epoch 14 with best validation PSNR 47.3143


DataParallel(
  (module): EdsrModel(
    (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (head): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (body): Sequential(
      (0): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
         

### 1.2 Prepare the Test DataLoader

In [16]:
try:
    test_data_dir = Path("/content/taco_normalized/test") if config.env_mode == "colab" else config.test_dir

    test_dataset = PreNormalizedDataset(test_data_dir)
    # For evaluation, batch size is typically 1 to measure per-image metrics.
    test_dataloader = DataLoader(test_dataset, batch_size=1)
    print(f"Loaded test dataset with {len(test_dataset)} samples.")
except Exception as e:
    print(f"❌ ERROR: Failed to load the test dataset from '{test_data_dir}': {e}")

Initialized dataset from /content/taco_normalized/test with 556 samples across 1 shards.
Loaded test dataset with 556 samples.


### 1.3 Run Evaluation Loop and Report

In [17]:
from super_image.data import EvalMetrics

EvalMetrics().evaluate(model, test_dataset)
print(f"Total images evaluated: {len(test_dataset)}")


Evaluating dataset:   0%|          | 0/556 [00:00<?, ?it/s]

scale:2      eval psnr: 47.46     ssim: 0.9910
Total images evaluated: 556
