## Self Supervised Learning (SSL)

In [12]:
!pip install torchgeo --quiet
!pip install lightning --quiet
!pip install prettytable



In [13]:
import os
SEED = 42
# Environment variables
os.environ["PYTHONHASHSEED"] = str(SEED)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [14]:
import torch
torch.cuda.empty_cache()
# os.cpu_count()
torch.cuda.is_available()

False

In [15]:
import torch
from torch.utils.data import Dataset
import rasterio
import numpy as np
from rasterio.enums import Resampling
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from torchvision import transforms  
from torchgeo.trainers.moco import MoCoTask
from torchgeo.models import ResNet18_Weights
import kornia.augmentation as K
import torch.nn.functional as F
import torchgeo.transforms as T
from lightning.pytorch.loggers import CSVLogger
import glob
import shutil
import random
import time
from prettytable import PrettyTable
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
import os
import pandas as pd

In [16]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

Seed set to 42


In [17]:
target_size = 224
target_batch_size= 8 #128 #prefer 256 or 128
target_num_workers=4
target_max_epoch=6
use_peft = True  
from datetime import datetime

In [30]:
def split_dataset(root_dir,csv_output)
    # --- COLLECT IMAGE INFO ---
    data=[]
    # Iterate over each class folder
    for class_name in os.listdir(root_dir):
        class_path = os.path.join(root_dir, class_name)
        # print(class_path)
        if not os.path.isdir(class_path):
            continue  # skip files in root_dir
    
        # Iterate over images in the class folder
        for fname in os.listdir(class_path):
            # if fname.lower().endswith((".jpg", ".jpeg", ".png")):
            if fname.lower().endswith((".tif", ".tiff")):
                # path = os.path.join(class_path, fname)
                rel_path = os.path.join(class_name, fname)
                data.append({
                    "id": os.path.splitext(fname)[0].split("_")[-1],
                     "fname": fname,
                    "rel_path": rel_path,
                    "label": class_name
                })
    # --- CREATE DATAFRAME ---
    df = pd.DataFrame(data)
    # print(df.columns)
    # --- STRATIFIED SPLIT: 80% SSL, 20% Downstream ---
    ssl_df, downstream_df = train_test_split(
        df,
        test_size=0.2,
        stratify=df['label'],
        random_state=42
    )
    
    ssl_df['task'] = 'ssl'
    downstream_df['task'] = 'downstream'
    
    df = pd.concat([ssl_df, downstream_df]).reset_index(drop=True)
    
    ssl_df = df[df['task'] == 'ssl'].copy()
    
    skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)
    
    ssl_splits = np.empty(len(ssl_df), dtype=object)
    split_names = ['subset1', 'subset2', 'subset3', 'subset4']
    
    for fold_idx, (_, val_idx) in enumerate(skf.split(ssl_df, ssl_df['label'])):
        ssl_splits[val_idx] = split_names[fold_idx]
    
    df.loc[ssl_df.index, 'split'] = ssl_splits
    
    down_df = df[df['task'] == 'downstream'].copy()
    # Step 1: Train (70%) vs Temp (30%)
    train_df, temp_df = train_test_split(
        down_df,
        test_size=0.30,
        stratify=down_df['label'],
        random_state=42
    )
    # Step 2: Temp â†’ Val (15%) + Test (15%)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.50,  # half of 30% = 15%
        stratify=temp_df['label'],
        random_state=42
    )
    # Assign splits back
    df.loc[train_df.index, 'split'] = 'train'
    df.loc[val_df.index, 'split'] = 'val'
    df.loc[test_df.index, 'split'] = 'test'
    
    # --- FINAL CHECK ---
    print(df['task'].value_counts())
    print(df['split'].value_counts())
    # print(df.head(10))
    
    # --- SAVE CSV ---
    df.to_csv(csv_output, index=False)
    print(f"CSV saved to {csv_output}")
    # print(df.head())
    # print(df[df['split'] == 'train']['label'].value_counts())

class SSLDataset(Dataset):
    def __init__(self, data_dir, split_path,split, transforms=None):
        """
        Args:
            data_dir (str): Eurosat folder paths.
            split_path (str): CSV file path containing splits metadata
            split (str): all, full_ssl, subset1, subset2, subset3, subset4, train, test, val
            transforms (callable, optional): Optional transform to apply to patches.
        """
        self.data_dir = data_dir
        self.split_path = split_path
        self.transforms = transforms
        # self.target_h= None
        # self.target_w = None
        
        # Precompute all paths based on split
        self.samples = []
        df=pd.read_csv(split_path)
        if split=="all":
            pass
        elif split=="full_ssl":
            df=df[df['task'] == "ssl"]
        else:
            df=df[df['split'] == split]
        self.samples = df['rel_path'].tolist()
        

    def __len__(self):
        return len(self.samples)
    

    # def __getitem__(self, idx):
    #     sample_path = os.path.join(self.data_dir, self.samples[idx])
        
    #     with Image.open(sample_path) as img:
    #         patch_tensor = img.convert("RGB")
    #     if self.transforms:
    #         patch_tensor = self.transforms(patch_tensor)

    #     return {"image": patch_tensor}
    def __getitem__(self, idx):
        sample_path = os.path.join(self.data_dir, self.samples[idx])
    
        # --- Read TIFF with rasterio ---
        with rasterio.open(sample_path) as src:
            # Read all bands as float32
            bands = [src.read(b).astype(np.float32) for b in range(1, src.count + 1)]
        
        # Stack bands to shape [C, H, W]
        img_array = np.stack(bands, axis=0)
    
        # --- Convert to torch tensor ---
        patch_tensor = torch.tensor(img_array, dtype=torch.float32)
    
        # --- Apply transforms if provided ---
        if self.transforms:
            patch_tensor = self.transforms(patch_tensor)
    
        return {"image": patch_tensor}


def calculate_stats(dataset, n_samples=500):
    mean = 0
    std = 0
    total = len(dataset)
    n = min(total, n_samples)

    # Randomly choose n indices
    np.random.seed(42)
    indices = np.random.choice(total, size=n, replace=False)
    # count=0
    for i in indices:
        # count=count+1
        # print(count)
        sample = dataset[i]
        # print(sample)
        img = sample["image"]   # TorchGeo-style dictionary

        mean += img.mean(dim=(1, 2))
        std += img.std(dim=(1, 2))
    mean /= n
    std /= n
    return mean, std

Index(['id', 'fname', 'rel_path', 'label'], dtype='object')
task
ssl           22077
downstream     5520
Name: count, dtype: int64
split
subset1    5520
subset4    5519
subset2    5519
subset3    5519
train      3864
test        828
val         828
Name: count, dtype: int64
CSV saved to /kaggle/working/eurosat_all_bands_split.csv
     id                          fname  \
0  2352            Industrial_2352.tif   
1  1808           Residential_1808.tif   
2   102                 Forest_102.tif   
3  1838  HerbaceousVegetation_1838.tif   
4  1217                Forest_1217.tif   

                                            rel_path                 label  \
0                     Industrial/Industrial_2352.tif            Industrial   
1                   Residential/Residential_1808.tif           Residential   
2                              Forest/Forest_102.tif                Forest   
3  HerbaceousVegetation/HerbaceousVegetation_1838...  HerbaceousVegetation   
4                        

In [31]:
# --- CONFIG ---
root_dir = "/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSATallBands"  # folder with all images
csv_output = "/kaggle/working/eurosat_all_bands_split.csv"
split_dataset(root_dir, csv_output)

Unnamed: 0,id,fname,rel_path,label,task,split
0,2352,Industrial_2352.tif,Industrial/Industrial_2352.tif,Industrial,ssl,subset4
1,1808,Residential_1808.tif,Residential/Residential_1808.tif,Residential,ssl,subset1
2,102,Forest_102.tif,Forest/Forest_102.tif,Forest,ssl,subset1
3,1838,HerbaceousVegetation_1838.tif,HerbaceousVegetation/HerbaceousVegetation_1838...,HerbaceousVegetation,ssl,subset1
4,1217,Forest_1217.tif,Forest/Forest_1217.tif,Forest,ssl,subset4


### Settings

In [36]:
target_size = 224
target_batch_size= 8 #128 #prefer 256 or 128
target_num_workers=4
target_max_epoch=6
use_peft = True  
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger = CSVLogger("logs", name=f"metrics_{timestamp}")

aug = K.AugmentationSequential(
    K.RandomResizedCrop(size=(target_size, target_size), scale=(0.4, 1.0)),
    K.RandomHorizontalFlip(),
    K.RandomVerticalFlip(),
    K.RandomGaussianBlur(kernel_size=(7,7), sigma=(0.1, 1.5), p=0.3),
    K.RandomBrightness(brightness=(0.85, 1.15), p=0.5),
    data_keys=['input'],
)

In [37]:
temp_dataset = SSLDataset(
    data_dir = "/kaggle/input/datasets/apollo2506/eurosat-dataset/EuroSATallBands", 
    split_path= "/kaggle/working/eurosat_all_bands_split.csv",
    split = "all",
)

In [48]:
mean, std = calculate_stats(temp_dataset, n_samples=50)
print(mean, std )

tensor([1306.3080, 1078.1479, 1025.4133,  911.3115, 1162.5461, 1997.2366,
        2391.7368, 2310.8821,  720.3613,   13.1632, 1779.5774, 1090.2094,
        2603.8066]) tensor([ 48.8467, 118.6814, 150.8174, 222.9257, 186.8235, 335.2391, 437.8371,
        509.6657,  97.3510,   1.2455, 332.6506, 254.2236, 484.9431])


In [None]:
# based on 10k samples
mean= [1333.8029, 1488.1448, 1745.9066, 1985.6210, 2322.0129, 2837.1787,
        3065.8462, 3192.4492, 3225.1826, 3344.8479,    0.0000, 2683.2991,
        2116.8357]
std = [384.9683, 472.5244, 497.7275, 590.9384, 578.0192, 641.7764, 699.6282,
        752.0769, 709.3992, 752.4539,   0.0000, 568.3574, 542.2833]


# to avoid 0 std
std = [max(s, 1e-5) for s in std]   

# define transform
transform = transforms.Compose([
    transforms.Resize((target_size, target_size)),
    transforms.Normalize(mean=mean, std=std)
])

In [None]:
import os
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

def run_ssl(
    data_dir,
    split_path,
    split="all",
    model="resnet18",
    weights=None,
    in_channels=13,
    transform=None,
    batch_size=64,
    num_workers=4,
    target_size=224,
    lr=1e-4,
    memory_bank_size=2048,
    temperature=0.15,
    use_peft=False,
    augmentation1=None,
    augmentation2=None,
    max_epochs=10,
    # logger=None,
    experiment_name= None
):
    """
    Run SSL experiment.

    Args:
        split: "all" for full dataset or "subst1"/"subst2"/... for sequential streaming
        weights: ResNet18_Weights object
        transform: preprocessing/normalization transform
        augmentation1/2: MoCo augmentations
        use_peft: bool, freeze backbone except last block
        Other args: training hyperparameters
    """

    # -----------------------------
    # Dataset / DataLoader
    # -----------------------------
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    logger = CSVLogger("logs", name=f"{experiment_name}/metrics_{timestamp}")
    dataset = SSLDataset(
        data_dir=data_dir,
        split_path=split_path,
        split=split,
        transforms=transform
    )
    print(f"Dataset split '{split}' size:", len(dataset))

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
        worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id)
    )

    num_batches = len(data_loader)
    print("Number of batches:", num_batches)

    # -----------------------------
    # Initialize MoCo task
    # -----------------------------
    task = MoCoTask(
        model=model,
        weights=weights,
        in_channels= in_channels,#weights.meta['in_chans'] if weights else 3,
        version=2,
        size=target_size,
        augmentation1=augmentation1,
        augmentation2=augmentation2,
        lr=lr,
        memory_bank_size=memory_bank_size,
        temperature=temperature
    )

    # -----------------------------
    # PEFT / Full Fine-Tuning Logic
    # -----------------------------
    if use_peft:
        print("Using PEFT: freezing backbone except last block, training projection head...")
        for name, param in task.backbone.named_parameters():
            param.requires_grad = "layer4" in name
    else:
        print("Full fine-tuning: backbone and projection head trainable...")
        for param in task.backbone.parameters():
            param.requires_grad = True

    # Momentum backbone always frozen
    for param in task.backbone_momentum.parameters():
        param.requires_grad = False

    # Projection head always trainable
    for param in task.projection_head.parameters():
        param.requires_grad = True

    summary_trainable(task)

    # -----------------------------
    # Trainer
    # -----------------------------
    trainer = Trainer(
        max_epochs=max_epochs,
        enable_progress_bar=True,
        log_every_n_steps=num_batches,
        precision=32,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        deterministic=True,
        logger=logger
    )

    # -----------------------------
    # Training
    # -----------------------------
    start_time = time.time()
    trainer.fit(task, data_loader)
    end_time = time.time()
    print(f"Training time: {(end_time-start_time)/60:.2f} min")

    print(task.trainer.logged_metrics)

    # -----------------------------
    # Save checkpoint / weights
    # -----------------------------
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(task.backbone.state_dict(), f"{experiment_name}/checkpoints/ssl_backbone_{timestamp}.pth")
    torch.save(task.projection_head.state_dict(), f"{experiment_name}/checkpoints/projection_head_{timestamp}.pth")
    trainer.save_checkpoint(f"{experiment_name}/checkpoints/ssl_ckpt_{timestamp}.ckpt")
    