In [11]:
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
from sklearn.decomposition import PCA
from sklearn.metrics import jaccard_score
from sklearn.preprocessing import StandardScaler
from skimage.segmentation import slic
from sklearn.neighbors import kneighbors_graph
from torchvision import transforms
from PIL import Image, ImageOps

In [12]:
class CellSegmentation():
    def __init__(self, root, filenames, k_neighbors=5):

        self.filenames = filenames
        # Extract image paths and labels from the CSV
        self.cellpaths = [os.path.join(f'{root}/Tissue Images', f'{filename}.tif') for filename in filenames]
        self.maskpaths = [os.path.join(f'{root}/Masks', f'{filename}.npz') for filename in filenames]
        self.k_neighbors = k_neighbors

    def preprocess_image(self, image_path):

        # Load and Resize Image
        image = Image.open(image_path)

        image = np.array(image)

        # Flatten the Image (N_pixels, C)
        img_flattened = image.reshape(-1, image.shape[-1])  # (50176, Channels)

        # Standardize by removing mean and scaling to unit variance 
        scaler = StandardScaler()
        embedding_standardized = scaler.fit_transform(img_flattened)

        return embedding_standardized.reshape(image.shape[0], image.shape[1], -1)

    # Load ground truth mask
    def preprocess_mask(self, mask_path):
        loaded_data = np.load(mask_path)
        loaded_color_mask = loaded_data['color_mask']

        loaded_color_mask = np.array(loaded_color_mask)
        return loaded_color_mask
    
    # Generate a graph from the image features and the provided mask.
    def generate_graph(self, features, mask):        
        # Superpixel segmentation
        segments = slic(features, n_segments=1000, compactness=15, start_label=0)
        nodes = np.unique(segments)  # Get unique segment labels
        node_features = []
        node_labels = []

        for node in nodes:
            mask_node = segments == node
            mean_features = features[mask_node].mean(axis=0)
            node_features.append(mean_features)

            superpixel_mask_values = mask[mask_node]
            unique, counts = np.unique(superpixel_mask_values, return_counts=True)
            node_label = unique[np.argmax(counts)]  # Assign most frequent class in superpixel
            node_labels.append(node_label)

        node_features = np.array(node_features)
        node_labels = np.array(node_labels)

        # Construct adjacency matrix (k-NN or spatial would require understanding the form)
        adj_matrix = kneighbors_graph(node_features, n_neighbors=self.k_neighbors).toarray()

        # PyTorch conversion for the graph
        edge_indices = np.array(np.nonzero(adj_matrix))
        edge_indices = torch.tensor(edge_indices, dtype=torch.long)
        x = torch.tensor(node_features, dtype=torch.float)
        y = torch.tensor(node_labels, dtype=torch.long)

        return Data(x=x, edge_index=edge_indices, y=y)
    
    # Generate a PyG dataset from image and mask paths.
    def create_dataset(self):
        dataset = []
        for i in range(len(self.filenames)):
            img_path = self.cellpaths[i]
            mask_path = self.maskpaths[i]
            features = self.preprocess_image(img_path)
            mask = self.preprocess_mask(mask_path)

            graph = self.generate_graph(features, mask)
            dataset.append(graph)
        return dataset

In [13]:
class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        GNN Model with both node-level and graph-level classification.

        Args:
            input_dim (int): Input feature dimension for each node.
            hidden_dim (int): Hidden layer dimension.
            output_dim (int): Output dimension for node-level classification..
        """
        super(GNNModel, self).__init__()

        # GNN layers: GCN and GAT for feature propagation
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.dropout1 = nn.Dropout(p=0.5)
        self.gat1 = GATConv(hidden_dim, hidden_dim, heads=2, concat=False)
        self.dropout2 = nn.Dropout(p=0.6)

        # Node-level classification branch
        self.node_classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        """
        Forward pass for GNN model.

        Args:
            data: A PyTorch Geometric Data object containing:
                - data.x: Node features (N_nodes x input_dim)
                - data.edge_index: Edge list (2 x N_edges)
        Returns:
            node_predictions (torch.Tensor): Node-level predictions (N_nodes x output_dim)
        """
        x, edge_index = data.x, data.edge_index
        
        x = self.dropout1(F.relu(self.gcn1(x, edge_index)))
        x = self.dropout2(F.relu(self.gat1(x, edge_index)))
        
        node_predictions = self.node_classifier(x)
        return node_predictions

In [14]:
def evaluate(model, val_loader, device):
    """
    Evaluate the model on the validation set and calculate node-level and graph-level accuracy and Jaccard score.

    Args:
        model: The GNN model.
        val_loader: DataLoader for the validation set.
        device: The device to run the model on (e.g., "cuda" or "cpu").
    
    Returns:
        node_jaccard: Jaccard accuracy score for node-level classification.
        graph_jaccard: Jaccard accuracy score for graph-level classification.
    """
    model.eval()  # Set the model to evaluation mode
    node_preds_all = []
    node_labels_all = []

    with torch.no_grad():  # No gradient computation during evaluation
        for data in val_loader:
            data = data.to(device)  # Move data to the device (GPU/CPU)

            # Forward pass
            node_predictions = model(data)

            # Collect predictions and true labels
            node_preds_all.append(node_predictions.cpu().numpy())
            node_labels_all.append(data.y.cpu().numpy())

    # Flatten the lists for evaluation
    node_preds_all = np.concatenate(node_preds_all, axis=0)
    node_labels_all = np.concatenate(node_labels_all, axis=0)

    # Calculate Jaccard score for node-level and graph-level classification
    jaccard = jaccard_score(node_labels_all, node_preds_all.argmax(axis=1), average='macro')  # Macro for multi-class

    return jaccard


In [15]:
def train(model, train_loader, val_loader, optimizer, device, epochs=5, lr=1e-3, patience=5):
    """
    Train and evaluate the GNN model for a given number of epochs. Includes early stopping and learning rate adjustment.

    Args:
        model: The GNN model.
        train_loader: DataLoader for the training set.
        val_loader: DataLoader for the validation set.
        device: The device to run the model on (e.g., "cuda" or "cpu").
        epochs (int): Number of training epochs.
        lr (float): Learning rate for the optimizer.
        patience (int): Number of epochs to wait for performance improvement before stopping.
    """
    model.to(device)  # Move model to GPU or CPU
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss() 
    best_node_accuracy = 0.0  # To track the best node-level accuracy
    patience_counter = 0  # To track the number of epochs without improvement
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        model.train()  # Set the model to training mode
        total_loss = 0

        # Training loop
        for data in train_loader:
            data = data.to(device)  # Move data to GPU/CPU
            optimizer.zero_grad()  # Zero the gradients
            # Forward pass
            node_predictions = model(data)
            # Loss calculation
 
            node_loss = criterion(node_predictions, data.y)  # Node-level loss
            # Total loss
            total_loss = node_loss
            # Backpropagation
            total_loss.backward()
            optimizer.step()  # Update the weights

        print(f"Training Loss: {total_loss.item():.4f}")

        # Evaluate on validation set
        node_accuracy = evaluate(model, val_loader, device)
        print(f"Node-level Accuracy: {node_accuracy:.4f}")

        # Early stopping check and learning rate adjustment
        if node_accuracy > best_node_accuracy:
            torch.save(model, 'best_model.pth')
            best_node_accuracy = node_accuracy
            patience_counter = 0  # Reset patience counter if performance improves
        else:
            patience_counter += 1

        # Early stopping condition
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}. No improvement in accuracy for {patience} epochs.")
            break

        print(f"Best Node Accuracy so far: {best_node_accuracy:.4f}")

    print("Training completed.")

In [16]:
filenames = []
# Specify the directory
directory = Path("Dataset/MoNuSeg 2018 Training Data/Annotations")

# Loop through every file in the folder
for file_path in directory.iterdir():
    # Check if it's a file (not a directory)
    if file_path.is_file():
        # Get the file name without the extension
        file_name_without_extension = file_path.stem
        filenames.append(file_name_without_extension)

# Initialize class and preprocess data
segmentation = CellSegmentation(root="Dataset/MoNuSeg 2018 Training Data", filenames=filenames)
dataset = segmentation.create_dataset()
print(len(dataset))
# Split into train/test sets
train_loader = DataLoader(dataset[:int(0.6 * len(dataset))], batch_size=6, shuffle=True)
test_loader = DataLoader(dataset[int(0.4 * len(dataset)):], batch_size=6)

if torch.cuda.is_available():
    print("CUDA enabled")
else:
    print("CUDA not found")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


37
CUDA enabled


In [17]:
input_dim = 3  # Node feature dimension is 1
hidden_dim = 64
output_dim = 3  # 3 classes for node-level classification

In [18]:
#Init model
model = GNNModel(input_dim, hidden_dim, output_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-3)
train(model, train_loader, test_loader, optimizer, device, epochs = 15)

Epoch 1/15
Training Loss: 1.0780
Node-level Accuracy: 0.0732
Best Node Accuracy so far: 0.0732
Epoch 2/15
Training Loss: 1.0115
Node-level Accuracy: 0.2931
Best Node Accuracy so far: 0.2931
Epoch 3/15
Training Loss: 0.9528
Node-level Accuracy: 0.4254
Best Node Accuracy so far: 0.4254
Epoch 4/15
Training Loss: 0.8765
Node-level Accuracy: 0.4977
Best Node Accuracy so far: 0.4977
Epoch 5/15
Training Loss: 0.7844
Node-level Accuracy: 1.0000
Best Node Accuracy so far: 1.0000
Epoch 6/15
Training Loss: 0.6887
Node-level Accuracy: 1.0000
Best Node Accuracy so far: 1.0000
Epoch 7/15
Training Loss: 0.5830
Node-level Accuracy: 1.0000
Best Node Accuracy so far: 1.0000
Epoch 8/15
Training Loss: 0.5009
Node-level Accuracy: 1.0000
Best Node Accuracy so far: 1.0000
Epoch 9/15
Training Loss: 0.4086
Node-level Accuracy: 1.0000
Best Node Accuracy so far: 1.0000
Epoch 10/15
Training Loss: 0.3074
Node-level Accuracy: 1.0000
Early stopping at epoch 10. No improvement in accuracy for 5 epochs.
Training compl

In [19]:
testfiles = []
# Specify the directory
directory = Path("Dataset/MoNuSegTestData/Annotations")

# Loop through every file in the folder
for file_path in directory.iterdir():
    # Check if it's a file (not a directory)
    if file_path.is_file():
        # Get the file name without the extension
        file_name_without_extension = file_path.stem
        testfiles.append(file_name_without_extension)

test = CellSegmentation(root="Dataset/MoNuSegTestData", filenames=testfiles)
dataset = test.create_dataset()
print(len(dataset))
# Split into train/test sets
testset_loader = DataLoader(dataset, batch_size=6, shuffle=True)

14


In [20]:
best_model = torch.load('best_model.pth')
results = evaluate(best_model, testset_loader, device)
print(f"Accuracy: {results:.4f}")

Accuracy: 1.0000

  best_model = torch.load('best_model.pth')



