In [None]:
# ============================================================
# 3. CREATE DATASET AND DATALOADERS
# ============================================================

images_dir = "/content/oil_spill_dataset/images"
masks_dir  = "/content/oil_spill_dataset/masks"

full_dataset = OilSpillDataset(images_dir, masks_dir, img_size=(256, 256), augment=True)

# train/val split
val_ratio = 0.2
val_size = int(len(full_dataset) * val_ratio)
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")


NameError: name 'OilSpillDataset' is not defined

In [None]:
# ============================================================
# 2. CUSTOM DATASET CLASS FOR OIL SPILL SEGMENTATION
# ============================================================
class OilSpillDataset(Dataset):
    def __init__(self, images_dir, masks_dir, img_size=(256, 256), augment=False):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.img_size = img_size
        self.augment = augment

        self.image_paths = sorted(glob.glob(os.path.join(images_dir, "*")))
        self.mask_paths = sorted(glob.glob(os.path.join(masks_dir, "*")))
        assert len(self.image_paths) == len(self.mask_paths), "Images & masks count mismatch"

        # basic transforms
        self.to_tensor = T.ToTensor()
        self.resize_img = T.Resize(img_size, interpolation=T.InterpolationMode.BILINEAR)
        self.resize_mask = T.Resize(img_size, interpolation=T.InterpolationMode.NEAREST)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        # load image (RGB)
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # grayscale mask (0/255)

        # resize
        img = self.resize_img(img)
        mask = self.resize_mask(mask)

        # optional simple augmentation
        if self.augment:
            if random.random() < 0.5:
                img = T.functional.hflip(img)
                mask = T.functional.hflip(mask)
            if random.random() < 0.5:
                img = T.functional.vflip(img)
                mask = T.functional.vflip(mask)

        img = self.to_tensor(img)          # [3, H, W], float32 in [0,1]
        mask = self.to_tensor(mask)        # [1, H, W], float32 in [0,1]
        mask = (mask > 0.5).float()        # binarize (0 or 1)

        return img, mask


In [None]:
# ============================================================
# 5. METRICS, TRAINING AND VALIDATION FUNCTIONS
# ============================================================
def compute_iou(preds, targets, threshold=0.5, eps=1e-6):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum(dim=(1,2,3))
    union = preds.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.mean().item()

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    for imgs, masks in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)

        with torch.no_grad():
            probs = torch.sigmoid(logits)
            running_iou += compute_iou(probs, masks) * imgs.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_iou = running_iou / len(loader.dataset)
    return epoch_loss, epoch_iou

def validate_one_epoch(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            masks = masks.to(device)

            logits = model(imgs)
            loss = criterion(logits, masks)

            running_loss += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits)
            running_iou += compute_iou(probs, masks) * imgs.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_iou = running_iou / len(loader.dataset)
    return epoch_loss, epoch_iou


In [None]:
# ============================================================
# 6. TRAINING LOOP
# ============================================================
num_epochs = 20
best_val_iou = 0.0

for epoch in range(1, num_epochs + 1):
    train_loss, train_iou = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_iou = validate_one_epoch(model, val_loader, criterion)

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), "best_unet_oil_spill.pth")

    print(f"Epoch {epoch:02d}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f}, IoU: {train_iou:.4f} | "
          f"Val Loss: {val_loss:.4f}, IoU: {val_iou:.4f}")

print("âœ… Training finished. Best Val IoU:", best_val_iou)


In [None]:
# ============================================================
# 7. VISUALIZE PREDICTIONS ON VALIDATION SAMPLE
# ============================================================
def show_sample_prediction(model, dataset, idx=None):
    model.eval()
    if idx is None:
        idx = random.randint(0, len(dataset)-1)

    img, mask = dataset[idx]
    with torch.no_grad():
        logits = model(img.unsqueeze(0).to(device))
        probs = torch.sigmoid(logits)
        pred_mask = (probs > 0.5).float().cpu().squeeze(0).squeeze(0).numpy()

    img_np = img.permute(1, 2, 0).numpy()
    mask_np = mask.squeeze(0).numpy()

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.title("Original Image")
    plt.imshow(img_np)
    plt.axis("off")

    plt.subplot(1,3,2)
    plt.title("Ground Truth Mask")
    plt.imshow(mask_np, cmap="gray")
    plt.axis("off")

    plt.subplot(1,3,3)
    plt.title("Predicted Mask")
    plt.imshow(pred_mask, cmap="gray")
    plt.axis("off")

    plt.show()


show_sample_prediction(model, val_dataset)
