# SatMAE Finetuning on Google Colab

This notebook sets up and runs SatMAE finetuning on EuroSAT dataset using Google Colab's free GPU.

## 🚀 Features:
- Automatic environment setup
- EuroSAT dataset download and preprocessing
- SatMAE model finetuning with multispectral data
- GPU acceleration (T4/V100/A100)

**Runtime**: Make sure to select **GPU** runtime (Runtime → Change runtime type → Hardware accelerator → GPU)

## 1. Environment Setup

In [None]:
# Check GPU availability (without importing torch yet)
print("🔍 Checking Google Colab GPU availability...")

# Check if CUDA is available in the system
try:
    import subprocess
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
    if result.returncode == 0:
        print("✅ NVIDIA GPU detected in system")
        # Extract GPU name from nvidia-smi output
        lines = result.stdout.split('\n')
        for line in lines:
            if 'Tesla' in line or 'RTX' in line or 'GTX' in line or 'T4' in line or 'V100' in line or 'A100' in line:
                gpu_name = line.split('|')[1].strip()
                print(f"✅ GPU: {gpu_name}")
                break
    else:
        print("⚠️ No NVIDIA GPU detected")
except Exception as e:
    print("⚠️ Could not detect GPU")

print("\n⚠️ Note: PyTorch will be installed in the next cell with correct versions")
print("📋 Runtime requirement: Make sure to select **GPU** runtime")
print("   (Runtime → Change runtime type → Hardware accelerator → GPU)")

# Check Python version
import sys
print(f"\n🐍 Python version: {sys.version}")
print("✅ Ready for SatMAE package installation")

PyTorch version: 2.6.0+cu124
CUDA available: True
GPU: Tesla T4
CUDA version: 12.4
GPU memory: 15.8 GB


In [None]:

# Install required packages with optimized versions for compatibility
print("🚀 Setting up SatMAE environment...")
print("📦 Installing packages (fast pip approach)...")

# Install packages with version constraints for compatibility
!pip install 'numpy<2.0' --quiet  # Avoid NumPy 2.0 compatibility issues
!pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 --quiet
!pip install timm==0.9.12 --quiet  # Modern timm version with compatibility patches
!pip install rasterio --quiet      # For satellite image processing (.tif files)
!pip install wandb --quiet         # For experiment tracking
!pip install tensorboard --quiet   # For training monitoring
!pip install pandas --quiet        # For data handling
!pip install pillow --quiet        # For image processing
!pip install matplotlib --quiet    # For visualization
!pip install tqdm --quiet          # For progress bars
!pip install pyyaml --quiet        # For config files
!pip install scikit-learn --quiet  # For data preprocessing

print("\n� Applying compatibility patches...")

# Patch 1: Fix torch._six import error
import sys
import types
six_module = types.ModuleType('six')
six_module.PY3 = True
six_module.string_types = str
sys.modules['torch._six'] = six_module

print("✅ Environment setup complete!")
print("🚀 Using fast pip installation with modern PyTorch + compatibility patches")
print("📋 All required packages installed for SatMAE finetuning")

# Download the SatMAE repository
import os
if not os.path.exists('SatMAE'):
    print("📥 Downloading SatMAE repository...")
    !git clone https://github.com/pvinnbru/SatMAE.git
    print("✅ Repository downloaded")

# Verify critical packages and imports
print("\n🔍 Verifying installation...")
try:
    import torch
    import torchvision
    import timm
    import numpy as np
    import rasterio
    import pandas as pd
    import wandb
    
    print(f"✅ PyTorch: {torch.__version__}")
    print(f"✅ torchvision: {torchvision.__version__}")
    print(f"✅ timm: {timm.__version__}")
    print(f"✅ numpy: {np.__version__}")
    print(f"✅ rasterio: {rasterio.__version__}")
    
    # Test critical SatMAE imports with compatibility
    from timm.models.layers import trunc_normal_
    from timm.data.mixup import Mixup
    from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
    print("✅ SatMAE-critical timm imports successful")
    
    # Check GPU availability
    if torch.cuda.is_available():
        print(f"✅ CUDA: {torch.cuda.get_device_name(0)}")
        print(f"✅ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("⚠️ CUDA not available")
    
    print("\n" + "="*60)
    print("🎉 FAST PIP ENVIRONMENT READY!")
    print("="*60)
    print("✅ Modern PyTorch 2.0.1 + compatibility patches")
    print("✅ timm 0.9.12 with backward compatibility")
    print("✅ All packages for satellite image processing")
    print("⚡ Installation time: ~2 minutes (vs 10-15 for conda)")
    print("🚀 Ready for SatMAE finetuning!")
    print("="*60)
    
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Some packages may not have installed correctly.")


# Navigate to SatMAE directory for subsequent cells
%cd SatMAE

**Required Google Drive Structure:**
```
MyDrive/
├── data/                           # Unzipped EuroSAT dataset folder
│   ├── eurosat_ms/                 # Multispectral dataset
│   └── eurosat_rgb/                # RGB dataset  
└── checkpoint/
    └── pretrain-vit-large-e199.pth # Pretrained model checkpoint
```

**Setup Steps:**
1. Unzip your EuroSAT dataset and upload the `data/` folder to Google Drive root
2. Upload checkpoint to `MyDrive/checkpoint/pretrain-vit-large-e199.pth`
3. Run the cells - they will copy files to the local workspace

In [None]:
# Load pretrained checkpoint from Google Drive
import os
import shutil

# Define paths (we're already in SatMAE directory after %cd SatMAE)
drive_checkpoint_path = '/content/drive/MyDrive/checkpoint/pretrain-vit-large-e199.pth'
local_checkpoint_dir = 'checkpoints'  # Fixed: removed SatMAE/ prefix
local_checkpoint_path = 'checkpoints/pretrain-vit-large-e199.pth'  # Fixed: removed SatMAE/ prefix

print("🔧 Loading pretrained checkpoint from Google Drive...")
print(f"Source: {drive_checkpoint_path}")
print(f"Target: {local_checkpoint_path}")

# Create checkpoints directory
os.makedirs(local_checkpoint_dir, exist_ok=True)
print(f"📂 Created directory: {os.path.abspath(local_checkpoint_dir)}/")

# Check if checkpoint exists in Google Drive
if os.path.exists(drive_checkpoint_path):
    print(f"✅ Found checkpoint in Google Drive")
    print(f"📊 File size: {os.path.getsize(drive_checkpoint_path) / 1e6:.1f} MB")

    # Copy checkpoint to local directory
    try:
        shutil.copy2(drive_checkpoint_path, local_checkpoint_path)
        print(f"✅ Checkpoint copied successfully!")
        print(f"📁 Available at: {os.path.abspath(local_checkpoint_path)}")

        # Verify the file
        if os.path.exists(local_checkpoint_path):
            print(f"✅ Verification successful - checkpoint ready for training")
        else:
            print(f"❌ Verification failed - file not found at target location")

    except Exception as e:
        print(f"❌ Copy failed: {e}")

else:
    print(f"❌ Checkpoint not found at: {drive_checkpoint_path}")
    print("Please ensure you have uploaded the checkpoint to your Google Drive")
    print("\nTo fix this:")
    print("1. Go to your Google Drive")
    print("2. Create a folder called 'checkpoint' in the root directory")
    print("3. Upload 'pretrain-vit-large-e199.pth' to MyDrive/checkpoint/")
    print("4. Run this cell again")

    print(f"\nExpected Google Drive structure:")
    print(f"  MyDrive/")
    print(f"  ├── data/")
    print(f"  │   ├── eurosat_ms/")
    print(f"  │   └── eurosat_rgb/")
    print(f"  └── checkpoint/")
    print(f"      └── pretrain-vit-large-e199.pth")

🔧 Loading pretrained checkpoint from Google Drive...
Source: /content/drive/MyDrive/checkpoint/pretrain-vit-large-e199.pth
Target: SatMAE/checkpoints/pretrain-vit-large-e199.pth
📂 Created directory: SatMAE/checkpoints/
✅ Found checkpoint in Google Drive
📊 File size: 298.8 MB
✅ Checkpoint copied successfully!
📁 Available at: SatMAE/checkpoints/pretrain-vit-large-e199.pth
✅ Verification successful - checkpoint ready for training


### **3.1 Generate txt Files and Training Subsets**

The text files are used for loading Eurosat Data stored in `SatMAE\data\`. They look like this:

```
<path_to_image> <label>
```
For example:
```
/path/to/image1.tif    0
/path/to/image2.tif    3
...
```

The .txt-files are generate from the script below:



### 3.2 **Create Training Subsets (10%, 25%, 50%, 100%)**

The Goal is to measure how model performance improves as the training data size increases. To ensure fair and meaningful comparisons across runs, the validation set remains fixed.

The following textfiles were generated and include the complete dataset:

```
SatMAE/data_splits/eurosat_ms_train.txt
SatMAE/data_splits/eurosat_rgb_train.txt
```

To subsample:

* Randomly select a percentage of lines from that file
* Save them into new files like:

  ```
  SatMAE/data_splits/eurosat_ms_train_10.txt
  SatMAE/data_splits/eurosat_ms_train_25.txt
  SatMAE/data_splits/eurosat_ms_train_50.txt
  ```

Do this for RGB and MS too:

In [None]:
# Verify all required files exist
import os

required_files = [
    'main_finetune.py',  # We're in SatMAE directory
    'data_splits/eurosat_ms_train_10.txt',  # Local txt files
    'data_splits/eurosat_ms_val.txt',
    'checkpoints/pretrain-vit-large-e199.pth'
]

# Also check Google Drive data access
gdrive_paths = [
    '/content/drive/MyDrive/data/eurosat_ms',
    '/content/drive/MyDrive/checkpoint/pretrain-vit-large-e199.pth'
]

print("Checking required files:")
all_good = True

# Check local files
for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file}")
    else:
        print(f"❌ {file} - MISSING")
        all_good = False

# Check Google Drive access
print("\nChecking Google Drive data access:")
for path in gdrive_paths:
    if os.path.exists(path):
        print(f"✅ {path}")
    else:
        print(f"❌ {path} - MISSING")
        all_good = False

# Verify txt files contain valid paths
if os.path.exists('data_splits/eurosat_ms_train_10.txt'):
    with open('data_splits/eurosat_ms_train_10.txt', 'r') as f:
        first_line = f.readline().strip()
        if first_line:
            image_path = first_line.split()[0]
            if os.path.exists(image_path):
                print(f"✅ Sample image accessible: {image_path}")
            else:
                print(f"❌ Sample image not accessible: {image_path}")
                all_good = False

if all_good:
    print("\n🚀 All files ready for training!")
    print("💡 Using Google Drive data directly - fast and efficient!")
else:
    print("\n⚠️ Some files are missing. Please check the previous steps.")

## 5. Monitor Training

In [None]:
# Check training results
import os
import glob

results_dir = "SatMAE/results/eurosat_ms_10"
if os.path.exists(results_dir):
    print("Training results:")
    !ls -la {results_dir}

    # Look for log files
    log_files = glob.glob(f"{results_dir}/*.txt")
    if log_files:
        print(f"\nLatest log file: {log_files[-1]}")
        !tail -20 {log_files[-1]}

    # Look for checkpoints
    checkpoints = glob.glob(f"{results_dir}/*.pth")
    if checkpoints:
        print(f"\nCheckpoints created: {len(checkpoints)}")
        for cp in checkpoints[-3:]:
            print(f"  {cp}")
else:
    print("No results directory found. Training may not have started yet.")

In [None]:
# Package results for download
import zipfile
import os
from datetime import datetime
import glob

def create_results_archive():
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    archive_name = f"satmae_results_{timestamp}.zip"

    with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add results directory from SatMAE
        results_path = 'SatMAE/results'
        if os.path.exists(results_path):
            for root, dirs, files in os.walk(results_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, '.')
                    zipf.write(file_path, arcname)

        # Add training logs from SatMAE directory
        log_files = glob.glob('SatMAE/*.log')
        for log_file in log_files:
            zipf.write(log_file)

    return archive_name

if os.path.exists('SatMAE/results'):
    archive_name = create_results_archive()
    print(f"✅ Results packaged in: {archive_name}")
    print(f"File size: {os.path.getsize(archive_name) / 1e6:.1f} MB")
    print("\nYou can download this file using Colab's file panel on the left.")
else:
    print("No results to package yet.")

In [None]:
# Run experiments with different data percentages
experiments = [25, 50, 75]

for pct in experiments:
    print(f"\n=== Running experiment with {pct}% of data ===")

    # Adjust epochs based on data size
    epochs = max(10, 30 - (pct // 25) * 5)  # Fewer epochs for more data

    cmd = f"""
    cd SatMAE && python main_finetune.py \
      --model_type group_c \
      --model vit_large_patch16 \
      --dataset_type euro_sat \
      --train_path data_splits/eurosat_ms_train_{pct}.txt \
      --test_path data_splits/eurosat_ms_val.txt \
      --finetune checkpoints/pretrain-vit-large-e199.pth \
      --input_size 96 --patch_size 8 \
      --batch_size {batch_size} --accum_iter {accum_iter} \
      --epochs {epochs} --blr 2e-4 \
      --weight_decay 0.05 \
      --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
      --dropped_bands 0 9 10 \
      --num_workers 2 \
      --output_dir results/eurosat_ms_{pct} \
      --log_dir results/eurosat_ms_{pct}
    """

    print(f"Training with {pct}% data for {epochs} epochs...")
    !{cmd}

    print(f"Completed {pct}% experiment")

print("\n🎉 All experiments completed!")