# SatMAE Finetuning on Google Colab

This notebook sets up and runs SatMAE finetuning on EuroSAT dataset using Google Colab's free GPU.

## 🚀 Features:
- Automatic environment setup
- EuroSAT dataset download and preprocessing
- SatMAE model finetuning with multispectral data
- GPU acceleration (T4/V100/A100)

**Runtime**: Make sure to select **GPU** runtime (Runtime → Change runtime type → Hardware accelerator → GPU)

## 1. Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected. Make sure to enable GPU runtime!")

In [None]:
# Install required packages
print("Installing required packages...")
!pip install timm  #==0.3.2
!pip install rasterio
!pip install wandb
!pip install tensorboard
!pip install gdown  # For Google Drive downloads

# Import and check versions
import timm
import rasterio
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
import requests
import zipfile
from pathlib import Path
import gdown

print(f"✅ timm version: {timm.__version__}")
print(f"✅ rasterio version: {rasterio.__version__}")
print(f"✅ gdown installed for Google Drive downloads")
print("All packages installed successfully!")

## 2. Download SatMAE Code and Data

In [None]:
# Clone the SatMAE repository
!git clone https://github.com/pvinnbru/SatMAE.git
%cd SatMAE

# List repository contents
!ls -la

In [None]:
# Copy data folder from Google Drive
from google.colab import drive
import shutil
import os

# Mount Google Drive
drive.mount('/content/drive')

# Copy data folder from Google Drive to SatMAE/data
source = '/content/drive/MyDrive/data'
target = 'data'

if os.path.exists(target):
    shutil.rmtree(target)

shutil.copytree(source, target)
print(f"✅ Data copied from {source} to {target}")

# Verify contents
print("\n📁 Data structure:")
!ls -la data/

### 📁 Google Drive Setup

**Required Google Drive Structure:**
```
MyDrive/
├── data/                           # Unzipped EuroSAT dataset folder
│   ├── eurosat_ms/                 # Multispectral dataset
│   └── eurosat_rgb/                # RGB dataset  
└── checkpoint/
    └── pretrain-vit-large-e199.pth # Pretrained model checkpoint
```

**Setup Steps:**
1. Unzip your EuroSAT dataset and upload the `data/` folder to Google Drive root
2. Upload checkpoint to `MyDrive/checkpoint/pretrain-vit-large-e199.pth`
3. Run the cells - they will copy files to the local workspace

In [None]:
# Load pretrained checkpoint from Google Drive
import os
import shutil

# Define paths
drive_checkpoint_path = '/content/drive/MyDrive/checkpoint/pretrain-vit-large-e199.pth'
local_checkpoint_dir = 'checkpoints'
local_checkpoint_path = 'checkpoints/pretrain-vit-large-e199.pth'

print("🔧 Loading pretrained checkpoint from Google Drive...")
print(f"Source: {drive_checkpoint_path}")
print(f"Target: {local_checkpoint_path}")

# Create checkpoints directory
os.makedirs(local_checkpoint_dir, exist_ok=True)
print(f"📂 Created directory: {local_checkpoint_dir}/")

# Check if checkpoint exists in Google Drive
if os.path.exists(drive_checkpoint_path):
    print(f"✅ Found checkpoint in Google Drive")
    print(f"📊 File size: {os.path.getsize(drive_checkpoint_path) / 1e6:.1f} MB")
    
    # Copy checkpoint to local directory
    try:
        shutil.copy2(drive_checkpoint_path, local_checkpoint_path)
        print(f"✅ Checkpoint copied successfully!")
        print(f"📁 Available at: {local_checkpoint_path}")
        
        # Verify the file
        if os.path.exists(local_checkpoint_path):
            print(f"✅ Verification successful - checkpoint ready for training")
        else:
            print(f"❌ Verification failed - file not found at target location")
            
    except Exception as e:
        print(f"❌ Copy failed: {e}")
        
else:
    print(f"❌ Checkpoint not found at: {drive_checkpoint_path}")
    print("Please ensure you have uploaded the checkpoint to your Google Drive")
    print("\nTo fix this:")
    print("1. Go to your Google Drive")
    print("2. Create a folder called 'checkpoint' in the root directory")
    print("3. Upload 'pretrain-vit-large-e199.pth' to MyDrive/checkpoint/")
    print("4. Run this cell again")
    
    print(f"\nExpected Google Drive structure:")
    print(f"  MyDrive/")
    print(f"  ├── data.zip")
    print(f"  └── checkpoint/")
    print(f"      └── pretrain-vit-large-e199.pth")

## 3. Data Preprocessing

### **3.1 Generate txt Files and Training Subsets**

The text files are used for loading Eurosat Data stored in `..\data\`. They look like this:

```
<path_to_image> <label>
```
For example:
```
/path/to/image1.tif    0
/path/to/image2.tif    3
...
```

The .txt-files are generate from the script below:


In [None]:
# Create train/val splits and subsets
import os
from glob import glob
import random

def generate_split_txt(root_folder, out_txt_path, split_ratio=0.8, seed=42):
    """
    Creates train/val .txt files from a root image folder organized by class.
    Supports .tif and .jpg files.
    """
    class_names = sorted(os.listdir(root_folder))
    class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}

    all_samples = []
    for cls in class_names:
        tif_paths = glob(os.path.join(root_folder, cls, "*.tif"))
        jpg_paths = glob(os.path.join(root_folder, cls, "*.jpg"))
        image_paths = tif_paths + jpg_paths
        for path in image_paths:
            all_samples.append(f"{path} {class_to_idx[cls]}")

    if not all_samples:
        print(f"⚠️  No image files found in: {root_folder}")
        return

    random.seed(seed)
    random.shuffle(all_samples)
    split_idx = int(len(all_samples) * split_ratio)
    train_samples = all_samples[:split_idx]
    val_samples = all_samples[split_idx:]

    with open(out_txt_path.replace(".txt", "_train.txt"), "w") as f:
        f.write("\n".join(train_samples))
    with open(out_txt_path.replace(".txt", "_val.txt"), "w") as f:
        f.write("\n".join(val_samples))

    print(f"✅ Created train/val splits for: {root_folder}")
    print(f"   → Train: {len(train_samples)} samples")
    print(f"   → Val:   {len(val_samples)} samples")

# Execution
generate_split_txt("../data/eurosat_ms", "../data_splits/eurosat_ms.txt")
generate_split_txt("../data/eurosat_rgb", "../data_splits/eurosat_rgb.txt")


### 3.2 **Create Training Subsets (10%, 25%, 50%, 100%)**

The Goal is to measure how model performance improves as the training data size increases. To ensure fair and meaningful comparisons across runs, the validation set remains fixed.

The following textfiles were generated and include the complete dataset:

```
../data_splits/eurosat_ms_train.txt
../data_splits/eurosat_rgb_train.txt
```

To subsample:

* Randomly select a percentage of lines from that file
* Save them into new files like:

  ```
  ../data_splits/eurosat_ms_train_10.txt
  ../data_splits/eurosat_ms_train_25.txt
  ../data_splits/eurosat_ms_train_50.txt
  ```

Do this for RGB and MS too:

In [None]:
def subsample_txt_file(input_path, output_prefix, percentages=[10, 25, 50], seed=42):
    with open(input_path, 'r') as f:
        lines = f.readlines()
    
    random.seed(seed)
    random.shuffle(lines)
    
    for p in percentages:
        count = int(len(lines) * (p / 100))
        subset = lines[:count]
        out_path = f"{output_prefix}_{p}.txt"
        with open(out_path, 'w') as f_out:
            f_out.writelines(subset)
        print(f"Saved {p}% subset to {out_path} ({count} samples)")


#Execution
subsample_txt_file("../data_splits/eurosat_ms_train.txt", "../data_splits/eurosat_ms_train", percentages=[10, 25, 50, 75])
subsample_txt_file("../data_splits/eurosat_rgb_train.txt", "../data_splits/eurosat_rgb_train", percentages=[10, 25, 50, 75])

print("\n✅ Data preprocessing complete!")
print("\n📁 Generated files:")
!ls -la data_splits/

## 4. Model Training

In [None]:
# Verify all required files exist
import os

required_files = [
    'main_finetune.py',
    'data_splits/eurosat_ms_train_10.txt',
    'data_splits/eurosat_ms_val.txt',
    'checkpoints/pretrain-vit-large-e199.pth'
]

print("Checking required files:")
all_good = True
for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file}")
    else:
        print(f"❌ {file} - MISSING")
        all_good = False

if all_good:
    print("\n🚀 All files ready for training!")
else:
    print("\n⚠️ Some files are missing. Please check the previous steps.")

In [None]:
# Run SatMAE finetuning
# Adjust batch_size based on available GPU memory

# Check GPU memory and adjust batch size accordingly
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

if gpu_memory_gb >= 15:  # A100, V100
    batch_size = 16
    accum_iter = 8
elif gpu_memory_gb >= 11:  # T4 or similar
    batch_size = 8
    accum_iter = 16
else:  # Smaller GPUs
    batch_size = 4
    accum_iter = 32

print(f"GPU Memory: {gpu_memory_gb:.1f}GB")
print(f"Using batch_size={batch_size}, accum_iter={accum_iter}")

# Run training command
training_cmd = f"""
python main_finetune.py \
  --model_type group_c \
  --model vit_large_patch16 \
  --dataset_type euro_sat \
  --train_path data_splits/eurosat_ms_train_10.txt \
  --test_path data_splits/eurosat_ms_val.txt \
  --finetune checkpoints/pretrain-vit-large-e199.pth \
  --input_size 96 --patch_size 8 \
  --batch_size {batch_size} --accum_iter {accum_iter} \
  --epochs 30 --blr 2e-4 \
  --weight_decay 0.05 \
  --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
  --dropped_bands 0 9 10 \
  --num_workers 2 \
  --output_dir results/eurosat_ms_10 \
  --log_dir results/eurosat_ms_10
"""

print("Starting training...")
print("This will take approximately 30-60 minutes depending on GPU")
print("Command:")
print(training_cmd)

# Execute training
!{training_cmd}

## 5. Monitor Training

In [None]:
# Load TensorBoard in Colab
%load_ext tensorboard
%tensorboard --logdir results/eurosat_ms_10

print("TensorBoard is running above!")
print("You can monitor training progress, loss curves, and metrics.")

In [None]:
# Check training results
import os
import glob

results_dir = "results/eurosat_ms_10"
if os.path.exists(results_dir):
    print("Training results:")
    !ls -la {results_dir}
    
    # Look for log files
    log_files = glob.glob(f"{results_dir}/*.txt")
    if log_files:
        print(f"\nLatest log file: {log_files[-1]}")
        !tail -20 {log_files[-1]}
    
    # Look for checkpoints
    checkpoints = glob.glob(f"{results_dir}/*.pth")
    if checkpoints:
        print(f"\nCheckpoints created: {len(checkpoints)}")
        for cp in checkpoints[-3:]:
            print(f"  {cp}")
else:
    print("No results directory found. Training may not have started yet.")

## 6. Download Results

In [None]:
# Package results for download
import zipfile
import os
from datetime import datetime

def create_results_archive():
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    archive_name = f"satmae_results_{timestamp}.zip"
    
    with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add results directory
        if os.path.exists('results'):
            for root, dirs, files in os.walk('results'):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, '.')
                    zipf.write(file_path, arcname)
        
        # Add training logs
        log_files = glob.glob('*.log')
        for log_file in log_files:
            zipf.write(log_file)
    
    return archive_name

if os.path.exists('results'):
    archive_name = create_results_archive()
    print(f"✅ Results packaged in: {archive_name}")
    print(f"File size: {os.path.getsize(archive_name) / 1e6:.1f} MB")
    print("\nYou can download this file using Colab's file panel on the left.")
else:
    print("No results to package yet.")

## 7. Optional: Cleanup and Additional Experiments

In [None]:
# Run experiments with different data percentages
experiments = [25, 50, 75]

for pct in experiments:
    print(f"\n=== Running experiment with {pct}% of data ===")
    
    # Adjust epochs based on data size
    epochs = max(10, 30 - (pct // 25) * 5)  # Fewer epochs for more data
    
    cmd = f"""
    python main_finetune.py \
      --model_type group_c \
      --model vit_large_patch16 \
      --dataset_type euro_sat \
      --train_path data_splits/eurosat_ms_train_{pct}.txt \
      --test_path data_splits/eurosat_ms_val.txt \
      --finetune checkpoints/pretrain-vit-large-e199.pth \
      --input_size 96 --patch_size 8 \
      --batch_size {batch_size} --accum_iter {accum_iter} \
      --epochs {epochs} --blr 2e-4 \
      --weight_decay 0.05 \
      --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
      --dropped_bands 0 9 10 \
      --num_workers 2 \
      --output_dir results/eurosat_ms_{pct} \
      --log_dir results/eurosat_ms_{pct}
    """
    
    print(f"Training with {pct}% data for {epochs} epochs...")
    !{cmd}
    
    print(f"Completed {pct}% experiment")

print("\n🎉 All experiments completed!")

In [None]:
# Optional: Clean up large files to save space
print("Current disk usage:")
!df -h

print("\nLarge files and directories:")
!du -sh * | sort -hr

# Uncomment to remove dataset after training
# !rm -rf data/EuroSATallBands.zip
# !rm -rf data/EuroSATallBands
# print("Dataset files removed to save space")