## Self Supervised Learning (SSL)

In [1]:
!pip install torchgeo --quiet
!pip install lightning --quiet
!pip install prettytable

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m652.1/652.1 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m246.1/246.1 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m853.6/853.6 kB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.6/37.6 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165.6/165.6 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K   

In [2]:
import os
SEED = 42
# Environment variables
os.environ["PYTHONHASHSEED"] = str(SEED)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [3]:
import torch
torch.cuda.empty_cache()
# os.cpu_count()
torch.cuda.is_available()

False

In [4]:
import torch
from torch.utils.data import Dataset
import rasterio
import numpy as np
from rasterio.enums import Resampling
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from torchvision import transforms  
from torchgeo.trainers.moco import MoCoTask
from torchgeo.models import ResNet50_Weights
import kornia.augmentation as K
import torch.nn.functional as F
import torchgeo.transforms as T
from lightning.pytorch.loggers import CSVLogger
import glob
import shutil
import random
import time
from prettytable import PrettyTable
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import numpy as np



In [5]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

Seed set to 42


In [8]:
import os
import pandas as pd

# --- CONFIG ---
root_dir = "/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT"  # folder with all images
csv_output = "/kaggle/working/eurosat_split.csv"

# --- COLLECT IMAGE INFO ---
data=[]
# Iterate over each class folder
for class_name in os.listdir(root_dir):
    class_path = os.path.join(root_dir, class_name)
    print(class_path)
    if not os.path.isdir(class_path):
        continue  # skip files in root_dir

    # Iterate over images in the class folder
    for fname in os.listdir(class_path):
        if fname.lower().endswith((".jpg", ".jpeg", ".png")):
            path = os.path.join(class_path, fname)
            
            data.append({
                "id": os.path.splitext(fname)[0].split("_")[-1],
                 "fname": fname,
                "path": path,
                "label": class_name
            })
    

# --- CREATE DATAFRAME ---
df = pd.DataFrame(data)

# --- STRATIFIED SPLIT: 80% SSL, 20% Downstream ---
ssl_df, downstream_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['label'],
    random_state=42
)

ssl_df['task'] = 'ssl'
downstream_df['task'] = 'downstream'

df = pd.concat([ssl_df, downstream_df]).reset_index(drop=True)

ssl_df = df[df['task'] == 'ssl'].copy()

skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)

ssl_splits = np.empty(len(ssl_df), dtype=object)
split_names = ['subst1', 'subst2', 'subst3', 'subst4']

for fold_idx, (_, val_idx) in enumerate(skf.split(ssl_df, ssl_df['label'])):
    ssl_splits[val_idx] = split_names[fold_idx]

df.loc[ssl_df.index, 'split'] = ssl_splits




down_df = df[df['task'] == 'downstream'].copy()

# Step 1: Train (70%) vs Temp (30%)
train_df, temp_df = train_test_split(
    down_df,
    test_size=0.30,
    stratify=down_df['label'],
    random_state=42
)

# Step 2: Temp → Val (15%) + Test (15%)
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,  # half of 30% = 15%
    stratify=temp_df['label'],
    random_state=42
)

# Assign splits back
df.loc[train_df.index, 'split'] = 'train'
df.loc[val_df.index, 'split'] = 'val'
df.loc[test_df.index, 'split'] = 'test'




# --- FINAL CHECK ---
print(df['task'].value_counts())
print(df['split'].value_counts())
# print(df.head(10))

# --- SAVE CSV ---
df.to_csv(csv_output, index=False)
print(f"CSV saved to {csv_output}")
print(df.head())
print(df[df['split'] == 'train']['label'].value_counts())


/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/SeaLake
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/Highway
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/River
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/Pasture
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/Industrial
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/Residential
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/PermanentCrop
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/validation.csv
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/AnnualCrop
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/train.csv
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/test.csv
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/label_map.json
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/Forest
/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSAT/HerbaceousVegetation
task
ssl           21600
downstre

label
AnnualCrop              420
HerbaceousVegetation    420
Residential             420
SeaLake                 420
Forest                  420
Industrial              350
Highway                 350
PermanentCrop           350
River                   350
Pasture                 280
Name: count, dtype: int64


In [None]:
class SSLDataset(Dataset):
    def __init__(self, scenes, bands, transforms=None):
        """
        Args:
            scenes (list): List of scene folder paths.
            bands (list): List of band names (e.g., ["B1","B2"]).
            patch_size (tuple): Size of random crop (H, W).
            transforms (callable, optional): Optional transform to apply to patches.
        """
        self.scenes = scenes
        self.bands = bands
        # self.patch_size = patch_size
        self.transforms = transforms
        self.target_h= None
        self.target_w = None
        

        # Precompute all timestamp paths to treat each timestamp as a sample
        self.samples = []
        for scene_path in scenes:
            timestamps = sorted([
                d for d in os.listdir(scene_path)
                if os.path.isdir(os.path.join(scene_path, d))
            ])
            for ts in timestamps:
                self.samples.append(os.path.join(scene_path, ts))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ts_path = self.samples[idx]

        with rasterio.open(os.path.join(ts_path, "B2.tif")) as src:
            target_h, target_w = src.height, src.width
            # print(target_h,target_w, "target width and height" )

        band_arrays = []

        for b in self.bands:
            path = os.path.join(ts_path, f"{b}.tif")
            with rasterio.open(path) as src:
                if src.height == target_h and src.width == target_w:
                    arr = src.read(1).astype(np.float32)
                else:
                    arr = src.read(
                        1,
                        out_shape=(target_h, target_w),
                        resampling=Resampling.bilinear
                    ).astype(np.float32)

            band_arrays.append(arr)

        # Insert fake B10
        insert_idx = 10
        b10_pad = np.zeros((target_h, target_w), dtype=np.float32)
        band_arrays.insert(insert_idx, b10_pad)

        img = np.stack(band_arrays, axis=0)

        # img_patch = self._random_crop(img)

        patch_tensor = torch.tensor(img, dtype=torch.float32)

        if self.transforms:
            patch_tensor = self.transforms(patch_tensor)

        return {"image": patch_tensor}

In [9]:
from torchgeo.datasets import EuroSAT

dataset= EuroSAT(root='/kaggle/input/datasets/apollo2506/eurosat-dataset/')

DatasetNotFoundError: Dataset not found in `root='/kaggle/input/datasets/apollo2506/eurosat-dataset/'` and `download=False`, either specify a different `root` or use `download=True` to automatically download the dataset.

In [None]:

Sub-sample 3k Data

# root_dir = "/Volumes/WD_Rabina/competition/extracted_data/s2a"
# # List all folders
# scenes = sorted(glob.glob(os.path.join(root_dir, "*/")))
# # print(scenes)
# print(len(scenes))

# no_of_files=3000

# # Randomly select 3000 scenes (without replacement)
# selected_scenes = random.sample(scenes, k=3000)

# print(f"Total selected scenes: {len(selected_scenes)}")
# # print(selected_scenes[:10])  # show first 10 for sanity check

# # Path to new folder where selected scenes will be copied
# destination_root = "data/s2a_3k_sample"
# os.makedirs(destination_root, exist_ok=True)  # create folder if it doesn't exist

# # Copy each selected folder
# count=0
# for scene_path in selected_scenes:
#     # Get folder name only (e.g., "000015")
#     folder_name = os.path.basename(os.path.normpath(scene_path))
    
#     # Destination path
#     dest_path = os.path.join(destination_root, folder_name)
#     count=count+1
#     print(count)
#     # Copy folder and all its contents
#     shutil.copytree(scene_path, dest_path)

# print(f"Copied {len(selected_scenes)} folders to {destination_root}")


### Settings

In [None]:
target_size = 224
target_batch_size= 8 #128 #prefer 256 or 128
target_num_workers=4
target_max_epoch=6
use_peft = True  
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger = CSVLogger("logs", name=f"metrics_{timestamp}")

aug = K.AugmentationSequential(
    K.RandomResizedCrop(size=(target_size, target_size), scale=(0.4, 1.0)),
    K.RandomHorizontalFlip(),
    K.RandomVerticalFlip(),
    K.RandomGaussianBlur(kernel_size=(7,7), sigma=(0.1, 1.5), p=0.3),
    K.RandomBrightness(brightness=(0.85, 1.15), p=0.5),
    data_keys=['input'],
)

### Helper Functions

In [None]:
class SSLDataset(Dataset):
    def __init__(self, scenes, bands, transforms=None):
        """
        Args:
            scenes (list): List of scene folder paths.
            bands (list): List of band names (e.g., ["B1","B2"]).
            patch_size (tuple): Size of random crop (H, W).
            transforms (callable, optional): Optional transform to apply to patches.
        """
        self.scenes = scenes
        self.bands = bands
        # self.patch_size = patch_size
        self.transforms = transforms
        self.target_h= None
        self.target_w = None
        

        # Precompute all timestamp paths to treat each timestamp as a sample
        self.samples = []
        for scene_path in scenes:
            timestamps = sorted([
                d for d in os.listdir(scene_path)
                if os.path.isdir(os.path.join(scene_path, d))
            ])
            for ts in timestamps:
                self.samples.append(os.path.join(scene_path, ts))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ts_path = self.samples[idx]

        with rasterio.open(os.path.join(ts_path, "B2.tif")) as src:
            target_h, target_w = src.height, src.width
            # print(target_h,target_w, "target width and height" )

        band_arrays = []

        for b in self.bands:
            path = os.path.join(ts_path, f"{b}.tif")
            with rasterio.open(path) as src:
                if src.height == target_h and src.width == target_w:
                    arr = src.read(1).astype(np.float32)
                else:
                    arr = src.read(
                        1,
                        out_shape=(target_h, target_w),
                        resampling=Resampling.bilinear
                    ).astype(np.float32)

            band_arrays.append(arr)

        # Insert fake B10
        insert_idx = 10
        b10_pad = np.zeros((target_h, target_w), dtype=np.float32)
        band_arrays.insert(insert_idx, b10_pad)

        img = np.stack(band_arrays, axis=0)

        # img_patch = self._random_crop(img)

        patch_tensor = torch.tensor(img, dtype=torch.float32)

        if self.transforms:
            patch_tensor = self.transforms(patch_tensor)

        return {"image": patch_tensor}

def calculate_stats(dataset, n_samples=500):
    mean = 0
    std = 0
    total = len(dataset)
    n = min(total, n_samples)

    # Randomly choose n indices
    np.random.seed(42)
    indices = np.random.choice(total, size=n, replace=False)
    count=0
    for i in indices:
        count=count+1
        print(count)
        sample = dataset[i]
        img = sample["image"]   # TorchGeo-style dictionary

        mean += img.mean(dim=(1, 2))
        std += img.std(dim=(1, 2))
    mean /= n
    std /= n
    return mean, std


def summary_trainable(model):
    table = PrettyTable()
    table.field_names = ["Module", "Type", "Trainable Params", "Total Params"]

    for name, module in model.named_children():
        total_params = sum(p.numel() for p in module.parameters())
        trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
        table.add_row([name, type(module).__name__, f"{trainable_params:,}", f"{total_params:,}"])

    total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(table)
    print(f"Total trainable parameters: {total_trainable:,} ({total_trainable / 1e6:.2f} M)")
    print(f"Total parameters: {total_params:,} ({total_params / 1e6:.2f} M)")


In [None]:
# root_dir = "data/s2a"
root_dir = "/kaggle/input/icpr-2026-competition-ssl-s2a-3k-subset/ICPR_SSL_S2A_3k_sample"
# List all folders
# scenes = sorted(glob.glob(os.path.join(root_dir, "*/")))
scenes= ["/kaggle/input/icpr-2026-competition-ssl-s2a-3k-subset/ICPR_SSL_S2A_3k_sample/000017","/kaggle/input/icpr-2026-competition-ssl-s2a-3k-subset/ICPR_SSL_S2A_3k_sample/000040","/kaggle/input/icpr-2026-competition-ssl-s2a-3k-subset/ICPR_SSL_S2A_3k_sample/000058","/kaggle/input/icpr-2026-competition-ssl-s2a-3k-subset/ICPR_SSL_S2A_3k_sample/000066"]
# scenes = ["data/s2a/000015", "data/s2a/000016"]  # list of scene folders
bands = ["B1","B2","B3","B4","B5","B6","B7","B8","B8A","B9","B11","B12"]
# print(scenes)

# import time
# # One time run to get mean and std
# temp_dataset = SSLDataset(scenes, bands)
# start_time=time.time()
# mean, std = calculate_stats(temp_dataset, n_samples=10000)
# end_time=time.time()
# print(f"calculate_stats time: {(end_time-start_time)/60} min")
# print(mean)
# print(std)

# based on 10k samples
mean= [1333.8029, 1488.1448, 1745.9066, 1985.6210, 2322.0129, 2837.1787,
        3065.8462, 3192.4492, 3225.1826, 3344.8479,    0.0000, 2683.2991,
        2116.8357]
std = [384.9683, 472.5244, 497.7275, 590.9384, 578.0192, 641.7764, 699.6282,
        752.0769, 709.3992, 752.4539,   0.0000, 568.3574, 542.2833]

# based on 500 sample
# mean= [1041.5322, 1224.2570, 1549.6492, 1815.5840, 2171.1243, 2729.5166,
#         2990.2266, 3074.2515, 3162.9661, 3260.7983,    0.0000, 2969.8357,
#         2335.3250]
# std = [328.5996, 410.0965, 443.5781, 547.2238, 519.4624, 547.1485, 606.4136,
#         649.6067, 621.7164, 687.0721,   0.0000, 574.0366, 560.9932]

# mean = [2358.7412, 2402.7629, 2580.9255, 2614.2227, 3057.6877, 3578.1008,
#         3796.8345, 3795.6868, 3947.5913, 4833.6362,    0.0000, 3379.1743,
#         2666.4465]
# std = [2994.4861, 2847.0354, 2542.9307, 2411.1196, 2399.0249, 2137.6804,
#         2036.8357, 2042.7140, 1957.9615, 3559.4121,    0.0000, 1535.7960,
#         1393.8278]

# to avoid 0 std
std = [max(s, 1e-5) for s in std]   

# define transform
transform = transforms.Compose([
    transforms.Resize((target_size, target_size)),
    transforms.Normalize(mean=mean, std=std)
])

dataset = SSLDataset(scenes, bands, transforms=transform)
print(len(dataset))
print(dataset[0]['image'].shape)

data_loader = DataLoader(
    dataset, 
    batch_size=target_batch_size, 
    shuffle=True, 
    pin_memory=True,
    num_workers=target_num_workers,
    worker_init_fn=lambda worker_id: np.random.seed(SEED + worker_id)
)
num_batches = len(data_loader)
print("Number of batches:", num_batches)

import time
task = MoCoTask(
    model="resnet50",      
    weights= ResNet50_Weights.SENTINEL2_ALL_MOCO,
    in_channels=13,       
    version=2,             # MoCo v2
    size=target_size,          
    augmentation1=aug,
    augmentation2=aug,
    lr=1e-4,
    memory_bank_size=2048,
    temperature=0.15,
)

# # Load your checkpoint to resume task
# ckpt_path = "/kaggle/working/ssl_3k_ckpt_20260206_063623.ckpt"
# task = task.load_from_checkpoint(ckpt_path)

# -----------------------------
# PEFT / Full Fine-Tuning Logic
# -----------------------------
if use_peft:
    print("Using PEFT: freezing backbone except last block, training projection head...")
    for name, param in task.backbone.named_parameters():
        if "layer4" in name:      # optionally fine-tune last residual block
            param.requires_grad = True
        else:
            param.requires_grad = False
else:
    print("Full fine-tuning: backbone and projection head trainable...")
    for param in task.backbone.parameters():
        param.requires_grad = True

# Momentum backbone always frozen
for param in task.backbone_momentum.parameters():
    param.requires_grad = False

# Projection head always trainable
for param in task.projection_head.parameters():
    param.requires_grad = True

# Example usage for your task
summary_trainable(task)



In [None]:
trainer = Trainer(
    max_epochs=target_max_epoch,
    enable_progress_bar=True, 
    log_every_n_steps=num_batches,
    precision=32,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    deterministic=True,
    logger=logger)

start_time=time.time()
trainer.fit(task, data_loader)
end_time=time.time()
print(f"Training time: {(end_time-start_time)/60} min")

print(task.trainer.logged_metrics)



In [None]:
# Save the backbone encoder only
torch.save(task.backbone.state_dict(),f"ssl_backbone_{timestamp}.pth")
torch.save(task.projection_head.state_dict(), f"projection_head_{timestamp}.pth")
trainer.save_checkpoint(f"ssl_3k_ckpt_{timestamp}.ckpt")

