# Method 1: U-Net Microplastic Segmentation - Kaggle Notebook

**Complete pipeline for pixel-level microplastic segmentation**

- **Task**: Binary segmentation (microplastic vs background)
- **Model**: U-Net with Dice + BCE loss
- **Dataset**: Microplastics from marine/ocean environments
- **Framework**: PyTorch

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -q albumentations opencv-python-headless scikit-learn matplotlib seaborn tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
from tqdm import tqdm
import warnings
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score

warnings.filterwarnings('ignore')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 2. Dataset Setup

**You need to add your dataset to this Kaggle notebook:**

1. Click "+ Add data" button on the right panel
2. Search for: `microplastic` or your uploaded dataset name
3. Click "Add"

**Recommended datasets on Kaggle:**
- `imtkaggleteam/microplastic-dataset-for-computer-vision`
- Or upload your own DeepParticle dataset

**Update `DATASET_PATH` below to match your dataset!**

In [None]:
# ============================================================
# DATASET CONFIGURATION - AUTO-DETECTION
# ============================================================

import os
import glob

# Auto-detect dataset path
print("Searching for dataset...")
print("="*60)

# List all available datasets
available_datasets = os.listdir('/kaggle/input/') if os.path.exists('/kaggle/input/') else []
print(f"Available datasets in /kaggle/input/:")
for ds in available_datasets:
    print(f"  - {ds}")

# Try to find the dataset automatically
DATASET_PATH = None
IMAGE_DIR = None
MASK_DIR = None

# Check for common dataset structures
for dataset_name in available_datasets:
    base_path = f"/kaggle/input/{dataset_name}"
    
    # Check for MACRO structure (your uploaded dataset)
    if os.path.exists(f"{base_path}/MACRO/raw_img"):
        DATASET_PATH = f"{base_path}/MACRO"
        IMAGE_DIR = f"{DATASET_PATH}/raw_img"
        MASK_DIR = None  # MACRO has TSV annotations, not masks
        print(f"\nDetected MACRO dataset!")
        break
    
    # Check for MICRO structure
    elif os.path.exists(f"{base_path}/MICRO/raw_img"):
        DATASET_PATH = f"{base_path}/MICRO"
        IMAGE_DIR = f"{DATASET_PATH}/raw_img"
        MASK_DIR = f"{DATASET_PATH}/annotation"
        print(f"\nDetected MICRO dataset!")
        break
    
    # Check for MESO structure
    elif os.path.exists(f"{base_path}/MESO/raw_img"):
        DATASET_PATH = f"{base_path}/MESO"
        IMAGE_DIR = f"{DATASET_PATH}/raw_img"
        MASK_DIR = f"{DATASET_PATH}/annotation"
        print(f"\nDetected MESO dataset!")
        break
    
    # Check for simple images/masks structure
    elif os.path.exists(f"{base_path}/images"):
        DATASET_PATH = base_path
        IMAGE_DIR = f"{DATASET_PATH}/images"
        MASK_DIR = f"{DATASET_PATH}/masks" if os.path.exists(f"{DATASET_PATH}/masks") else None
        print(f"\nDetected standard images/masks dataset!")
        break
    
    # Check for raw_img in root
    elif os.path.exists(f"{base_path}/raw_img"):
        DATASET_PATH = base_path
        IMAGE_DIR = f"{DATASET_PATH}/raw_img"
        MASK_DIR = None
        print(f"\nDetected raw_img dataset!")
        break

# Manual override option (uncomment and modify if auto-detection fails)
# DATASET_PATH = "/kaggle/input/your-dataset-name"
# IMAGE_DIR = f"{DATASET_PATH}/MACRO/raw_img"
# MASK_DIR = None

# ============================================================

# Create output directories
os.makedirs('/kaggle/working/models', exist_ok=True)
os.makedirs('/kaggle/working/results', exist_ok=True)

# Verify paths
print("\n" + "="*60)
print("DATASET CONFIGURATION")
print("="*60)

if DATASET_PATH is None or IMAGE_DIR is None:
    print("\n⚠️  ERROR: Could not auto-detect dataset!")
    print("\nPlease set paths manually by uncommenting these lines above:")
    print("  DATASET_PATH = '/kaggle/input/your-dataset-name'")
    print("  IMAGE_DIR = f'{DATASET_PATH}/MACRO/raw_img'")
    print("\nAvailable datasets:", available_datasets)
else:
    print(f"\n✓ Dataset path: {DATASET_PATH}")
    print(f"✓ Image directory: {IMAGE_DIR}")
    print(f"✓ Image directory exists: {os.path.exists(IMAGE_DIR)}")
    
    if IMAGE_DIR and os.path.exists(IMAGE_DIR):
        # Count images
        image_count = 0
        for ext in ['*.jpg', '*.jpeg', '*.JPG', '*.JPEG', '*.png', '*.PNG', '*.tif', '*.tiff']:
            image_count += len(glob.glob(os.path.join(IMAGE_DIR, ext)))
        
        print(f"✓ Total images found: {image_count}")
        
        # Show first few images
        all_images = []
        for ext in ['*.jpg', '*.jpeg', '*.JPG', '*.JPEG', '*.png', '*.PNG']:
            all_images.extend(glob.glob(os.path.join(IMAGE_DIR, ext)))
        
        if all_images:
            print(f"\nFirst 5 images:")
            for img in sorted(all_images)[:5]:
                print(f"  - {os.path.basename(img)}")
        
        # Check for masks
        if MASK_DIR and os.path.exists(MASK_DIR):
            print(f"\n✓ Mask directory: {MASK_DIR}")
            print(f"✓ Masks exist: True")
        else:
            print(f"\n⚠️  No mask directory found")
            print("  Will use dummy masks for training (demo mode)")
            print("  For real training, you need segmentation masks")
    else:
        print(f"\n⚠️  ERROR: Image directory does not exist!")
        print(f"  Path: {IMAGE_DIR}")

print("="*60)

## 3. U-Net Model Definition

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Encoder
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Decoder
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))
        
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        
    def forward(self, x):
        skip_connections = []
        
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
                x = transforms.functional.resize(x, size=skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)

# Initialize model
model = UNet(in_channels=3, out_channels=1).to(device)
print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

## 4. Dataset Class and Data Loading

In [None]:
class MicroplasticsDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None, transform=None, image_size=(256, 256)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_size = image_size
        
        self.images = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff']:
            self.images.extend(glob.glob(os.path.join(image_dir, ext)))
        
        self.images = sorted(self.images)
        print(f"Found {len(self.images)} images")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask if available
        if self.mask_dir:
            img_name = os.path.basename(img_path)
            mask_path = os.path.join(self.mask_dir, img_name)
            
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                mask = (mask > 0).astype(np.float32)
            else:
                mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
        else:
            # Create dummy mask for demo
            mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
        
        # Resize
        image = cv2.resize(image, self.image_size)
        mask = cv2.resize(mask, self.image_size)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        return image, mask.unsqueeze(0)

# Data transforms
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=45, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Create dataset
full_dataset = MicroplasticsDataset(IMAGE_DIR, MASK_DIR if os.path.exists(MASK_DIR) else None, transform=train_transform)

# Split dataset
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

val_dataset.dataset.transform = val_transform
test_dataset.dataset.transform = val_transform

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# Data loaders
BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## 5. Loss Functions and Metrics

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2 * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice_loss = self.dice(pred, target)
        return self.alpha * bce_loss + (1 - self.alpha) * dice_loss

def calculate_metrics(pred, target, threshold=0.5):
    pred_binary = (torch.sigmoid(pred) > threshold).float()
    pred_np = pred_binary.cpu().numpy().flatten().astype(int)
    target_np = target.cpu().numpy().flatten().astype(int)
    
    iou = jaccard_score(target_np, pred_np, zero_division=0)
    f1 = f1_score(target_np, pred_np, zero_division=0)
    precision = precision_score(target_np, pred_np, zero_division=0)
    recall = recall_score(target_np, pred_np, zero_division=0)
    
    intersection = (pred_binary * target).sum()
    dice = (2 * intersection) / (pred_binary.sum() + target.sum() + 1e-8)
    
    return {
        'iou': iou,
        'dice': dice.item(),
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

criterion = CombinedLoss(alpha=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

## 6. Training

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_metrics = {'iou': 0, 'dice': 0, 'f1': 0, 'precision': 0, 'recall': 0}
    
    for images, masks in tqdm(dataloader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        metrics = calculate_metrics(outputs, masks)
        for key in total_metrics:
            total_metrics[key] += metrics[key]
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {key: val / len(dataloader) for key, val in total_metrics.items()}
    
    return avg_loss, avg_metrics

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_metrics = {'iou': 0, 'dice': 0, 'f1': 0, 'precision': 0, 'recall': 0}
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validation"):
            images, masks = images.to(device), masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            total_loss += loss.item()
            metrics = calculate_metrics(outputs, masks)
            for key in total_metrics:
                total_metrics[key] += metrics[key]
    
    avg_loss = total_loss / len(dataloader)
    avg_metrics = {key: val / len(dataloader) for key, val in total_metrics.items()}
    
    return avg_loss, avg_metrics

# Training loop
NUM_EPOCHS = 20
best_val_loss = float('inf')
train_losses, val_losses = [], []
train_metrics_history, val_metrics_history = [], []

print("Starting training...")
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_metrics_history.append(train_metrics)
    
    val_loss, val_metrics = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_metrics_history.append(val_metrics)
    
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Train IoU: {train_metrics['iou']:.4f} | Val IoU: {val_metrics['iou']:.4f}")
    print(f"Train Dice: {train_metrics['dice']:.4f} | Val Dice: {val_metrics['dice']:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '/kaggle/working/models/best_unet_model.pth')
        print("Best model saved!")

print("\nTraining completed!")

## 7. Visualize Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(train_losses, label='Train')
axes[0].plot(val_losses, label='Val')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True)

train_ious = [m['iou'] for m in train_metrics_history]
val_ious = [m['iou'] for m in val_metrics_history]
axes[1].plot(train_ious, label='Train')
axes[1].plot(val_ious, label='Val')
axes[1].set_title('IoU')
axes[1].legend()
axes[1].grid(True)

train_dice = [m['dice'] for m in train_metrics_history]
val_dice = [m['dice'] for m in val_metrics_history]
axes[2].plot(train_dice, label='Train')
axes[2].plot(val_dice, label='Val')
axes[2].set_title('Dice')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig('/kaggle/working/results/training_curves.png')
plt.show()

# Visualize predictions
def visualize_predictions(model, dataloader, num_samples=4):
    model.eval()
    fig, axes = plt.subplots(3, num_samples, figsize=(20, 12))
    
    data_iter = iter(dataloader)
    images, masks = next(data_iter)
    
    with torch.no_grad():
        outputs = model(images.to(device))
        predictions = torch.sigmoid(outputs)
    
    for i in range(min(num_samples, len(images))):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        mask = masks[i].squeeze().numpy()
        pred = predictions[i].squeeze().cpu().numpy()
        
        axes[0, i].imshow(img)
        axes[0, i].set_title(f'Image {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask, cmap='gray')
        axes[1, i].set_title(f'Ground Truth {i+1}')
        axes[1, i].axis('off')
        
        axes[2, i].imshow(pred, cmap='gray')
        axes[2, i].set_title(f'Prediction {i+1}')
        axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.savefig('/kaggle/working/results/predictions.png')
    plt.show()

visualize_predictions(model, val_loader)

## 8. Final Results

In [None]:
# Load best model
model.load_state_dict(torch.load('/kaggle/working/models/best_unet_model.pth'))

# Final evaluation
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loss, test_metrics = validate_epoch(model, test_loader, criterion, device)

print("\nFinal Test Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test IoU: {test_metrics['iou']:.4f}")
print(f"Test Dice: {test_metrics['dice']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")

print("\nModel saved to: /kaggle/working/models/best_unet_model.pth")
print("Results saved to: /kaggle/working/results/")