In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score, f1_score, accuracy_score

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        # ... (Your UNet model code) ...

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        # ... (Your custom dataset code) ...

# Create a U-Net model with 1 input channel and 3 output channels for segmentation
model = UNet(in_channels=1, out_channels=3)
print(model)

# Set the model to evaluation mode
model.eval()

# Load and preprocess your input image
image_path = './Datasets/Dataset 1/data/BMMC_1.tif'
image = Image.open(image_path).convert('L')  # Convert to grayscale
transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)  # Add a batch dimension
print("Input shape:", input_image.shape)

# Forward pass through the model
with torch.no_grad():
    output = model(input_image)
print("Output shape:", output.shape)

# Define data directories
image_dir = './Datasets/Dataset 1/data'
mask_dir = './Datasets/Dataset 1/masks'

# Define a transform to convert images and masks to tensors
data_transform = transforms.Compose([transforms.ToTensor()])

# Create a custom dataset
dataset = CustomDataset(image_dir, mask_dir, transform=data_transform)

# Create a dataloader
batch_size = 1  # Adjust as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define a loss function (e.g., cross-entropy) and an optimizer (e.g., Adam)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training loop
num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, masks in dataloader:
        optimizer.zero_grad()

        inputs, masks = inputs.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, masks)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch + 1}/{num_epochs}] Current Loss: {loss.item()}")

        running_loss += loss.item()

    print()
    print(f"Epoch [{epoch + 1}/{num_epochs}] Average Loss: {running_loss / len(dataloader)}")
    print()
    
print("Training finished")

# Save the trained model
torch.save(model.state_dict(), "unet_segmentation_model.pth")

# Evaluation and Testing
model.eval()

# Evaluate the model
total_loss = 0.0
num_samples = 0

for inputs, masks in dataloader:  # Use a separate dataloader for evaluation
    with torch.no_grad():
        inputs, masks = inputs.to(device), masks.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, masks)
        total_loss += loss.item()
        num_samples += inputs.size(0)

average_loss = total_loss / num_samples
print(f"Average Evaluation Loss: {average_loss}")

# Create and save sample segmentation outputs
model.eval()

# Lists to store evaluation results for each sample
iou_scores = []
dice_scores = []
accuracy_scores = []

# Select a few images from the evaluation dataset
num_samples_to visualize = 3
visualize_dataloader = DataLoader(dataset, batch_size=num_samples_to_visualize, shuffle=True)

for inputs, masks in visualize_dataloader:
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = model(inputs)
        # Assuming outputs are class probabilities, you can obtain the predicted class
        predicted_masks = outputs.argmax(dim=1).cpu().numpy()

        for i in range(num_samples_to_visualize):
            plt.figure(figsize=(10, 4))
            plt.subplot(1, 3, 1)
            plt.title("Input Image")
            plt.imshow(inputs[i][0].cpu().numpy(), cmap='gray')

            plt.subplot(1, 3, 2)
            plt.title("Ground Truth Mask")
            plt.imshow(masks[i][0].cpu().numpy(), cmap='jet')

            plt.subplot(1, 3, 3)
            plt.title("Predicted Mask")
            plt.imshow(predicted_masks[i], cmap='jet')

            plt.show()

            # Calculate and store evaluation metrics for each sample
            iou = jaccard_score(masks[i][0].cpu().numpy().flatten(), predicted_masks[i].flatten())
            dice = f1_score(masks[i][0].cpu().numpy().flatten(), predicted_masks[i].flatten())
            accuracy = accuracy_score(masks[i][0].cpu().numpy().flatten(), predicted_masks[i].flatten())

            iou_scores.append(iou)
            dice_scores.append(dice)
            accuracy_scores.append(accuracy)

# Calculate and print average metrics
average_iou = sum(iou_scores) / len(iou_scores)
average_dice = sum(dice_scores) / len(dice_scores)
average_accuracy = sum(accuracy_scores) / len(accuracy_scores)

print(f"Average IoU (Jaccard Index): {average_iou}")
print(f"Average Dice Coefficient: {average_dice}")
print(f"Average Accuracy: {average_accuracy}")

print("Evaluation and Testing Finished")
