### Section 1: Imports and Setup
---

In [None]:
# SimSiam Training on SSL4EO-S12 Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt
import os
import random
from tqdm import tqdm


# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


### Section 2: Data Augmentation Pipeline
---

In [2]:
class SimSiamTransforms:
    """
    Data augmentation pipeline for SimSiam following the original paper
    """
    def __init__(self, img_size=224):
        # Strong augmentation pipeline
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),  # Common for satellite imagery
            transforms.RandomRotation(degrees=90),  # 90-degree rotations for satellite data
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
        ])
    
    def __call__(self, x):
        return self.transform(x), self.transform(x)


### Section 3: SSL4EO-S12 Dataset Class
---

In [3]:
from dataset import SSL4EO_S12_Dataset

### Section 4: SimSiam Model Implementation
---

In [4]:
class ProjectionMLP(nn.Module):
    """Projection MLP for SimSiam"""
    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=2048):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.BatchNorm1d(output_dim, affine=False)  # No bias/scale in final BN
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

class PredictionMLP(nn.Module):
    """Prediction MLP for SimSiam"""
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=2048):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class SimSiam(nn.Module):
    """
    SimSiam model implementation
    """
    def __init__(self, backbone='resnet50', proj_dim=2048, pred_dim=512):
        super().__init__()
        
        # Backbone encoder
        if backbone == 'resnet50':
            self.encoder = models.resnet50(pretrained=False)
            self.encoder.fc = nn.Identity()  # Remove classification head
            encoder_dim = 2048
        elif backbone == 'resnet18':
            self.encoder = models.resnet18(pretrained=False)
            self.encoder.fc = nn.Identity()
            encoder_dim = 512
        else:
            raise ValueError(f"Backbone {backbone} not supported")
        
        # Projection head
        self.projector = ProjectionMLP(encoder_dim, proj_dim, proj_dim)
        
        # Prediction head
        self.predictor = PredictionMLP(proj_dim, pred_dim, proj_dim)
    
    def forward(self, x1, x2):
        # Encode both views
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Predict
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        return p1, p2, z1.detach(), z2.detach()

### Section 5: Loss Function
---

In [5]:
def simsiam_loss(p1, p2, z1, z2):
    """
    SimSiam loss function
    Negative cosine similarity
    """
    def cosine_similarity(a, b):
        a = F.normalize(a, dim=1)
        b = F.normalize(b, dim=1)
        return (a * b).sum(dim=1).mean()
    
    loss1 = -cosine_similarity(p1, z2)
    loss2 = -cosine_similarity(p2, z1)
    
    return (loss1 + loss2) * 0.5


### Section 6: Training Configuration
---

In [None]:
# Training hyperparameters
config = {
    'batch_size': 64,
    'learning_rate': 0.05,
    'weight_decay': 1e-4,
    'epochs': 10,
    'img_size': 224,
    'backbone': 'resnet18',
    'data_dir': './temp_zarr', 
    'save_dir': './checkpoints_new_backbone',
    'log_interval': 10,
    'run': 'Run_batch_size_64_epoch_10_resnet_18_new_backbone'
}

# Create save directory
os.makedirs(config['save_dir'], exist_ok=True)

### Section 7: Data Loading
---

In [None]:
# Initialize transforms and dataset
transform = SimSiamTransforms(img_size=config['img_size'])

# Update the data_dir path to your SSL4EO-S12 dataset location
dataset = SSL4EO_S12_Dataset(
    extracted_dir=config['data_dir'],
    transform=transform
)

# Create data loader
dataloader = DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")

### Section 8: Model Initialization
---

In [None]:
# Initialize model
model = SimSiam(backbone=config['backbone']).to(device)

# Initialize optimizer with cosine annealing
optimizer = optim.SGD(
    model.parameters(),
    lr=config['learning_rate'],
    momentum=0.9,
    weight_decay=config['weight_decay']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config['epochs']
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")


### Section 9: Training Loop
---


In [None]:
from torch.utils.tensorboard import SummaryWriter

# TensorBoard writer
writer = SummaryWriter(log_dir=os.path.join(config['save_dir'], config['run']))

def train_epoch(model, dataloader, optimizer, epoch, config, writer=None):
    model.train()
    total_loss = 0
    num_batches = 0

    total_steps = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=total_steps, desc=f'Epoch {epoch+1}/{config["epochs"]}')
    
    for batch_idx, (view1, view2) in pbar:
        global_step = epoch * total_steps + batch_idx

        view1, view2 = view1.to(device), view2.to(device)

        # Forward pass
        p1, p2, z1, z2 = model(view1, view2)

        # Loss
        loss = simsiam_loss(p1, p2, z1, z2)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
            

        # Update tqdm and log on tensorboard
        if global_step % config['log_interval'] == 0:

            pbar.set_postfix({
                'Iter': global_step,
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss / num_batches:.4f}',
                'LR': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })

            writer.add_scalar("Train/Batch_Loss", loss.item(), global_step)

    avg_loss = total_loss / num_batches

    # Epoch-level logging
    if writer:
        writer.add_scalar("Train/Epoch_Loss", avg_loss, epoch)
        writer.add_scalar("Train/Learning_Rate", optimizer.param_groups[0]['lr'], epoch)

    return avg_loss

# Training history
train_losses = []
learning_rates = []

print("Starting training...")

for epoch in range(config['epochs']):
    avg_loss = train_epoch(model, dataloader, optimizer, epoch, config, writer)

    scheduler.step()

    train_losses.append(avg_loss)
    learning_rates.append(optimizer.param_groups[0]['lr'])

    print(f'Epoch {epoch+1}/{config["epochs"]}, Average Loss: {avg_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')

    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
            'config': config
        }
        torch.save(checkpoint, os.path.join(config['save_dir'], f'simsiam_epoch_{epoch+1}.pth'))
        print(f'Checkpoint saved at epoch {epoch+1}')

# Close TensorBoard writer
writer.close()
print("Training completed!")


Starting training...


Epoch 1/10: 100%|██████████| 1524/1524 [29:08<00:00,  1.15s/it, Iter=1520, Loss=-0.8624, Avg Loss=-0.8082, LR=0.050000]

Epoch 1/10, Average Loss: -0.8083, LR: 0.048776



Epoch 2/10: 100%|██████████| 1524/1524 [28:57<00:00,  1.14s/it, Iter=3040, Loss=-0.8860, Avg Loss=-0.8714, LR=0.048776]

Epoch 2/10, Average Loss: -0.8714, LR: 0.045225



Epoch 3/10: 100%|██████████| 1524/1524 [28:51<00:00,  1.14s/it, Iter=4570, Loss=-0.8833, Avg Loss=-0.8757, LR=0.045225]

Epoch 3/10, Average Loss: -0.8757, LR: 0.039695



Epoch 4/10: 100%|██████████| 1524/1524 [28:46<00:00,  1.13s/it, Iter=6090, Loss=-0.8899, Avg Loss=-0.8775, LR=0.039695]

Epoch 4/10, Average Loss: -0.8775, LR: 0.032725



Epoch 5/10: 100%|██████████| 1524/1524 [28:52<00:00,  1.14s/it, Iter=7610, Loss=-0.8695, Avg Loss=-0.8861, LR=0.032725]

Epoch 5/10, Average Loss: -0.8861, LR: 0.025000



Epoch 6/10: 100%|██████████| 1524/1524 [28:43<00:00,  1.13s/it, Iter=9140, Loss=-0.9087, Avg Loss=-0.8942, LR=0.025000]

Epoch 6/10, Average Loss: -0.8942, LR: 0.017275



Epoch 7/10: 100%|██████████| 1524/1524 [28:46<00:00,  1.13s/it, Iter=10660, Loss=-0.9138, Avg Loss=-0.8990, LR=0.017275]

Epoch 7/10, Average Loss: -0.8990, LR: 0.010305



Epoch 8/10: 100%|██████████| 1524/1524 [28:44<00:00,  1.13s/it, Iter=12190, Loss=-0.9266, Avg Loss=-0.9031, LR=0.010305]

Epoch 8/10, Average Loss: -0.9031, LR: 0.004775



Epoch 9/10: 100%|██████████| 1524/1524 [28:37<00:00,  1.13s/it, Iter=13710, Loss=-0.9147, Avg Loss=-0.9064, LR=0.004775]

Epoch 9/10, Average Loss: -0.9064, LR: 0.001224



Epoch 10/10: 100%|██████████| 1524/1524 [28:44<00:00,  1.13s/it, Iter=15230, Loss=-0.9133, Avg Loss=-0.9083, LR=0.001224]


Epoch 10/10, Average Loss: -0.9083, LR: 0.000000
Checkpoint saved at epoch 10
Training completed!


### Section 10: Save Final Model and Results
---

In [10]:
# Save final model
final_checkpoint = {
    'epoch': config['epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_losses': train_losses,
    'learning_rates': learning_rates,
    'config': config
}

torch.save(final_checkpoint, os.path.join(config['save_dir'], 'simsiam_final.pth'))
print("Final model saved!")

# Save encoder only (for downstream tasks)
encoder_checkpoint = {
    'encoder_state_dict': model.encoder.state_dict(),
    'config': config
}
torch.save(encoder_checkpoint, os.path.join(config['save_dir'], 'simsiam_encoder.pth'))
print("Encoder saved for downstream tasks!")

Final model saved!
Encoder saved for downstream tasks!


### Section 11: Model Evaluation and Feature Extraction
---

In [12]:
def extract_features(model, dataloader, device, max_samples=1000):
    """Extract features from the trained encoder"""
    model.eval()
    features = []
    
    with torch.no_grad():
        for batch_idx, (view1, view2) in enumerate(dataloader):
            if batch_idx * dataloader.batch_size >= max_samples:
                break
                
            view1 = view1.to(device)
            # Use only first view for feature extraction
            feat = model.encoder(view1)
            features.append(feat.cpu().numpy())
    
    return np.concatenate(features, axis=0)

# Extract features for analysis
print("Extracting features for analysis...")
features = extract_features(model, dataloader, device, max_samples=1000)
print(f"Extracted features shape: {features.shape}")

# Analyze feature statistics
print(f"Feature statistics:")
print(f"Mean: {features.mean():.4f}")
print(f"Std: {features.std():.4f}")
print(f"Min: {features.min():.4f}")
print(f"Max: {features.max():.4f}")

Extracting features for analysis...
Extracted features shape: (1024, 512)
Feature statistics:
Mean: 0.5273
Std: 0.5048
Min: 0.0000
Max: 9.0637
