In [2]:
#packages to import
import os
import time
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import Data, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.utils as pyg_utils
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.utils import from_networkx
from sklearn.model_selection import train_test_split
from collections import defaultdict
from tqdm import tqdm
import random
from itertools import combinations, product
from torch.utils.data import DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
import re
from collections import defaultdict

In [3]:
#Generate graph data object

def flatten_node_attributes(node_attr):
        """Flatten the node attributes to ensure all values are numeric."""
        flat_attrs = []
        for attr in node_attr.values():
            if isinstance(attr, (tuple, list)):
                flat_attrs.extend(attr)
            else:
                flat_attrs.append(attr)
        return flat_attrs
    

def load_graph_data(node_features_file, edge_features_file):
    node_features_df = pd.read_csv(node_features_file)
    edge_features_df = pd.read_csv(edge_features_file)

    node_features_df['pos'] = node_features_df['pos'].apply(eval).apply(lambda x: tuple(map(float, x)))
    node_features_df['dir'] = node_features_df['dir'].apply(eval).apply(lambda x: tuple(map(float, x)))
    node_features_df[['ends0', 'ends1']] = node_features_df['ends'].apply(eval).apply(pd.Series)
    node_features_df['ends0'] = node_features_df['ends0'].apply(lambda x: tuple(map(float, x)))
    node_features_df['ends1'] = node_features_df['ends1'].apply(lambda x: tuple(map(float, x)))

     # Select required columns and ensure they are numeric
    include_columns = ['stepNum', 'id', 'ObjectNumber', 'parent_id', 'label', 'cellType', 'divideFlag',
                       'LifeHistory', 'startVol', 'targetVol', 'radius', 'length', 'strainRate', 'strainRate_rolling',
                       'pos', 'dir', 'ends0', 'ends1']
    node_features_df = node_features_df[include_columns]

   # display(node_features_df)
    # Create the graph
    G = nx.MultiDiGraph()

    # Add nodes with their features as attributes and label them by their index
    for idx, row in node_features_df.iterrows():
        node_id = idx  # Use the index as a unique label for each node
        G.add_node(node_id, **row.to_dict())

    # Define positions for the entire graph
    pos = {node: data['pos'] for node, data in G.nodes(data=True)}

    # Create a unique mapping from (stepNum, ObjectNum) to node index
    node_mapping = {(row['stepNum'], row['ObjectNumber']): idx for idx, row in node_features_df.iterrows()}
    contact_edges = []
    # Add edges, considering contact edges
    for _, row in edge_features_df.iterrows():
      if row['Relationship'] == 'Neighbors':
        step_num_1 = row['First Image Number']
        step_num_2 = row['Second Image Number']
        node1 = row['First Object Number']
        node2 = row['Second Object Number']

        if (step_num_1, node1) in node_mapping and (step_num_2, node2) in node_mapping:
          if not G.has_edge(node_mapping[(step_num_1, node1)], node_mapping[(step_num_2, node2)]):
            node1_idx = node_mapping[(step_num_1, node1)]
            node2_idx = node_mapping[(step_num_2, node2)]
            G.add_edge(node1_idx, node2_idx, edge_type = 'contact')
            G.add_edge(node2_idx, node1_idx, edge_type = 'contact')
            contact_edges.append((node1_idx, node2_idx))
            contact_edges.append((node2_idx, node1_idx))

    # Create lineage mapping
    lineage_mapping = {(row['id'], row['parent_id'], row['stepNum']): idx for idx, row in node_features_df.iterrows()}
    lineage_edges = []
    # Add directed lineage edges based on the lineage mapping
    for key, node1_idx in lineage_mapping.items():
        id, parent_id, step_num = key
        if parent_id == 0:
            continue  # Skip if parent_id is zero
        if (id, id, step_num + 1) in lineage_mapping:
            node2_idx = lineage_mapping[(id, id, step_num + 1)]
            G.add_edge(node1_idx, node2_idx, edge_type='lineage')
            lineage_edges.append((node1_idx, node2_idx))
        elif (id, parent_id, step_num + 1) in lineage_mapping:
            node2_idx = lineage_mapping[(id, parent_id, step_num + 1)]
            G.add_edge(node1_idx, node2_idx, edge_type='lineage')
            lineage_edges.append((node1_idx, node2_idx))
        else:
            for parent_key, node2_idx in lineage_mapping.items():
                parent_id_key, _, step_num_key = parent_key
                if parent_id_key == parent_id and step_num_key == step_num - 1:
                    G.add_edge(node2_idx, node1_idx, edge_type='lineage')
                    lineage_edges.append((node2_idx, node1_idx))
                    break


    x = torch.tensor([flatten_node_attributes(G.nodes[node]) for node in G.nodes()], dtype=torch.float)
    contact_edge_index = torch.tensor(contact_edges, dtype=torch.long).t().contiguous()
    lineage_edge_index = torch.tensor(lineage_edges, dtype=torch.long).t().contiguous()

    data =  Data(x, contact_edge_index=contact_edge_index, lineage_edge_index=lineage_edge_index)
    return data

In [4]:
# Triplet Loss functions
class TripletLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = F.pairwise_distance(anchor, positive, p=2)
        distance_negative = F.pairwise_distance(anchor, negative, p=2)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

class MixedTripletLoss(nn.Module):
    def __init__(self, alpha=0.5, margin=1.0, margin_cosine=0.2):
        super(MixedTripletLoss, self).__init__()
        self.alpha = alpha
        self.margin = margin
        self.margin_cosine = margin_cosine

    def forward(self, anchor, positive, negative):
        triplet_loss = TripletLoss(self.margin)
        triplet_loss_cosine = TripletLossCosine(self.margin_cosine)

        loss_triplet = triplet_loss(anchor, positive, negative)
        loss_triplet_cosine = triplet_loss_cosine(anchor, positive, negative)

        mixed_loss = loss_triplet + self.alpha * loss_triplet_cosine

        return mixed_loss

class TripletLossCosine(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLossCosine, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = 1 - F.cosine_similarity(anchor, positive, dim=-1)
        distance_negative = 1 - F.cosine_similarity(anchor, negative, dim=-1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()


In [5]:
#folder of the data generated by sims_relationships.ipynb
def get_file_names(directory):

    # Collect files starting with "Trackrefiner"
    matching_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.startswith("Trackrefiner"):
                matching_files.append(os.path.join(root, file))

    return matching_files

#Identifying parameter combinations

def identify_parameter_combinations(matching_files):
    variations_with_instances = np.unique(list(map(lambda x: x.split('Trackrefiner.')[1].split('result')[0], matching_files)))

    # Initialize sets to collect unique values
    unique_gamma = set()
    unique_reg_param = set()
    unique_adh = set()
    unique_average = set()


    # Extract parameter values
    for var in variations_with_instances:
        # Remove the instance part
        base_var, average = var.split('-Average-')
        uni_average = average[0]
        # Parse the gamma, reg_param, and adh values
        parts = base_var.split('_')
        gamma = parts[1]
        reg_param = parts[4]
        adh = parts[6]

        # Add to respective sets
        unique_gamma.add(gamma)
        unique_reg_param.add(reg_param)
        unique_adh.add(adh)
        unique_average.add(uni_average)

    # Convert to sorted lists (optional)
    unique_gamma = list(sorted(unique_gamma))
    unique_reg_param = list(sorted(unique_reg_param))
    unique_adh = list(sorted(unique_adh))
    return unique_gamma, unique_reg_param, unique_adh

# Generate the data structure
def get_paramter_set_simulations(gamma_values, reg_param_values, adh_values):
    data = {
        f"gamma_{gamma}_reg_param_{reg_param}_adh_{adh}": [
            f"gamma_{gamma}_reg_param_{reg_param}_adh_{adh}-Average-{avg}" 
            for avg in average_values
        ]
        for gamma, reg_param, adh in product(gamma_values, reg_param_values, adh_values)
    }
    return data

#this function will have a simulation instance as key and the values as [node_features, node_relationships]
def generate_simulation_ditionary(file_list):
    # Regex pattern to extract parameter combinations and Average instance
    pattern = re.compile(r"gamma_(\d+)_reg_param_(\d+\.?\d*)_adh_(\d+\.?\d*)-Average-(\d+)")

    # Dictionary to store the key-value pairs
    parameter_dict = defaultdict(list)

    # Process each filename
    for file in file_list:
        match = pattern.search(file)
        if match:
            # Extract values
            gamma, reg_param, adh, average = match.groups()
            # Create key as parameter combination
            key = f"gamma_{gamma}_reg_param_{reg_param}_adh_{adh}-Average-{average}"
            # Append the Average instance to the key's list
            parameter_dict[key] = [file, file.split('Trackrefiner')[0] + 'ObjectRelationship.Trackrefiner' + file.split('Trackrefiner')[1]]

    # Convert defaultdict to a regular dictionary
    parameter_dict = dict(parameter_dict)
    return parameter_dict

# Generate triplets
def generate_triplets(data):
    triplets = []
    params = list(data.keys())  # List of parameter sets

    for param, instances in data.items():
        # Anchor and Positive from the same parameter set
        for anchor, positive in combinations(instances, 2):
            # Negative from a different parameter set
            negative_param = random.choice([p for p in params if p != param])
            negative = random.choice(data[negative_param])
            triplets.append((anchor, positive, negative))

    return triplets


#triplet batches
class TripletDataset(Dataset):
    def __init__(self, triplets):
        """
        Args:
            triplets: A list of triplets (anchor_data, positive_data, negative_data)
        """
        self.triplets = triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor_data, positive_data, negative_data = self.triplets[idx]
        return anchor_data, positive_data, negative_data



In [6]:
# GAT Block with Attention Mechanism
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels=22, out_channels = hidden_channels, heads=2)
        self.conv2 = GATConv(hidden_channels * 2, out_channels=out_channels, heads=1, concat=False)

        self.conv3 = GATConv(out_channels, out_channels=hidden_channels, heads=2)
        self.conv4 = GATConv(hidden_channels * 2, out_channels=out_channels, heads=1, concat=False)


    def forward(self, data):
        x, c_edge_index, l_edge_index = data.x, data.contact_edge_index, data.lineage_edge_index
        x = self.conv1(x, c_edge_index)
        x = F.elu(x)
        x = self.conv2(x, c_edge_index)
        x = F.elu(x)

        x = self.conv3(x, l_edge_index)
        x = F.elu(x)
        x = self.conv4(x, l_edge_index)
        x = F.elu(x)

        x = global_mean_pool(x, data.batch)
        return x
     

In [7]:
#model training on data
def train(data_loader, device, model, criterion, optimizer, num_epochs=20):
    """
    Train the model using triplet loss.

    Args:
        data_loader: DataLoader that loads the triplet data.
        model: The model to be trained (e.g., GAT).
        criterion: The loss function (e.g., TripletMarginLoss).
        optimizer: The optimizer (e.g., Adam).
        num_epochs: The number of training epochs.
    """
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0

        # Initialize tqdm for progress bar and display the epoch
        epoch_progress = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", ncols=100)

        for anchor_data, positive_data, negative_data in epoch_progress:
            # Transfer the data to the appropriate device (GPU or CPU)
            anchor_data, positive_data, negative_data = anchor_data.to(device), positive_data.to(device), negative_data.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass for anchor, positive, and negative data
            anchor_out = model(anchor_data)  # Process anchor data
            positive_out = model(positive_data)  # Process positive data
            negative_out = model(negative_data)  # Process negative data

            # Calculate the triplet loss
            loss = criterion(anchor_out, positive_out, negative_out)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update the loss summary and progress bar description
            running_loss += loss.item()
            avg_loss = running_loss / (epoch_progress.n + 1)
            epoch_progress.set_postfix(loss=avg_loss)

        avg_loss = running_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


In [8]:
#check model performance
def validate(data_loader, model, criterion, margin):
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    correct_triplets = 0
    total_triplets = 0

    with torch.no_grad():
        for anchor_data, positive_data, negative_data in data_loader:
            # Transfer data to the device
            anchor_data = anchor_data.to(device)
            positive_data = positive_data.to(device)
            negative_data = negative_data.to(device)

            # Forward pass
            anchor_out = model(anchor_data)
            positive_out = model(positive_data)
            negative_out = model(negative_data)

            # Calculate triplet loss
            loss = criterion(anchor_out, positive_out, negative_out)
            total_loss += loss.item()

            # Evaluate triplet condition
            distance_positive = F.pairwise_distance(anchor_out, positive_out, p=2)
            distance_negative = F.pairwise_distance(anchor_out, negative_out, p=2)

            correct_triplets += (distance_positive + margin < distance_negative).sum().item()
            total_triplets += len(distance_positive)

    avg_loss = total_loss / len(data_loader)
    accuracy = correct_triplets / total_triplets * 100

    print(f"Validation Loss: {avg_loss:.4f}, Triplet Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy


In [None]:
#split training data
#the split should sum upto 1
def generate_data_splits(train_split, test_split, val_split, triplets):
    if train_split + test_split + val_split == 1:
        
        #processing of triplet file names
        triplets_n = []
        for i in triplets:
            temp = []
            for j in i:
                temp.append(j.split('-')[0])
            triplets_n.append(tuple(temp))
            
        # Step 1: Create a mapping of anchor-negative pairs
        anchor_negative_map = defaultdict(list)
        k = 0
        for anchor, positive, negative in triplets_n:
            anchor_negative_map[(anchor, negative)].append((k, anchor, positive, negative))
            k+=1

        # Step 2: Get anchor-negative pairs
        anchor_negative_pairs = list(anchor_negative_map.keys())

        # Step 3: Stratified split for training, validation, and test
        train_pairs, temp_pairs = train_test_split(
            anchor_negative_pairs, test_size=1-train_split, random_state=42
        )
        val_pairs, test_pairs = train_test_split(
            temp_pairs, test_size=val_split/(1-train_split), random_state=42
        )

        # Step 4: Collect triplets corresponding to each split
        train_triplets = [triplet for pair in train_pairs for triplet in anchor_negative_map[pair]]
        val_triplets = [triplet for pair in val_pairs for triplet in anchor_negative_map[pair]]
        test_triplets = [triplet for pair in test_pairs for triplet in anchor_negative_map[pair]]
        
        train_ix = list(map(lambda x: x[0], train_triplets))
        val_ix = list(map(lambda x: x[0], val_triplets))
        test_ix = list(map(lambda x: x[0],test_triplets))
        
        triplets = np.array(triplets)
        train_dat = triplets[train_ix]
        test_dat = triplets[test_ix]
        val_dat = triplets[val_ix]
        
        final_train_data = []
        final_test_data = []
        final_val_data = []

        for i in train_dat:
            k = []
            for j in i:
                k.append(data_tuple[j])
            final_train_data.append(k)

        for i in test_dat:
            k = []
            for j in i:
                k.append(data_tuple[j])
            final_test_data.append(k)
            
        for i in val_dat:
            k = []
            for j in i:
                k.append(data_tuple[j])
            final_val_data.append(k)
            
        return final_train_data, final_test_data, final_val_data    



In [None]:
#non-configurable parameters
n_simulation_instances = 10 #number of simulation instances for every parameter combination generated from ABM
#configurable parameters
train_split =0.7
test_split = 0.15
val_split = 0.15
batch_train_size = 32
hidden_channels_model = 4 #attention heads
feature_embedding_size = 5
model_learning_rate = 0.001 #optimizer can be changed and can also be learnt using adaptive learning rates
mixed_loss_alpha = 0.5
mixed_loss_margin = 1
mixed_loss_cosine = 0.2
model_epochs = 10

In [None]:
#Triplets generation
directory = "/Users/sushmadhamodharan/Downloads/finaldata"
matching_files = get_file_names(directory)
unique_gamma, unique_reg_param, unique_adh = identify_parameter_combinations(matching_files)
n_simulation_instances = 10 #number of simulation instances for every parameter combination generated from ABM
average_values = [str(i) for i in range(n_simulation_instances)]
data = get_paramter_set_simulations(unique_gamma, unique_reg_param, unique_adh)
# Generate triplets
triplets = generate_triplets(data)
# Example list of filenames
file_list = matching_files
parameter_dict = generate_simulation_ditionary(file_list)

#generating the graph object for every simulation
data_tuple = {}
for key in parameter_dict.keys():
    data_tuple[key] = load_graph_data(parameter_dict[key][0], parameter_dict[key][1])

final_train_data, final_test_data, final_val_data = generate_data_splits(train_split, test_split, val_split, triplets)


In [15]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

#generate batches of training data
triplet_dataset = TripletDataset(final_train_data)
triplet_train_dataloader = DataLoader(triplet_dataset, batch_size=batch_train_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels = 22,hidden_channels=hidden_channels_model, out_channels=feature_embedding_size).to(device)  # Adjust input/output as needed
optimizer = Adam(model.parameters(), lr=model_learning_rate)

# Choose the loss function (e.g., MixedTripletLoss, TripletLossCosine, or TripletLoss)
criterion = MixedTripletLoss(alpha=mixed_loss_alpha, margin=mixed_loss_margin, margin_cosine=mixed_loss_cosine).to(device)


# Train the model
train(triplet_train_dataloader, device, model, criterion, optimizer, num_epochs=model_epochs)

Epoch 1/10: 100%|█████████████████████████████████████| 98/98 [01:08<00:00,  1.44batch/s, loss=0.46]


Epoch [1/10], Loss: 0.4601


Epoch 2/10: 100%|████████████████████████████████████| 98/98 [01:08<00:00,  1.44batch/s, loss=0.322]


Epoch [2/10], Loss: 0.3220


Epoch 3/10: 100%|████████████████████████████████████| 98/98 [01:06<00:00,  1.48batch/s, loss=0.298]


Epoch [3/10], Loss: 0.2978


Epoch 4/10: 100%|████████████████████████████████████| 98/98 [01:05<00:00,  1.49batch/s, loss=0.278]


Epoch [4/10], Loss: 0.2781


Epoch 5/10: 100%|████████████████████████████████████| 98/98 [01:07<00:00,  1.45batch/s, loss=0.271]


Epoch [5/10], Loss: 0.2710


Epoch 6/10: 100%|████████████████████████████████████| 98/98 [01:08<00:00,  1.43batch/s, loss=0.258]


Epoch [6/10], Loss: 0.2582


Epoch 7/10: 100%|████████████████████████████████████| 98/98 [01:05<00:00,  1.49batch/s, loss=0.249]


Epoch [7/10], Loss: 0.2486


Epoch 8/10: 100%|████████████████████████████████████| 98/98 [01:05<00:00,  1.50batch/s, loss=0.242]


Epoch [8/10], Loss: 0.2424


Epoch 9/10: 100%|████████████████████████████████████| 98/98 [01:20<00:00,  1.21batch/s, loss=0.236]


Epoch [9/10], Loss: 0.2364


Epoch 10/10: 100%|███████████████████████████████████| 98/98 [01:09<00:00,  1.41batch/s, loss=0.237]

Epoch [10/10], Loss: 0.2366





In [16]:
validate(final_test_data, model, criterion, margin = 1)

Validation Loss: 0.2193, Triplet Accuracy: 79.39%


(0.21929494740308975, 79.39042089985486)

In [17]:
validate(final_val_data, model, criterion, margin = 1)

Validation Loss: 0.2080, Triplet Accuracy: 79.53%


(0.20796928657918334, 79.52871870397644)

In [None]:
#generate the feature embeddin using the model and save it in pickle format to use it for visualization in SIMPLE_GAT_visulaization.ipynb