In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: d:\workspace\iscat


In [2]:
from src.data_processing.dataset import iScatDataset
from src.data_processing.utils import Utils
data_path = os.path.join('data', 'iScat', 'Data', '2024_11_11', 'Metasurface', 'Chip_02')
image_paths,target_paths = Utils.get_data_paths(data_path)

In [3]:
train_dataset = iScatDataset(image_paths[:-1], target_paths[:-1], preload_image=True)
valid_dataset = iScatDataset([image_paths[-1]],[target_paths[-1]],preload_image=True,apply_augmentation=False)

Loading surface images to Memory: 100%|██████████| 4/4 [00:44<00:00, 11.02s/it]
Creating Masks: 100%|██████████| 4/4 [00:00<00:00, 73.58it/s]
Loading surface images to Memory: 100%|██████████| 1/1 [00:14<00:00, 14.07s/it]
Creating Masks: 100%|██████████| 1/1 [00:00<00:00, 77.13it/s]


In [4]:
from torch.utils.data import DataLoader, Dataset
def create_dataloaders(train_dataset, test_dataset, batch_size=4):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader
train_loader, val_loader = create_dataloaders(train_dataset, valid_dataset, batch_size=4)

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np

# Modify the model to support multiclass segmentation
class MultiClassUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=3, init_features=32):
        super(MultiClassUNet, self).__init__()
        
        # Load the pretrained model and modify the final layer
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 
                               in_channels=in_channels, 
                               out_channels=1, 
                               init_features=init_features, 
                               pretrained=True)
        
        # Replace the final convolution layer to match number of classes
        model.conv = nn.Conv2d(init_features, num_classes, kernel_size=1)
        
        self.model = model
    
    def forward(self, x):
        return self.model(x)

def train_multiclass_segmentation(model, train_loader, val_loader, num_epochs=50, learning_rate=1e-4):
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler (optional)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)
    
    # Tensorboard for logging
    writer = SummaryWriter('runs/multiclass_segmentation')
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for batch_idx, (images, masks) in enumerate(train_loader):
            # Move to device
            images = images.to(device)
            masks = masks.to(device)
            
            # Ensure masks are long tensor for CrossEntropyLoss
            # Assumes masks are one-hot, convert to class indices
            masks = torch.argmax(masks, dim=1).long()
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Tensorboard logging
            writer.add_scalar('Training Loss', loss.item(), 
                               epoch * len(train_loader) + batch_idx)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                # Ensure masks are long tensor for CrossEntropyLoss
                masks = torch.argmax(masks, dim=1).long()
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
        
        # Average losses
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Tensorboard logging
        writer.add_scalar('Validation Loss', val_loss, epoch)
        
        # Print progress
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}')
        
        # Model checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss
            }, 'best_multiclass_unet_model.pth')
    
    writer.close()
    return model

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchmetrics
import numpy as np

# Modify the model to support multiclass segmentation
class MultiClassUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=3, init_features=32):
        super(MultiClassUNet, self).__init__()
        
        # Load the pretrained model and modify the final layer
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 
                               in_channels=in_channels, 
                               out_channels=1, 
                               init_features=init_features, 
                               pretrained=True)
        
        # Replace the final convolution layer to match number of classes
        model.conv = nn.Conv2d(init_features, num_classes, kernel_size=1)
        
        self.model = model
    
    def forward(self, x):
        return self.model(x)

# Dice Loss function
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, outputs, targets):
        # Apply softmax to logits to get class probabilities
        outputs = torch.softmax(outputs, dim=1)  # (batch_size, num_classes, H, W)
        
        # One-hot encode the targets (if they are not one-hot already)
        targets_onehot = torch.eye(outputs.size(1), device=outputs.device)[targets].permute(0, 3, 1, 2)
        
        intersection = (outputs * targets_onehot).sum(dim=(1, 2, 3))  # sum over batch, classes, height, and width
        union = outputs.sum(dim=(1, 2, 3)) + targets_onehot.sum(dim=(1, 2, 3))  # sum over batch, classes, height, and width
        dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
        
        return 1 - dice_score.mean()  # Loss is 1 - Dice score


def train_multiclass_segmentation(model, train_loader, val_loader, num_epochs=50, learning_rate=1e-4, loss_function='crossentropy'):
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Loss function choice
    if loss_function == 'crossentropy':
        criterion = nn.CrossEntropyLoss()
    elif loss_function == 'dice':
        criterion = DiceLoss()
    else:
        raise ValueError("Unsupported loss function. Choose either 'crossentropy' or 'dice'.")
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Metrics
    accuracy_metric = torchmetrics.Accuracy(task= "multiclass",num_classes=3).to(device)
    iou_metric = torchmetrics.JaccardIndex(task="multiclass",num_classes=3).to(device)
    
    # Learning rate scheduler (optional)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)
    
    # Tensorboard for logging
    writer = SummaryWriter('runs/multiclass_segmentation')
    
    # Training loop
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_accuracy = 0.0
        train_iou = 0.0
        
        for batch_idx, (images, masks) in enumerate(train_loader):
            # Move to device
            images = images.to(device)
            masks = masks.to(device)
            
            # If masks are one-hot encoded, convert them to class indices
            if masks.ndimension() == 3:  # unbatched case
                masks = torch.argmax(masks, dim=0).unsqueeze(0)  # Convert from (C, H, W) to (1, H, W)
            else:
                masks = torch.argmax(masks, dim=1)  # Convert from one-hot to class indices for batched case
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update metrics
            accuracy_metric.update(outputs, masks)
            iou_metric.update(outputs, masks)
            
            train_loss += loss.item()
        
        # Compute average training loss, accuracy and IoU
        train_loss /= len(train_loader)
        train_accuracy = accuracy_metric.compute().item()
        train_iou = iou_metric.compute().item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                # If masks are one-hot encoded, convert them to class indices
                if masks.ndimension() == 3:  # unbatched case
                    masks = torch.argmax(masks, dim=0).unsqueeze(0)  # Convert from (C, H, W) to (1, H, W)
                else:
                    masks = torch.argmax(masks, dim=1)  # Convert from one-hot to class indices for batched case
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
        
        # Average validation loss
        val_loss /= len(val_loader)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Tensorboard logging
        writer.add_scalar('Training Loss', train_loss, epoch)
        writer.add_scalar('Validation Loss', val_loss, epoch)
        writer.add_scalar('Training Accuracy', train_accuracy, epoch)
        writer.add_scalar('Training IoU', train_iou, epoch)
        
        # Print progress
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {train_loss:.4f}, '
              f'Train Accuracy: {train_accuracy:.4f}, '
              f'Train IoU: {train_iou:.4f}, '
              f'Val Loss: {val_loss:.4f}')
        
        # Model checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss
            }, 'best_multiclass_unet_model.pth')
    
    writer.close()
    return model


In [6]:
# Example of training with CrossEntropyLoss
model = MultiClassUNet(in_channels=3, num_classes=3)

train_multiclass_segmentation(model, train_loader, val_loader, num_epochs=50, loss_function='dice')


Using cache found in C:\Users\zakar/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


Epoch [1/50], Train Loss: 0.6421, Train Accuracy: 0.5154, Train IoU: 0.1784, Val Loss: 0.5813
Epoch [2/50], Train Loss: 0.6105, Train Accuracy: 0.6677, Train IoU: 0.2295, Val Loss: 0.5688
Epoch [3/50], Train Loss: 0.5905, Train Accuracy: 0.7318, Train IoU: 0.2511, Val Loss: 0.5515
Epoch [4/50], Train Loss: 0.5725, Train Accuracy: 0.7695, Train IoU: 0.2645, Val Loss: 0.5382


KeyboardInterrupt: 

In [None]:
num_classes = 3

# Initialize the model
model = MultiClassUNet(in_channels=3, num_classes=num_classes)

# Train the model
trained_model = train_multiclass_segmentation(
    model, 
    train_loader, 
    val_loader, 
    num_epochs=50, 
    learning_rate=1e-4
)