libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import average_precision_score, roc_curve, auc, classification_report
import cv2
import torch.optim as optim
from tqdm import tqdm
import math

Models

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        return F_loss

class CarotidDataset(Dataset):
    def __init__(self, us_images_dir, mask_images_dir, transform=None):
        self.us_images = sorted([os.path.join(us_images_dir, fname) for fname in os.listdir(us_images_dir)])
        self.mask_images = sorted([os.path.join(mask_images_dir, fname) for fname in os.listdir(mask_images_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.us_images)

    def __getitem__(self, idx):
        us_image = Image.open(self.us_images[idx]).convert('L')
        mask = Image.open(self.mask_images[idx]).convert('L')

        if self.transform:
            us_image = self.transform(us_image)
            mask = self.transform(mask)
        return us_image, mask

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 
                                 stride=stride, padding=padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.depthwise(x))
        x = self.pointwise(x)
        return x

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.encoder = nn.Conv2d(in_channels, 2, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.decoder = nn.ConvTranspose2d(2, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.encoder(x))
        x = self.pool(x)
        x = self.decoder(x)
        return x

class SimpleViT(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, patch_size=16):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = 32
        self.patch_embed = nn.Conv2d(in_channels, self.embed_dim, 
                                    kernel_size=patch_size, stride=patch_size)
        self.transformer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim, 
            nhead=2,
            dim_feedforward=64
        )
        self.reconstruct = nn.ConvTranspose2d(
            self.embed_dim, out_channels,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.transformer(x)
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.reconstruct(x)
        return x

class Patches(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size
        
    def forward(self, x):
        B, C, H, W = x.shape
        assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
        
        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size
        num_patches = num_patches_h * num_patches_w
        
        x = x.unfold(2, self.patch_size, self.patch_size)
        x = x.unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(B, C, num_patches, self.patch_size * self.patch_size)
        x = x.permute(0, 2, 1, 3)
        x = x.view(B, num_patches, -1)
        return x

class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim, patch_size):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.projection = nn.Linear(patch_size * patch_size * 1, projection_dim)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, projection_dim))
        
    def forward(self, x):
        x = self.projection(x)
        x = x + self.position_embedding
        return x

class TransformerBlock(nn.Module):
    def __init__(self, projection_dim, num_heads, mlp_ratio=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(projection_dim)
        self.attn = nn.MultiheadAttention(projection_dim, num_heads, dropout=0.1)
        self.norm2 = nn.LayerNorm(projection_dim)
        self.mlp = nn.Sequential(
            nn.Linear(projection_dim, projection_dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim * mlp_ratio, projection_dim),
            nn.Dropout(0.1)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class FusionViTNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, image_size=256):
        super().__init__()
        self.patch_size = 16
        self.projection_dim = 64
        self.num_heads = 4
        self.transformer_layers = 8
        
        assert image_size % self.patch_size == 0, "Image size must be divisible by patch size"
        self.num_patches = (image_size // self.patch_size) ** 2
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1, stride=2),
            nn.ReLU(),
            DepthwiseSeparableConv(16, 32, stride=2),
            nn.ReLU(),
            DepthwiseSeparableConv(32, 32),
            nn.ReLU()
        )
        
        # ViT components
        self.patches = Patches(self.patch_size)
        self.patch_encoder = PatchEncoder(
            num_patches=self.num_patches,
            projection_dim=self.projection_dim,
            patch_size=self.patch_size
        )
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(self.projection_dim, self.num_heads) 
            for _ in range(self.transformer_layers)
        ])
        self.vit_norm = nn.LayerNorm(self.projection_dim)
        
        # Fusion components
        self.fusion_conv = nn.Conv2d(32 + self.projection_dim, 32, kernel_size=1)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(16, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == W == 256, f"Input size must be 256x256, got {H}x{W}"
        
        # CNN pathway
        cnn_features = self.encoder(x)
        
        # ViT pathway
        patches = self.patches(x)
        encoded_patches = self.patch_encoder(patches)
        
        for block in self.transformer_blocks:
            encoded_patches = block(encoded_patches)
        vit_features = self.vit_norm(encoded_patches)
        
        h = w = int(math.sqrt(self.num_patches))
        vit_features = vit_features.transpose(1, 2).view(B, self.projection_dim, h, w)
        vit_features = F.interpolate(vit_features, size=cnn_features.shape[2:], mode='bilinear')
        
        fused_features = torch.cat([cnn_features, vit_features], dim=1)
        fused_features = self.fusion_conv(fused_features)
        
        out = self.decoder(fused_features)
        return out


Saving

In [None]:

def create_run_folder(base_dir='checkpoints'):
    os.makedirs(base_dir, exist_ok=True)
    existing_runs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith('run')]
    run_numbers = [int(d.replace('run', '')) for d in existing_runs if d.replace('run', '').isdigit()]
    next_run_number = max(run_numbers) + 1 if run_numbers else 1
    run_folder = os.path.join(base_dir, f'run{next_run_number}')
    os.makedirs(run_folder, exist_ok=True)
    return run_folder

def save_checkpoint(model, optimizer, epoch, val_loss, filename):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }, filename)



Calculate metrics

In [None]:
def calculate_metrics(model, val_loader, device):
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for images, targets in val_loader:
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            if outputs.shape[2:] != targets.shape[2:]:
                outputs = F.interpolate(outputs, size=targets.shape[2:], mode='bilinear')
            
            # Convert to probabilities and threshold at 0.5
            preds = torch.sigmoid(outputs).cpu().numpy()
            targets = targets.cpu().numpy()
            
            # Binarize the targets if they're not already binary
            if np.max(targets) > 1 or np.min(targets) < 0:
                targets = (targets > 0.5).astype(np.float32)
            
            all_preds.append(preds)
            all_targets.append(targets)
    
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    
    # Flatten the arrays for metric calculation
    preds_flat = all_preds.flatten()
    targets_flat = all_targets.flatten()
    
    # Binarize predictions for metrics that need binary inputs
    binary_preds = (preds_flat > 0.5).astype(np.float32)
    binary_targets = (targets_flat > 0.5).astype(np.float32)
    
    epsilon = 1e-7
    tp = np.sum((binary_preds == 1) & (binary_targets == 1))
    tn = np.sum((binary_preds == 0) & (binary_targets == 0))
    fp = np.sum((binary_preds == 1) & (binary_targets == 0))
    fn = np.sum((binary_preds == 0) & (binary_targets == 1))
    
    # Calculate metrics
    metrics = {
        'IoU': tp / (tp + fp + fn + epsilon),
        'Dice': 2 * tp / (2 * tp + fp + fn + epsilon),
        'Precision': tp / (tp + fp + epsilon),
        'Sensitivity': tp / (tp + fn + epsilon),
        'Specificity': tn / (tn + fp + epsilon),
        'MSE': np.mean((preds_flat - targets_flat) ** 2),
    }
    
    # Only calculate mAP if we have proper binary targets
    if len(np.unique(binary_targets)) >= 2:  # Need both classes present
        metrics['mAP'] = average_precision_score(binary_targets, preds_flat)
    else:
        metrics['mAP'] = float('nan')
    
    return metrics


Visualization

In [None]:
def plot_training_results(model, epoch, train_losses, val_losses, run_folder):
    plt.plot(range(1, epoch + 1), train_losses, label='Train Loss')
    plt.plot(range(1, epoch + 1), val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss for {model.__class__.__name__}')
    plt.legend()
    plt.savefig(os.path.join(run_folder, 'training_loss.png'))
    plt.close()



Training

In [None]:
# Training configuration
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Add this
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# Initialize datasets and loaders
train_dataset = CarotidDataset(
    us_images_dir='Common Carotid Artery Ultrasound Images/US images/train',
    mask_images_dir='Common Carotid Artery Ultrasound Images/Expert mask images/train',
    transform=transform,
)
val_dataset = CarotidDataset(
    us_images_dir='Common Carotid Artery Ultrasound Images/US images/val',
    mask_images_dir='Common Carotid Artery Ultrasound Images/Expert mask images/val',
    transform=transform,
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# Training parameters
epochs = 100
learning_rate_simpleunet = 1e-3
learning_rate_vit = 1e-4
learning_rate_fusionvitnet = 1e-4
criterion = FocalLoss(alpha=0.8, gamma=2.0)
scaler = torch.amp.GradScaler()


def train_model(model, optimizer, train_loader, val_loader, num_epochs=epochs, run_folder=None):
    best_val_loss = float('inf')
    train_losses, val_losses = [], []
    model_folder = os.path.join(run_folder, model.__class__.__name__.lower())
    os.makedirs(model_folder, exist_ok=True)
    
    # Create CSV file to store epoch metrics
    metrics_file = os.path.join(model_folder, 'epoch_metrics.csv')
    metrics_header = ['Epoch', 'Train_Loss', 'Val_Loss', 'IoU', 'Dice', 'Precision', 
                     'Sensitivity', 'Specificity', 'MSE', 'mAP']
    with open(metrics_file, 'w') as f:
        f.write(','.join(metrics_header) + '\n')
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, masks in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images, masks = images.to(device), masks.to(device).float()
            if masks.ndim == 4 and masks.shape[1] == 3:
                masks = masks.mean(dim=1, keepdim=True)
                
            optimizer.zero_grad()
            with torch.amp.autocast(device_type='cuda'):
                outputs = model(images)
                outputs_resized = F.interpolate(outputs, size=masks.shape[2:], mode='bilinear')
                masks_resized = F.interpolate(masks, size=outputs_resized.shape[2:], mode='nearest')
                loss = criterion(outputs_resized, masks_resized)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
        
        # Calculate validation loss and metrics
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device).float()
                if masks.ndim == 4 and masks.shape[1] == 3:
                    masks = masks.mean(dim=1, keepdim=True)

                with autocast():
                    outputs = model(images)
                    outputs_resized = F.interpolate(outputs, size=masks.shape[2:], mode='bilinear')
                    masks_resized = F.interpolate(masks, size=outputs_resized.shape[2:], mode='nearest')
                    val_loss += criterion(outputs_resized, masks_resized).item()
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        # Calculate metrics for this epoch
        epoch_metrics = calculate_metrics(model, val_loader, device)
        
        # Save metrics to CSV
        train_loss = running_loss / len(train_loader)
        metrics_row = [
            epoch + 1,
            train_loss,
            val_loss,
            epoch_metrics['IoU'],
            epoch_metrics['Dice'],
            epoch_metrics['Precision'],
            epoch_metrics['Sensitivity'],
            epoch_metrics['Specificity'],
            epoch_metrics['MSE'],
            epoch_metrics['mAP']
        ]
        with open(metrics_file, 'a') as f:
            f.write(','.join(map(str, metrics_row)) + '\n')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, epoch, best_val_loss,
                          os.path.join(model_folder, f'{model.__class__.__name__}_best.pth'))
        
        train_losses.append(train_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Metrics: {epoch_metrics}")
        plot_training_results(model, epoch + 1, train_losses, val_losses, model_folder)
    
    return train_losses, val_losses



Main

In [None]:
def main():
    run_folder = create_run_folder()
    
    # Initialize models
    models = {
        'SimpleUNet': SimpleUNet().to(device),
        'SimpleViT': SimpleViT().to(device),
        'FusionViTNet': FusionViTNet().to(device)
    }
    
    optimizers = {
        'SimpleUNet': optim.Adam(models['SimpleUNet'].parameters(), lr=learning_rate_simpleunet),
        'SimpleViT': optim.Adam(models['SimpleViT'].parameters(), lr=learning_rate_vit),
        'FusionViTNet': optim.Adam(models['FusionViTNet'].parameters(), 
                                 lr=learning_rate_fusionvitnet, weight_decay=1e-5)
    }
    
    # Train and evaluate each model
    results = {}
    for name, model in models.items():
        print(f"\nTraining {name} Model...")
        train_losses, val_losses = train_model(
            model, optimizers[name], train_loader, val_loader, epochs, run_folder)
        
        metrics = calculate_metrics(model, val_loader, device)
        results[name] = metrics
        print(f"{name} Metrics: {metrics}")
    
    # Save final results
    df = pd.DataFrame(results).T.round(4)
    df.to_csv(os.path.join(run_folder, 'final_metrics.csv'))
    
    print("\nTraining and evaluation complete!")

if __name__ == "__main__":
    main()