# Imports

In [6]:
import os
import torch
import pandas as pd
from PIL import Image
from ast import literal_eval
from tqdm import tqdm

from torchvision import models
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Config

In [None]:
# TODO: use the ./image folder 
DATA_FOLDER = 
SUBSET_DIR = f"../data/subsets"
os.makedirs(SUBSET_DIR, exist_ok=True)
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

train_csv = f"{DATA_FOLDER}/train.csv"
val_csv   = f"{DATA_FOLDER}/val.csv"

# fraction - percentage of the data used
fraction = 0.01

# frac - percentage of the data used
def sample_csv(src_csv, dst_csv, frac=fraction, seed=42):
    df = pd.read_csv(src_csv)
    df_sample = df.sample(frac=frac, random_state=seed).reset_index(drop=True)
    df_sample.to_csv(dst_csv, index=False)
    print(f"Saved {len(df_sample)} rows → {dst_csv}")

sample_csv(train_csv, f"{SUBSET_DIR}/train_subset_{fraction}.csv")
sample_csv(val_csv, f"{SUBSET_DIR}/val_subset_{fraction}.csv")

Saved 600 rows → ../data/subsets/train_subset_0.01.csv
Saved 200 rows → ../data/subsets/val_subset_0.01.csv


# Load Datasets

In [8]:
class_map = {
    'VenusExpress': 1, 'Cheops': 2, 'LisaPathfinder': 3, 'ObservationSat1': 4,
    'Proba2': 5, 'Proba3': 6, 'Proba3ocs': 7, 'Smart1': 8, 'Soho': 9, 'XMM Newton': 10
}

class PyTorchSPARKDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, class_map, root_dir=DATA_FOLDER, split="train"):
        self.df = pd.read_csv(csv_path)
        self.class_map = class_map
        self.root_dir = root_dir
        self.split = split

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sat = row['Class']
        img_name = row['Image name']
        mask_name = row['Mask name']
        bbox = literal_eval(row['Bounding box'])

        # Load image
        img_path = f"{self.root_dir}/images/{sat}/{self.split}/{img_name}"
        pil_image = Image.open(img_path).convert("RGB")
        w, h = pil_image.size
        image = torch.ByteTensor(torch.ByteStorage.from_buffer(pil_image.tobytes()))
        image = image.view(h, w, 3).permute(2, 0, 1).float() / 255.0

        # Load mask
        mask_path = f"{self.root_dir}/mask/{sat}/{self.split}/{mask_name}"
        pil_mask = Image.open(mask_path).convert("L")
        w2, h2 = pil_mask.size
        mask = torch.ByteTensor(torch.ByteStorage.from_buffer(pil_mask.tobytes()))
        mask = mask.view(h2, w2)[None].float() / 255.0

        x1, y1, x2, y2 = bbox
        boxes = torch.tensor([[x1, y1, x2, y2]], dtype=torch.float32)
        labels = torch.tensor([self.class_map[sat]], dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": mask,
            "image_id": torch.tensor([idx]),
            "area": torch.tensor([(x2 - x1)*(y2 - y1)], dtype=torch.float32),
            "iscrowd": torch.tensor([0])
        }
        return image, target


# Load Model

In [9]:
def create_model(num_classes=11):
    model = models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # Update classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Update mask head
    in_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_mask, 256, num_classes)
    return model


# Data Loaders

In [10]:
from torch.utils.data import DataLoader

train_ds = PyTorchSPARKDataset(f"{SUBSET_DIR}/train_subset_{fraction}.csv", class_map, split="train")
val_ds   = PyTorchSPARKDataset(f"{SUBSET_DIR}/val_subset_{fraction}.csv", class_map, split="val")

def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)

len(train_ds), len(val_ds)

(600, 200)

# Debug  Single Training Step

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = create_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# run 1 batch
images, targets = next(iter(train_loader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

model.train()
loss_dict = model(images, targets)
loss = sum(loss_dict.values())

optimizer.zero_grad()
loss.backward()
optimizer.step()

loss.item(), loss_dict


  image = torch.ByteTensor(torch.ByteStorage.from_buffer(pil_image.tobytes()))


(4.218687534332275,
 {'loss_classifier': tensor(2.4092, grad_fn=<NllLossBackward0>),
  'loss_box_reg': tensor(0.1432, grad_fn=<DivBackward0>),
  'loss_mask': tensor(1.6195, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  'loss_objectness': tensor(0.0436, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  'loss_rpn_box_reg': tensor(0.0031, grad_fn=<DivBackward0>)})

# Small Batch Training

In [12]:
EPOCHS = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

model = create_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss = 0

    for images, targets in tqdm(train_loader):
        images = [i.to(device) for i in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch} Train Loss: {total_loss/len(train_loader):.4f}")

# Save state dict
torch.save(model.state_dict(), "final_state_dict.pth")
print("Saved → final_state_dict.pth")


  8%|▊         | 12/150 [01:24<16:13,  7.06s/it]


KeyboardInterrupt: 