In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: d:\workspace\iscat


In [2]:
from src.data_processing.dataset import iScatDataset
from src.data_processing.utils import Utils
import torch
DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
data_path = os.path.join('dataset', '2024_11_11', 'Metasurface', 'Chip_02')
image_paths,target_paths = Utils.get_data_paths(data_path)

In [3]:
image_size=256
fluo_masks_indices=[0]
seg_method = "kmeans"
train_dataset = iScatDataset(image_paths[:-1], target_paths[:-1], preload_image=True,image_size = (image_size,image_size),apply_augmentation=True,normalize=False,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)
valid_dataset = iScatDataset([image_paths[-1]],[target_paths[-1]],preload_image=True,image_size = (image_size,image_size),apply_augmentation=False,normalize=False,device=DEVICE,fluo_masks_indices=fluo_masks_indices,seg_method=seg_method)

Loading TIFF images to Memory: 100%|██████████| 4/4 [00:00<00:00, 36.96it/s]
Loading TIFF images to Memory: 100%|██████████| 1/1 [00:00<00:00, 41.78it/s]


In [4]:
from torch.utils.data import DataLoader, Dataset
def create_dataloaders(train_dataset, test_dataset, batch_size=4):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader
train_loader, val_loader = create_dataloaders(train_dataset, valid_dataset, batch_size=32)

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from monai.losses import DiceLoss
from monai.metrics import MeanIoU
import numpy as np
from monai.networks.utils import one_hot

def z_score_normalize(images: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Normalize a batch of images using z-score normalization.

    Args:
        images (torch.Tensor): Input tensor of shape (N, 3, H, W), where
                               N is the batch size, 3 is the number of channels,
                               H and W are height and width.
        eps (float): A small value to avoid division by zero (default: 1e-8).

    Returns:
        torch.Tensor: Z-score normalized tensor of the same shape as `images`.
    """
    normalized_images = images/65535.0
    # Compute mean and std for each image in the batch
    mean = normalized_images .mean(dim=(1, 2, 3), keepdim=True)  # Shape: (N, 1, 1, 1)
    std = normalized_images .std(dim=(1, 2, 3), keepdim=True)    # Shape: (N, 1, 1, 1)
    
    # Perform z-score normalization
    normalized_images = (images - mean) / (std + eps)
    return normalized_images

class MultiClassUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1, init_features=32):
        super(MultiClassUNet, self).__init__()
        
        # Load the pretrained model and modify the final layer
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 
                               in_channels=in_channels, 
                               out_channels=1, 
                               init_features=init_features, 
                               pretrained=True)
        
        # Replace the final convolution layer to match number of classes
        model.conv = nn.Conv2d(init_features, num_classes, kernel_size=1)
        
        self.model = model
    
    def forward(self, x):
        return self.model(x)

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from monai.losses import DiceLoss
from monai.metrics import MeanIoU
import numpy as np
from enum import Enum
from typing import Dict, Any

class LossType(Enum):
    CROSSENTROPY = "crossentropy"
    DICE = "dice"
    COMBINED = "combined"

class UNetTrainer:
    def __init__(
        self,
        model: nn.Module,
        device: torch.device,
        loss_type: LossType,
        learning_rate: float = 1e-4,
        log_dir: str = "runs/unet_training"
    ):
        self.model = model.to(device)
        self.device = device
        self.loss_type = loss_type
        
        # Initialize loss functions
        self.ce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss(
            sigmoid=True,
            smooth_nr=1.0,
            smooth_dr=1.0,
            squared_pred=False,  # Changed to False for potentially better stability
            batch=True,
            reduction="mean"
        )
        
        # Initialize metrics
        self.miou_metric = MeanIoU(include_background=True, reduction="mean")
        
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        
        # Add learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=True
        )
        self.writer = SummaryWriter(log_dir)

    def compute_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute the specified loss
        predictions: (batch_size, 1, height, width)
        targets: (batch_size, height, width)
        """
        # Add channel dimension to targets if needed
        if len(targets.shape) == 3:
            targets = targets.unsqueeze(1)
            
        # Ensure targets are float
        targets = targets.float()
        
        if self.loss_type == LossType.CROSSENTROPY:
            return self.ce_loss(predictions, targets)
        elif self.loss_type == LossType.DICE:
            return self.dice_loss(predictions, targets)
        else:  # COMBINED
            ce = self.ce_loss(predictions, targets)
            dice = self.dice_loss(predictions, targets)
            return ce + dice

    def compute_metrics(self, predictions: torch.Tensor, targets: torch.Tensor) -> float:
        """
        Compute mean IoU metric
        predictions: (batch_size, 1, height, width)
        targets: (batch_size, height, width)
        """
        # Ensure predictions are binary
        pred_masks = (torch.sigmoid(predictions) > 0.5).float()
        
        # Add channel dimension to targets if needed
        if len(targets.shape) == 3:
            targets = targets.unsqueeze(1)
        
        # Convert predictions and targets to one-hot format (required by MeanIoU)
        # Shape: (batch_size, 2, height, width)
        pred_one_hot = torch.cat([1 - pred_masks, pred_masks], dim=1)
        target_one_hot = torch.cat([1 - targets, targets], dim=1)
        
        # Compute IoU
        metric = self.miou_metric(pred_one_hot, target_one_hot)
        
        # Return mean IoU (average across classes)
        return metric.mean().item()

    def train_epoch(self, train_loader, epoch: int) -> Dict[str, float]:
        self.model.train()
        total_loss = 0.0
        total_miou = 0.0
        batches = 0

        for batch_idx, (images, masks) in enumerate(train_loader):
            # Move data to device and normalize images
            images = z_score_normalize(images).to(self.device)
            masks = masks.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            predictions = self.model(images)  # Shape: (batch_size, 1, height, width)
            
            # Compute loss and backpropagate
            loss = self.compute_loss(predictions, masks)
            loss.backward()
            
            # Add gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Compute metrics
            miou = self.compute_metrics(predictions, masks)
            
            # Update running statistics
            total_loss += loss.item()
            total_miou += miou
            batches += 1
            
            # Log to TensorBoard
            step = epoch * len(train_loader) + batch_idx
            self.writer.add_scalar('Train/Loss', loss.item(), step)
            self.writer.add_scalar('Train/mIoU', miou, step)
            
            # Print batch progress
            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}, mIoU: {miou:.4f}')

        return {
            'loss': total_loss / batches,
            'miou': total_miou / batches
        }

    @torch.no_grad()
    def validate(self, val_loader) -> Dict[str, float]:
        self.model.eval()
        total_loss = 0.0
        total_miou = 0.0
        batches = 0

        for images, masks in val_loader:
            images = z_score_normalize(images).to(self.device)
            masks = masks.to(self.device)
            
            predictions = self.model(images)
            loss = self.compute_loss(predictions, masks)
            miou = self.compute_metrics(predictions, masks)
            
            total_loss += loss.item()
            total_miou += miou
            batches += 1

        return {
            'val_loss': total_loss / batches,
            'val_miou': total_miou / batches
        }

    def train(self, train_loader, val_loader, num_epochs: int):
            best_val_miou = 0.0
            patience = 10  # Early stopping patience
            no_improve = 0
            
            for epoch in range(num_epochs):
                # Training phase
                train_metrics = self.train_epoch(train_loader, epoch)
                
                # Validation phase
                val_metrics = self.validate(val_loader)
                
                # Update learning rate scheduler
                self.scheduler.step(val_metrics['val_loss'])
                
                # Log validation metrics
                self.writer.add_scalar('Validation/Loss', val_metrics['val_loss'], epoch)
                self.writer.add_scalar('Validation/mIoU', val_metrics['val_miou'], epoch)
                
                # Save best model
                if val_metrics['val_miou'] > best_val_miou:
                    best_val_miou = val_metrics['val_miou']
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_miou': val_metrics['val_miou'],
                    }, 'best_model.pth')
                    no_improve = 0
                else:
                    no_improve += 1
                
                # Early stopping
                if no_improve >= patience:
                    print(f'Early stopping triggered after {patience} epochs without improvement')
                    break
                
                print(f"Epoch {epoch+1}/{num_epochs}")
                print(f"Train Loss: {train_metrics['loss']:.4f}, Train mIoU: {train_metrics['miou']:.4f}")
                print(f"Val Loss: {val_metrics['val_loss']:.4f}, Val mIoU: {val_metrics['val_miou']:.4f}")
                print(f"Learning rate: {self.optimizer.param_groups[0]['lr']:.2e}")
                print("-" * 50)



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = MultiClassUNet(in_channels=3, num_classes=1, init_features=32)

# Initialize trainer
trainer = UNetTrainer(
    model=model,
    device=device,
    loss_type=LossType.DICE,  # or CROSSENTROPY or DICE
    learning_rate=1e-4
)


trainer.train(train_loader, val_loader, num_epochs=100)

Using cache found in C:\Users\zakar/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


Batch 0/13, Loss: 0.9900, mIoU: 0.0025


In [12]:
a= train_dataset[1][0]
print(z_score_normalize(a.unsqueeze(0)).max())

tensor(4.2531, device='cuda:0')
