In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [11]:
#The list of functions and where they can be integrated in the model are in the loss function report version 1
# Dice Loss
def dice_loss(pred, target, smooth=1e-5):
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return 1 - (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

# Boundary Loss
def boundary_loss(pred, target):
    sobel_filter = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().unsqueeze(0).unsqueeze(0)
    if torch.cuda.is_available():
        sobel_filter = sobel_filter.cuda()

    pred_grad = F.conv2d(pred, sobel_filter, padding=1)
    target_grad = F.conv2d(target, sobel_filter, padding=1)
    loss = F.mse_loss(pred_grad, target_grad)
    return loss

# KL Divergence Loss
def kl_divergence_loss(pred_distribution, target_distribution):
    pred_distribution = F.log_softmax(pred_distribution, dim=1)
    target_distribution = F.softmax(target_distribution, dim=1)
    return F.kl_div(pred_distribution, target_distribution, reduction='batchmean')

# Hausdorff Distance Loss
def hausdorff_distance_loss(pred, target):
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    return torch.abs(pred_flat - target_flat).mean()

# Shape Alignment Loss
def shape_alignment_loss(pred, target):
    return F.mse_loss(pred, target)


In [12]:
#Using a Simple CNN just for convenience while implementing the loss function
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv3(x))
        return x

In [13]:
# IoU Metric
def iou_metric(pred, target):
    pred_binary = (pred > 0.5).float()
    target_binary = (target > 0.5).float()
    intersection = (pred_binary * target_binary).sum(dim=(1, 2, 3))
    union = (pred_binary + target_binary).sum(dim=(1, 2, 3)) - intersection
    iou = intersection / (union + 1e-5)
    return iou.mean().item()

# Dummy Data (Planning to replace this soon using cardiac data)
def generate_dummy_data(batch_size=16, image_size=64):
    images = torch.rand(batch_size, 1, image_size, image_size)
    masks = torch.randint(0, 2, (batch_size, 1, image_size, image_size)).float()
    if torch.cuda.is_available():
        images, masks = images.cuda(), masks.cuda()
    return images, masks


In [14]:
# Training Loop
def train(model, optimizer, loss_fn):
    model.train()
    images, masks = generate_dummy_data()
    optimizer.zero_grad()
    outputs = model(images)
    loss = loss_fn(outputs, masks)
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation Loop
def evaluate(model, loss_fn):
    model.eval()
    images, masks = generate_dummy_data()
    with torch.no_grad():
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        iou = iou_metric(outputs, masks)
    return loss.item(), iou

In [15]:
model = SimpleCNN()
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Loss Functions to Test
loss_functions = {
    "Dice Loss": dice_loss,
    "Boundary Loss": boundary_loss,
    "KL Divergence Loss": kl_divergence_loss,
    "Hausdorff Distance Loss": hausdorff_distance_loss,
    "Shape Alignment Loss": shape_alignment_loss
}

# Evaluation for each loss function separately, next step is to do a grid search for the optimal compound loss function
results = {}
for loss_name, loss_fn in loss_functions.items():
    print(f"\nUsing {loss_name}...")
    train_loss = train(model, optimizer, loss_fn)
    eval_loss, eval_iou = evaluate(model, loss_fn)
    results[loss_name] = {"Train Loss": train_loss, "Eval Loss": eval_loss, "Eval IoU": eval_iou}
    print(f"Training Loss: {train_loss:.4f}")
    print(f"Evaluation Loss: {eval_loss:.4f}")
    print(f"Evaluation IoU: {eval_iou:.4f}")




Using Dice Loss...
Training Loss: 0.4980
Evaluation Loss: 0.4939
Evaluation IoU: 0.4872

Using Boundary Loss...
Training Loss: 2.9291
Evaluation Loss: 2.9906
Evaluation IoU: 0.4925

Using KL Divergence Loss...
Training Loss: 0.0000
Evaluation Loss: 0.0000
Evaluation IoU: 0.4949

Using Hausdorff Distance Loss...
Training Loss: 0.5001
Evaluation Loss: 0.5000
Evaluation IoU: 0.4962

Using Shape Alignment Loss...
Training Loss: 0.2512
Evaluation Loss: 0.2514
Evaluation IoU: 0.4939
