# Preparing the Enviorment


## Check out Colab Instance Region

In [1]:
!curl ipinfo.io

{
  "ip": "34.143.181.73",
  "hostname": "73.181.143.34.bc.googleusercontent.com",
  "city": "Singapore",
  "region": "Singapore",
  "country": "SG",
  "loc": "1.2897,103.8501",
  "org": "AS396982 Google LLC",
  "postal": "018989",
  "timezone": "Asia/Singapore",
  "readme": "https://ipinfo.io/missingauth"
}

In [None]:
!nvidia-smi

Sat Aug  9 18:48:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   77C    P0             30W /   70W |    4382MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Intsall Essential Packaeges

`super-image` library is built on top of **Hugging Face**'s `transformers` and `datasets`

In [2]:
!pip install super-image -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.9/95.9 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m89.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m71.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m841.7 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Imports

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, lr_scheduler


from super_image import Trainer, PreTrainedModel, TrainingArguments
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter

In [4]:
import os
from pathlib import Path
from typing import List, Union, Dict
from dataclasses import dataclass

from tqdm.auto import tqdm
import numpy as np

## Paths and Directories

In [5]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
BASE_DIR = Path('/content/drive/MyDrive/datasets/sen2venus')

TACO_RAW_DIR = BASE_DIR / 'TACO_raw_data'
os.makedirs(TACO_RAW_DIR, exist_ok=True)
print(f"Data will be saved to: {TACO_RAW_DIR}")

SELECTED_SUBSETS = [
    "SUDOUE-4",
    "SUDOUE-5",
    "SUDOUE-6"
]
TACO_FILE_PATHS = [TACO_RAW_DIR / f"{site_name}.taco" for site_name in SELECTED_SUBSETS]


NORMALIZED_SETS_DIR = BASE_DIR / 'normalized_sets'
os.makedirs(NORMALIZED_SETS_DIR, exist_ok=True)
print(f"Normalaized sets are retrieved from:\n\t {NORMALIZED_SETS_DIR}")

# TRAIN_SAVE_DIR = NORMALIZED_SETS_DIR / 'train'
# os.makedirs(TRAIN_SAVE_DIR, exist_ok=True)
# print(f"Train data will be saved to:\n\t {TRAIN_SAVE_DIR}")

# VAL_SAVE_DIR = NORMALIZED_SETS_DIR / 'val'
# os.makedirs(VAL_SAVE_DIR, exist_ok=True)
# print(f"Validation data will be saved to:\n\t {VAL_SAVE_DIR}")

# TEST_SAVE_DIR = NORMALIZED_SETS_DIR / 'test'
# os.makedirs(TEST_SAVE_DIR, exist_ok=True)
# print(f"Test data will be saved to:\n\t {TEST_SAVE_DIR}")

# essential for resuming training and saving final model.
FINETUNR_SAVE_DIR = BASE_DIR / 'edsr_finetune'
os.makedirs(FINETUNR_SAVE_DIR, exist_ok=True)
print(f"Finetuning data including checkpoints and logs will be saved to:\n\t{FINETUNR_SAVE_DIR}")

Data will be saved to: /content/drive/MyDrive/datasets/sen2venus/TACO_raw_data
Normalaized sets are retrieved from:
	 /content/drive/MyDrive/datasets/sen2venus/normalized_sets
Finetuning data including checkpoints and logs will be saved to:
	/content/drive/MyDrive/datasets/sen2venus/edsr_finetune


In [7]:
VM_DIR = '/content/TACO_Normalized'

print("--- Starting Data Transfer ---")
# Create the local directory if it doesn't exist
if not os.path.exists(VM_DIR):
    print(f"Copying data from {NORMALIZED_SETS_DIR} to {VM_DIR}...")
    # Use the -q flag for a quiet copy to avoid flooding your output
    !cp -r "{NORMALIZED_SETS_DIR}" "{VM_DIR}"
    print("Data transfer complete.")
else:
    print("Data already exists locally.")


VM_DIR = Path(VM_DIR)
TRAIN_VM_DIR = VM_DIR / 'train'
VAL_VM_DIR = VM_DIR / 'val'
TEST_VM_DIR = VM_DIR / 'test'

--- Starting Data Transfer ---
Copying data from /content/drive/MyDrive/datasets/sen2venus/normalized_sets to /content/TACO_Normalized...
Data transfer complete.


# Step 1: Define PyTorch Datasets & Dataloaders

In [8]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """
    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx)->Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample['lr'], squeezed_sample['hr']


## Dataloader Instantiation

In [9]:
train_dataset = PreNormalizedDataset(TRAIN_VM_DIR)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Initialized dataset from /content/TACO_Normalized/train with 4436 samples across 5 shards.


In [None]:
print("---Verifying  batch shape:")

lr_batch, hr_batch = next(iter(train_loader))

print(f"Verification successful!")
print(f"LR batch shape: {lr_batch.shape}")
print(f"HR batch shape: {hr_batch.shape}")
print(f"LR batch dtype: {lr_batch.dtype}")
print(f"HR batch dtype: {hr_batch.dtype}")

---Verifying  dataset output format:
dict_keys(['pixel_values', 'labels'])
LR shape: torch.Size([3, 128, 128])
HR shape: torch.Size([3, 256, 256])
---Verifying  batch shape:
Verification successful!
LR batch shape: torch.Size([16, 3, 128, 128])
HR batch shape: torch.Size([16, 3, 256, 256])
LR batch dtype: torch.float32
HR batch dtype: torch.float32


In [10]:
val_dataset = PreNormalizedDataset(VAL_VM_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/TACO_Normalized/val with 554 samples across 1 shards.


In [None]:
test_dataset = PreNormalizedDataset(TEST_SAVE_DIR)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/drive/MyDrive/datasets/sen2venus/normalized_sets/test with 556 samples across 1 shards.
Loaded 556 test samples.


# Step 2: Load the Pre-trained EDSR Model

**Objectives**:



1.   Loading a well-known, **pre-trained** architecture (edsr-base) specifically configured for **2x super-resolution**.
2.   Confirming that the model accepts data batches and produces outputs of the correct shape ([16, 3, 256, 256]).

## 2.1 Instantiate and Inspect the pre-trained EDSR model

In [11]:
# The 'from_pretrained' method downloads the model configuration and weights.
# We must specify our desired scale factor.
# (LR: 128x128, HR: 256x256), -> scale is 2.
scale = 2
model_id = 'eugenesiow/edsr-base'
model = EdsrModel.from_pretrained(model_id, scale=scale)

# Inspect the model architecture
print("Model architecture loaded successfully:")
print(model)

config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

pytorch_model_2x.pt:   0%|          | 0.00/5.51M [00:00<?, ?B/s]

https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_2x.pt
Model architecture loaded successfully:
DataParallel(
  (module): EdsrModel(
    (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (head): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (body): Sequential(
      (0): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResBlock(
        (body): Sequential(
       

## 2.2 Sanity Check: Pass one batch of data through the model

In [None]:
# a crucial test to ensure the input/output dimensions are compatible.
print("\nPerforming a forward pass sanity check...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # Set to evaluation mode for this check

with torch.no_grad():
    # Get a single batch from our dataloader
    lr_batch, hr_batch = next(iter(train_loader))

    # Move the batch to the same device as the model
    lr_batch = lr_batch.to(device)

    # Perform a forward pass
    predictions = model(lr_batch)

    print(f"Sanity check successful!")
    print(f"Running on device: {device}")
    print(f"Model Input Shape (LR): {lr_batch.shape}")
    print(f"Model Output Shape (Predictions): {predictions.shape}")
    print(f"Target Shape (HR): {hr_batch.shape}")

# Compare output shape with the target High-Resolution shape
assert predictions.shape == hr_batch.shape, "Model output shape does not match target HR shape!"
print("Output shape matches target shape. Ready for training.")


Performing a forward pass sanity check...
Sanity check successful!
Running on device: cuda
Model Input Shape (LR): torch.Size([16, 3, 128, 128])
Model Output Shape (Predictions): torch.Size([16, 3, 256, 256])
Target Shape (HR): torch.Size([16, 3, 256, 256])
Output shape matches target shape. Ready for training.


# Step 3: Configure and Launch the Trainer

## Custum Trainer with Checkpoints

In [12]:
import copy
from super_image.file_utils import WEIGHTS_NAME, WEIGHTS_NAME_SCALE


class CustomResumableTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimizer = None
        self.scheduler = None

    def save_checkpoint(self, epoch, global_step, is_best=False):
        output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        filename = 'best_model_checkpoint.pt' if is_best else 'latest_step_checkpoint.pt'
        checkpoint_path = os.path.join(output_dir, filename)
        state = {
            'epoch': epoch, 'global_step': global_step, 'best_metric': self.best_metric,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }
        torch.save(state, checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path} (step {global_step})")

    def load_checkpoint(self):
        output_dir = self.args.output_dir
        checkpoint_path = os.path.join(output_dir, 'latest_step_checkpoint.pt')
        start_epoch, global_step = 0, 0
        if not os.path.exists(checkpoint_path):
            logger.warning("No checkpoint found. Starting from scratch.")
            return start_epoch, global_step
        try:
            state = torch.load(checkpoint_path, map_location=self.args.device)
            self.model.load_state_dict(state['model_state_dict'])
            if self.optimizer is None: self._create_optimizer_and_scheduler()
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            self.scheduler.load_state_dict(state['scheduler_state_dict'])
            self.best_metric = state.get('best_metric', 0.0)
            start_epoch = state['epoch']
            global_step = state['global_step']
            logger.info(f"Successfully loaded checkpoint. Resuming from epoch {start_epoch}, step {global_step}.")
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}. Starting from scratch.")
            start_epoch, global_step = 0, 0
        return start_epoch, global_step

    def save_model(self, output_dir: str = None):
      """
      Overrides the faulty base save_model method to correctly handle
      models wrapped in nn.DataParallel.
      """
      output_dir = output_dir if output_dir is not None else self.args.output_dir
      os.makedirs(output_dir, exist_ok=True)

      # Determine if the model is the raw model or a wrapped one
      model_to_save = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model

      # Now, we can safely call save_pretrained on the actual model
      if isinstance(model_to_save, PreTrainedModel):
          model_to_save.save_pretrained(output_dir)
      else:
          # Fallback for non-PreTrainedModel, though EDSR is one.
          # This part is for full compatibility with the original's logic.
          logger.warning("Saving a model that is not a PreTrainedModel.")
          scale = model_to_save.config.scale
          if scale is not None:
              weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
          else:
              weights_name = WEIGHTS_NAME
          weights = copy.deepcopy(model_to_save.state_dict())
          torch.save(weights, os.path.join(output_dir, weights_name))

    def load_checkpoint(self):
        """Loads the most recent checkpoint available."""
        output_dir = self.args.output_dir
        step_checkpoint_path = os.path.join(output_dir, 'latest_step_checkpoint.pt')
        best_checkpoint_path = os.path.join(output_dir, 'best_model_checkpoint.pt')
        start_epoch = 0

        # Prioritize the very latest step-based checkpoint for resumption
        checkpoint_path = None
        if os.path.exists(step_checkpoint_path):
            checkpoint_path = step_checkpoint_path
        elif os.path.exists(best_checkpoint_path):
            checkpoint_path = best_checkpoint_path

        if checkpoint_path is None:
            logger.warning("No checkpoint found. Starting from scratch.")
            return start_epoch

        try:
            state = torch.load(checkpoint_path, map_location=self.args.device)
            self.model.load_state_dict(state['model_state_dict'])
            if self.optimizer is None: self._create_optimizer_and_scheduler()
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            self.scheduler.load_state_dict(state['scheduler_state_dict'])
            self.best_metric = state.get('best_metric', 0.0)
            start_epoch = state['epoch'] # Resume from the same epoch
            logger.info(f"Successfully loaded checkpoint '{os.path.basename(checkpoint_path)}'. Resuming from epoch {start_epoch}.")
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}. Starting from scratch.")
            start_epoch = 0
        return start_epoch

    def _create_optimizer_and_scheduler(self):
        self.optimizer = Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=999999, gamma=1.0) # Dummy scheduler, we control LR manually

    def train(self, **kwargs):
        """Complete, resumable training loop with warm-up and step-based saving."""
        self._create_optimizer_and_scheduler()
        start_epoch = self.load_checkpoint()
        train_dataloader = self.get_train_dataloader()

        # Correctly initialize global_step based on the loaded epoch
        global_step = start_epoch * len(train_dataloader)

        for epoch in range(start_epoch, self.args.num_train_epochs):
            self.model.train()
            epoch_losses = AverageMeter()
            # Reset the tqdm progress bar for each epoch
            with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch}/{self.args.num_train_epochs - 1}') as t:
                for data in train_dataloader:
                    # --- Learning Rate Scheduling (Warm-up + Decay) ---
                    if global_step < self.args.warmup_steps:
                        # Warm-up phase
                        lr_scale = float(global_step) / float(self.args.warmup_steps) if self.args.warmup_steps > 0 else 1.0
                        new_lr = self.args.learning_rate * lr_scale
                    else:
                        # Post-warm-up decay phase
                        # --- THIS IS THE CORRECTED LINE ---
                        divisor = max(1, int(self.args.num_train_epochs * 0.8))
                        new_lr = self.args.learning_rate * (0.1 ** (epoch // divisor))

                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = new_lr

                    inputs, labels = data
                    inputs, labels = inputs.to(self.args.device), labels.to(self.args.device)
                    preds = self.model(inputs)
                    loss = torch.nn.L1Loss()(preds, labels)
                    epoch_losses.update(loss.item(), len(inputs))
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    global_step += 1
                    t.set_postfix(loss=f'{epoch_losses.avg:.6f}', lr=f'{new_lr:.2e}')
                    t.update(1)

                    # --- Step-Based Checkpointing ---
                    if global_step > 0 and global_step % self.args.save_steps == 0:
                        self.save_checkpoint(epoch, is_best=False)

            # --- Epoch-Based Evaluation and Best Model Saving ---
            self.eval(epoch)


## Trainer Config

In [13]:
@dataclass
class CustomTrainingArguments(TrainingArguments):
    warmup_steps: int = 0
    save_steps: int = 100 # Default to saving every 100 steps


## Dry Run Test

**What this accomplishes:**

*   The training loop will run on only 32 images.
*   The `save_checkpoint` method will be called multiple times (for `is_best=False`).
*   The `eval` loop will run after the tiny epoch.
*   The overridden `save_model` will be called.
*   The entire process should take **less than a minute**, giving you immediate feedback on whether your fix worked.

In [None]:
from torch.utils.data import Subset

print("--- Performing a quick Dry Run to verify the fix ---")

# 1. Create a tiny subset of your training and validation data
#    This will only use the first 32 samples.
dry_run_train_dataset = Subset(train_dataset, range(32))
dry_run_val_dataset = Subset(val_dataset, range(32))

# 2. Create special "dry run" training arguments
dry_run_args = CustomTrainingArguments(
    output_dir=BASE_DIR /'edsr_dry_run_test', # Use a separate test directory
    learning_rate=1e-4,
    num_train_epochs=1,          # We only need one epoch to test the save logic
    per_device_train_batch_size=8, # Use a small batch size
    save_steps=2,                # Save a checkpoint every 2 steps to test this logic
    warmup_steps=1               # Test the warm-up logic
)

# 3. Instantiate your trainer with the dry run data and args
dry_run_trainer = CustomResumableTrainer(
    model=model,  # Use the same model
    args=dry_run_args,
    train_dataset=dry_run_train_dataset,
    eval_dataset=dry_run_val_dataset
)

# 4. Run the training. This should be very fast!
try:
    dry_run_trainer.train()
    print("\n✅ Dry Run Successful! Proceed with full training.")
except Exception as e:
    print(f"\n❌ Dry Run Failed: {e}")

# --- END OF DRY RUN ---```


--- Performing a quick Dry Run to verify the fix ---


Epoch 0/0:   0%|          | 0/4 [00:00<?, ?it/s]

scale:2      eval psnr: 44.31     ssim: 0.9851
best epoch: 0, psnr: 44.310272, ssim: 0.985123

✅ Dry Run Successful! Proceed with full training.


## Start Training

In [14]:
training_args = CustomTrainingArguments(
    output_dir=FINETUNR_SAVE_DIR,

    # --- Core parameters that are fully functional ---
    num_train_epochs=15,          # Controls training length and the hardcoded LR decay
    learning_rate=1e-4,           # Sets the initial learning rate
    per_device_train_batch_size=24, # Controls training batch size

    # --- Technical parameters that are functional ---
    seed=42,                      # For reproducibility
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
    save_steps=100,      # Save a recovery checkpoint every 100 steps
    warmup_steps=500     # Use a 500-step learning rate warm-up
)

In [15]:
trainer = CustomResumableTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
# Start Training with full confidence in the underlying process.
print("Starting model fine-tuning ")
trainer.train()

print(f"\nTraining complete. The best model was found at epoch {trainer.best_epoch} "
      f"with a PSNR of {trainer.best_metric:.2f}.")
print(f"The best model has been saved in: {FINETUNR_SAVE_DIR}")

Starting model fine-tuning 


Epoch 1/14:   0%|          | 0/185 [00:00<?, ?it/s]

scale:2      eval psnr: 46.36     ssim: 0.9892
best epoch: 1, psnr: 46.361565, ssim: 0.989188


Epoch 2/14:   0%|          | 0/185 [00:00<?, ?it/s]

scale:2      eval psnr: 46.31     ssim: 0.9892


Epoch 3/14:   0%|          | 0/185 [00:00<?, ?it/s]