# Env Prep

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!nvidia-smi

## Config

In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal, Union, Dict
import os
import sys
import subprocess
import shutil

EnvModeType = Literal["colab", "remote", "local"]


@dataclass(frozen=True)
class Config:
    env_mode: EnvModeType
    selected_subsets: List[str] = field(
        default_factory=lambda: ["SUDOKU_4", "SUDOKU_5", "SUDOKU_6"]
    )

    # Derived fields (init=False)
    data_root: Path = field(init=False)
    base_dir: Path = field(init=False)
    taco_raw_dir: Path = field(init=False)
    taco_file_paths: List[Path] = field(init=False)
    normalized_sets_dir: Path = field(init=False)
    finetune_dir: Path = field(init=False)
    train_dir: Path = field(init=False)
    val_dir: Path = field(init=False)
    test_dir: Path = field(init=False)

    def __post_init__(self) -> None:
        # Set data_root based on env_mode (defaults; override in factory if needed)
        data_root_map = {
            # "local": Path.home() / "mnt/shared",
            "local": Path("/mnt/shared"),
            # "remote": Path("/mnt/data"),
            "remote": Path.home(),
            "colab": Path("/content/drive/MyDrive"),
        }
        object.__setattr__(
            self, "data_root", data_root_map.get(self.env_mode, Path.cwd())
        )

        # Derive other paths
        object.__setattr__(
            self, "base_dir", self.data_root / "datasets/sen2venus"
        )  
        # object.__setattr__(self, "taco_raw_dir", self.base_dir / "TACO_raw_data")
        # object.__setattr__(
        #     self,
        #     "taco_file_paths",
        #     [self.taco_raw_dir / f"{subset}.taco" for subset in self.selected_subsets],
        # )
        object.__setattr__(
            self, "normalized_sets_dir", self.base_dir / "normalized_sets"
        )
        object.__setattr__(self, "finetune_dir", self.base_dir / "finetune")
        object.__setattr__(self, "train_dir", self.normalized_sets_dir / "train")
        object.__setattr__(self, "val_dir", self.normalized_sets_dir / "val")
        object.__setattr__(self, "test_dir", self.normalized_sets_dir / "test")

    def validate(self) -> None:
        """Validate config paths exist; raise errors otherwise."""
        missing_paths = []
        for attr in [
            "data_root",
            "base_dir",
            # "taco_raw_dir",
            "normalized_sets_dir",
            "finetune_dir",
            "train_dir",
            "val_dir",
            "test_dir",
        ]:
            path: Path = getattr(self, attr)
            if not path.exists():
                missing_paths.append(str(path))
        if missing_paths:
            raise ValueError(f"Missing paths: {', '.join(missing_paths)}")
        # for file_path in self.taco_file_paths:
        #     if not file_path.exists():
        #         missing_paths.append(str(file_path))
        if missing_paths:
            raise ValueError(f"Missing taco files: {', '.join(missing_paths)}")


def setup_environment(env_mode: EnvModeType) -> None:
    """Perform environment-specific setup (side effects isolated here)."""
    if env_mode == "colab":
        try:
            import super_image
        except ImportError:
            print("Installing 'super-image'...")
            try:
                subprocess.run(["pip", "install", "--quiet", "super-image"], check=True)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to install super-image: {e}")

        try:
            from google.colab import drive

            drive.mount("/content/drive", force_remount=True)
        except ImportError:
            raise RuntimeError("Google Colab module not found. Are you in Colab?")
        except Exception as e:
            raise RuntimeError(f"Failed to mount Google Drive: {e}")

        # Optional: Copy data to local /content for faster I/O in Colab
        colab_vm_dir = Path("/content/taco_normalized")
        if not colab_vm_dir.exists():
            print("Copying normalized data to local Colab storage for performance...")
            shutil.copytree(
                Path(
                    "/content/drive/MyDrive/datasets/sen2venus/normalized_sets"
                ),
                colab_vm_dir,
            )
            print("Copy complete.")
        # Avoid os.chdir; let users handle working dir if needed

    elif env_mode == "remote":
        print("Remote environment detected. No specific setup needed.")

    elif env_mode == "local":
        print("Local environment detected. Ensuring dependencies...")


def create_config(env_mode: EnvModeType | None = None) -> Config:
    """Factory to create and setup config based on detected environment."""
    if env_mode is None:
        if "google.colab" in sys.modules:
            env_mode = "colab"
        elif "REMOTE_ENV_VAR" in os.environ:  # Example detection for remote
            env_mode = "remote"
        else:
            env_mode = "local"

    setup_environment(env_mode)
    config = Config(env_mode=env_mode)
    config.validate()  # Ensure everything is ready
    return config


In [None]:
config = create_config()  
print(config.data_root)

## Project imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

from super_image import Trainer, PreTrainedModel, TrainingArguments
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter, compute_metrics

import numpy as np
from tqdm.auto import tqdm
from collections import OrderedDict

# Data Loader

In [None]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """
    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx)->Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample['lr'], squeezed_sample['hr']


# Step 1: Quantitative Evaluation on the Test Set

### 1.1 Load the Fine-Tuned Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_data_dir = config.test_dir
checkpoint_path = config.finetune_dir / 'edsr_base/best_model_checkpoint.pt'

# --- Step 3: Load the Fine-Tuned Model ---
# First, instantiate the base architecture
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)

# Now, load the state dictionary from our best checkpoint
print(f"Loading best model checkpoint from: {checkpoint_path}")
try:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # This logic handles cases where the model was saved with DataParallel wrapper
    model_state_dict = checkpoint['model_state_dict']
    if list(model_state_dict.keys())[0].startswith('module.'):
        print("DataParallel wrapper detected. Unwrapping state dictionary...")
        new_state_dict = OrderedDict()
        for k, v in model_state_dict.items():
            name = k[7:] # remove `module.` prefix
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(model_state_dict)
        
    print(f"Successfully loaded model from epoch {checkpoint['epoch']} with best validation PSNR {checkpoint['best_metric']:.4f}")

except FileNotFoundError:
    print(f"❌ ERROR: Checkpoint file not found at '{checkpoint_path}'. Please verify the path.")
except Exception as e:
    print(f"❌ ERROR: An error occurred while loading the checkpoint: {e}")

model.to(device)
model.eval() # Set the model to evaluation mode

### 1.2 Prepare the Test DataLoader

In [None]:
try:
    test_dataset = PreNormalizedDataset(test_data_dir)
    # For evaluation, batch size is typically 1 to measure per-image metrics.
    test_dataloader = DataLoader(test_dataset, batch_size=1) 
    print(f"Loaded test dataset with {len(test_dataset)} samples.")
except Exception as e:
    print(f"❌ ERROR: Failed to load the test dataset from '{test_data_dir}': {e}")

### 1.3 Run Evaluation Loop and Report

In [None]:
# Initialize meters to accumulate the scores
test_psnr = AverageMeter()
test_ssim = AverageMeter()

# Set the scale from the model's configuration
scale = model.config.scale

with torch.no_grad(): # Disable gradient calculations for efficiency
    for data in tqdm(test_dataloader, desc="Evaluating on Test Set"):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        preds = model(inputs)
        
        # The compute_metrics function expects a dictionary
        metrics = compute_metrics({'predictions': preds, 'labels': labels}, scale=scale)
        
        test_psnr.update(metrics['psnr'], len(inputs))
        test_ssim.update(metrics['ssim'], len(inputs))

# --- Step 6: Report Final Scores ---
print("\n--- Final Test Set Evaluation Complete ---")
print(f"Total images evaluated: {len(test_dataset)}")
print(f"Final Test Set PSNR: {test_psnr.avg:.4f}")
print(f"Final Test Set SSIM: {test_ssim.avg:.4f}")