# Preparing the Enviorment


## Check out Colab Instance Region

In [None]:
!curl ipinfo.io

{
  "ip": "34.82.176.201",
  "hostname": "201.176.82.34.bc.googleusercontent.com",
  "city": "The Dalles",
  "region": "Oregon",
  "country": "US",
  "loc": "45.5946,-121.1787",
  "org": "AS396982 Google LLC",
  "postal": "97058",
  "timezone": "America/Los_Angeles",
  "readme": "https://ipinfo.io/missingauth"
}

## Intsall Essential Packaeges

`super-image` library is built on top of **Hugging Face**'s `transformers` and `datasets`

In [None]:
!pip install super-image datasets transformers -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.9/95.9 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m121.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m61.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, lr_scheduler


from super_image import Trainer, PreTrainedModel, TrainingArguments
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter

In [None]:
import os
from pathlib import Path
from typing import List, Union, Dict

from tqdm.auto import tqdm
import numpy as np

## Paths and Directories

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive
'Colab Notebooks'   Education	        IFTTT   Neuromarketing-EHIA
 datasets	   'Google AI Studio'   manga


In [None]:
ROOT_PATH = Path('/content/drive/MyDrive/datasets/sen2venus')

TACO_RAW_DIR = ROOT_PATH / 'TACO_raw_data'
os.makedirs(TACO_RAW_DIR, exist_ok=True)
print(f"Data will be saved to: {TACO_RAW_DIR}")

SELECTED_SUBSETS = [
    "SUDOUE-4",
    "SUDOUE-5",
    "SUDOUE-6"
]
TACO_FILE_PATHS = [TACO_RAW_DIR / f"{site_name}.taco" for site_name in SELECTED_SUBSETS]


NORMALIZED_SETS_DIR = ROOT_PATH / 'normalized_sets'
os.makedirs(NORMALIZED_SETS_DIR, exist_ok=True)
print(f"Normalaized datest will be saved to:\n\t {NORMALIZED_SETS_DIR}")

TRAIN_SAVE_DIR = NORMALIZED_SETS_DIR / 'train'
os.makedirs(TRAIN_SAVE_DIR, exist_ok=True)
print(f"Train data will be saved to:\n\t {TRAIN_SAVE_DIR}")

VAL_SAVE_DIR = NORMALIZED_SETS_DIR / 'val'
os.makedirs(VAL_SAVE_DIR, exist_ok=True)
print(f"Validation data will be saved to:\n\t {VAL_SAVE_DIR}")

TEST_SAVE_DIR = NORMALIZED_SETS_DIR / 'test'
os.makedirs(TEST_SAVE_DIR, exist_ok=True)
print(f"Test data will be saved to:\n\t {TEST_SAVE_DIR}")

# essential for resuming training and saving final model.
FINETUNR_SAVE_DIR = ROOT_PATH / 'edsr_finetune'
os.makedirs(FINETUNR_SAVE_DIR, exist_ok=True)
print(f"Finetuning data including checkpoints and logs will be saved to:\n\t{FINETUNR_SAVE_DIR}")

Data will be saved to: /content/drive/MyDrive/datasets/sen2venus/TACO_raw_data
Normalaized datest will be saved to:
	 /content/drive/MyDrive/datasets/sen2venus/normalized_sets
Train data will be saved to:
	 /content/drive/MyDrive/datasets/sen2venus/normalized_sets/train
Validation data will be saved to:
	 /content/drive/MyDrive/datasets/sen2venus/normalized_sets/val
Test data will be saved to:
	 /content/drive/MyDrive/datasets/sen2venus/normalized_sets/test
Finetuning data including checkpoints and logs will be saved to:
	/content/drive/MyDrive/datasets/sen2venus/edsr_finetune


# Step 1: Define PyTorch Datasets & Dataloaders

In [None]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """
    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx)->Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample['lr'], squeezed_sample['hr']


## Dataloader Instantiation

In [None]:
train_dataset = PreNormalizedDataset(TRAIN_SAVE_DIR)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Initialized dataset from /content/drive/MyDrive/datasets/sen2venus/normalized_sets/train with 4436 samples across 5 shards.


In [None]:
print("---Verifying  dataset output format:")
sample_output = train_dataset[0]
print(sample_output.keys())
print("LR shape:", sample_output['pixel_values'].shape)
print("HR shape:", sample_output['labels'].shape)


print("---Verifying  batch shape:")

lr_batch, hr_batch = next(iter(train_loader))

print(f"Verification successful!")
print(f"LR batch shape: {lr_batch.shape}")
print(f"HR batch shape: {hr_batch.shape}")
print(f"LR batch dtype: {lr_batch.dtype}")
print(f"HR batch dtype: {hr_batch.dtype}")

---Verifying  dataset output format:
dict_keys(['pixel_values', 'labels'])
LR shape: torch.Size([3, 128, 128])
HR shape: torch.Size([3, 256, 256])
---Verifying  batch shape:
Verification successful!
LR batch shape: torch.Size([16, 3, 128, 128])
HR batch shape: torch.Size([16, 3, 256, 256])
LR batch dtype: torch.float32
HR batch dtype: torch.float32


In [None]:
val_dataset = PreNormalizedDataset(VAL_SAVE_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/drive/MyDrive/datasets/sen2venus/normalized_sets/val with 554 samples across 1 shards.


In [None]:
test_dataset = PreNormalizedDataset(TEST_SAVE_DIR)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/drive/MyDrive/datasets/sen2venus/normalized_sets/test with 556 samples across 1 shards.
Loaded 556 test samples.


# Step 2: Load the Pre-trained EDSR Model

**Objectives**:



1.   Loading a well-known, **pre-trained** architecture (edsr-base) specifically configured for **2x super-resolution**.
2.   Confirming that the model accepts data batches and produces outputs of the correct shape ([16, 3, 256, 256]).

## 2.1 Instantiate and Inspect the pre-trained EDSR model

In [None]:
# The 'from_pretrained' method downloads the model configuration and weights.
# We must specify our desired scale factor.
# (LR: 128x128, HR: 256x256), -> scale is 2.
scale = 2
model_id = 'eugenesiow/edsr-base'
model = EdsrModel.from_pretrained(model_id, scale=scale)

# Inspect the model architecture
print("Model architecture loaded successfully:")
print(model)

https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_2x.pt
Model architecture loaded successfully:
DataParallel(
  (module): EdsrModel(
    (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (head): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (body): Sequential(
      (0): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResBlock(
        (body): Sequential(
       

## 2.2 Sanity Check: Pass one batch of data through the model

In [None]:
# a crucial test to ensure the input/output dimensions are compatible.
print("\nPerforming a forward pass sanity check...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # Set to evaluation mode for this check

with torch.no_grad():
    # Get a single batch from our dataloader
    lr_batch, hr_batch = next(iter(train_loader))

    # Move the batch to the same device as the model
    lr_batch = lr_batch.to(device)

    # Perform a forward pass
    predictions = model(lr_batch)

    print(f"Sanity check successful!")
    print(f"Running on device: {device}")
    print(f"Model Input Shape (LR): {lr_batch.shape}")
    print(f"Model Output Shape (Predictions): {predictions.shape}")
    print(f"Target Shape (HR): {hr_batch.shape}")

# Compare output shape with the target High-Resolution shape
assert predictions.shape == hr_batch.shape, "Model output shape does not match target HR shape!"
print("Output shape matches target shape. Ready for training.")


Performing a forward pass sanity check...
Sanity check successful!
Running on device: cuda
Model Input Shape (LR): torch.Size([16, 3, 128, 128])
Model Output Shape (Predictions): torch.Size([16, 3, 256, 256])
Target Shape (HR): torch.Size([16, 3, 256, 256])
Output shape matches target shape. Ready for training.


# Step 3: Configure and Launch the Trainer

## Custum Trainer with Checkpoints

In [None]:


class CustomResumableTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimizer = None
        self.scheduler = None

    def save_checkpoint(self, epoch):
        """Saves a complete training state checkpoint."""
        output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, 'training_checkpoint.pt')

        if self.optimizer is None or self.scheduler is None:
            logger.warning("Optimizer/Scheduler not initialized. Cannot save full checkpoint.")
            return

        state = {
            'epoch': epoch, 'best_metric': self.best_metric,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }
        torch.save(state, checkpoint_path)
        logger.info(f"Saved complete training checkpoint to {checkpoint_path}")

    def load_checkpoint(self):
        """Loads a complete training state from a checkpoint."""
        checkpoint_path = os.path.join(self.args.output_dir, 'training_checkpoint.pt')
        start_epoch = 0
        if not os.path.exists(checkpoint_path):
            logger.warning("No training checkpoint found. Starting from scratch.")
            return start_epoch
        try:
            state = torch.load(checkpoint_path, map_location=self.args.device)
            self.model.load_state_dict(state['model_state_dict'])
            if self.optimizer is None: self._create_optimizer_and_scheduler()
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            self.scheduler.load_state_dict(state['scheduler_state_dict'])
            self.best_metric = state.get('best_metric', 0.0)
            start_epoch = state['epoch'] + 1
            logger.info(f"Successfully loaded checkpoint. Resuming from epoch {start_epoch}.")
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}. Starting from scratch.")
            start_epoch = 0
        return start_epoch

    def _create_optimizer_and_scheduler(self):
        """Helper function to initialize optimizer and scheduler."""
        self.optimizer = Adam(self.model.parameters(), lr=self.args.learning_rate)
        step_size = int(len(self.train_dataset) / self.args.train_batch_size * 200)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=self.args.gamma)

    def train(self, **kwargs):
        """Complete, resumable training loop."""
        self._create_optimizer_and_scheduler()
        start_epoch = self.load_checkpoint()
        train_dataloader = self.get_train_dataloader()

        for epoch in range(start_epoch, self.args.num_train_epochs):
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.args.learning_rate * (0.1 ** (epoch // int(self.args.num_train_epochs * 0.8)))

            self.model.train()
            epoch_losses = AverageMeter()
            with tqdm(total=(len(self.train_dataset) - len(self.train_dataset) % self.args.train_batch_size)) as t:
                t.set_description(f'epoch: {epoch}/{self.args.num_train_epochs - 1}')
                for data in train_dataloader:
                    inputs, labels = data
                    inputs = inputs.to(self.args.device)
                    labels = labels.to(self.args.device)
                    preds = self.model(inputs)
                    loss = torch.nn.L1Loss()(preds, labels)
                    epoch_losses.update(loss.item(), len(inputs))
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    self.scheduler.step()
                    t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
                    t.update(len(inputs))

            # --- Integrated Eval and Saving ---
            # Call original eval logic, which updates self.best_epoch and self.best_metric
            super().eval(epoch)
            # If the eval run was the best so far, save full state
            if self.best_epoch == epoch:
                print(f"New best model found at epoch {epoch}. Saving full checkpoint.")
                self.save_checkpoint(epoch)

## Trainer Config

In [None]:
# Define DEFINITIVE Training Arguments based on a full-code review
training_args = TrainingArguments(
    output_dir=FINETUNR_SAVE_DIR,

    # --- Core parameters that are fully functional ---
    num_train_epochs=15,          # Controls training length and the hardcoded LR decay
    learning_rate=1e-4,           # Sets the initial learning rate
    per_device_train_batch_size=16, # Controls training batch size

    # --- Technical parameters that are functional ---
    seed=42,                      # For reproducibility
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
)

## Start Training

In [None]:
trainer = CustomResumableTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
# Start Training with full confidence in the underlying process.
print("Starting model fine-tuning ")
trainer.train()

print(f"\nTraining complete. The best model was found at epoch {trainer.best_epoch} "
      f"with a PSNR of {trainer.best_metric:.2f}.")
print(f"The best model has been saved in: {output_dir}")

Starting model fine-tuning 


  0%|          | 0/4432 [00:00<?, ?it/s]