### Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import time
import networkx as nx

### Load Features

In [2]:
features_df = pd.read_csv("..Feature_with_FrameLevel/Feature_ResNet50Only/Feature_P01_04_EpicKitchen.csv")
action_classes = sorted(features_df["ActionLabel"].unique())
action_mapping = {action_classes[i]: i for i in range(len(action_classes))}
features_df["action_class_mapped"] = features_df["ActionLabel"].map(action_mapping)
train_df, test_df = train_test_split(features_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
def process_data(df):
    X = torch.tensor(df.iloc[:, 3:].values, dtype=torch.float32)
    y = torch.tensor(df["action_class_mapped"].values, dtype=torch.long)
    return X, y

X_train, y_train = process_data(train_df)
X_val, y_val = process_data(val_df)
X_test, y_test = process_data(test_df)

### SVSG Graph Construction using DPT

In [23]:
def dynamic_percentile_graph(X, percentile=95, device='cuda'):
    X = X.to(device)
    N = X.shape[0]
    sim_matrix = F.cosine_similarity(X.unsqueeze(1), X.unsqueeze(0), dim=2)
    threshold = torch.quantile(sim_matrix, percentile / 100)
    adj_matrix = (sim_matrix >= threshold).float()
    adj_matrix.fill_diagonal_(0)
    row_indices, col_indices = torch.nonzero(adj_matrix, as_tuple=True)
    edge_index = torch.stack([row_indices, col_indices], dim=0)
    edge_count = edge_index.shape[1]
    total_possible_edges = N * (N - 1)
    return edge_index, edge_count, total_possible_edges

def plot_dpt_graph(edge_index, num_nodes, title="DPT Graph (Circular Layout)"):
    import matplotlib.pyplot as plt
    import networkx as nx
    G = nx.DiGraph()
    G.add_nodes_from(range(1, num_nodes + 1))  
    edges = edge_index.cpu().numpy().T
    edges_1based = [(src + 1, dst + 1) for src, dst in edges]
    G.add_edges_from(edges_1based)
    pos = nx.circular_layout(G)
    plt.figure(figsize=(8, 8))
    nx.draw_networkx_nodes(G, pos, node_color='lavender', edgecolors='black', node_size=100)
    nx.draw_networkx_labels(G, pos, font_size=12, font_color='red', font_weight='bold')
    nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=12)
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


### GAT Model

In [24]:
class GATModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=2):
        super(GATModel, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=heads)
        self.conv2 = GATConv(hidden_dim * heads, output_dim, heads=1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
    
class MultiTaskLoss(nn.Module):
    def __init__(self, alpha=0.8):
        super(MultiTaskLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, outputs, targets, features):
        ce = self.ce_loss(outputs, targets)
        temporal_loss = torch.mean(torch.abs(features[1:] - features[:-1]))
        return self.alpha * ce + (1 - self.alpha) * temporal_loss

input_dim = X_train.shape[1]
hidden_dim = 128
output_dim = features_df['action_class_mapped'].nunique()
model = GATModel(input_dim, hidden_dim, output_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
criterion = MultiTaskLoss()

### Model Training

In [26]:
num_epochs = 100
percentile_threshold = 95

for epoch in range(num_epochs):
    model.train()
    edge_index, edge_count, total_possible_edges = dynamic_percentile_graph(
        X_train, percentile=percentile_threshold, device=device)
    edge_index = edge_index.to(device)

    optimizer.zero_grad()
    output = model(X_train.to(device), edge_index)
    loss = criterion(output, y_train.to(device), X_train.to(device))
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()
    with torch.no_grad():
        edge_index_val, _, _ = dynamic_percentile_graph(
            X_val, percentile=percentile_threshold, device=device)
        edge_index_val = edge_index_val.to(device)
        val_output = model(X_val.to(device), edge_index_val)
        val_loss = criterion(val_output, y_val.to(device), X_val.to(device))
        val_probs = torch.softmax(val_output, dim=1)
        val_preds = val_probs.argmax(dim=1).cpu().numpy()

        top1_acc = (val_preds == y_val.cpu().numpy()).sum() / len(y_val)
        precision = precision_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        recall = recall_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        f1 = f1_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        top5_preds = torch.topk(val_probs, 5, dim=1).indices.cpu().numpy()
        top5_correct = sum([y in top5 for y, top5 in zip(y_val.cpu().numpy(), top5_preds)])
        top5_acc = top5_correct / len(y_val)

    sparsity_ratio = edge_count / total_possible_edges * 100

    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Loss: {loss.item():.4f} - "
          f"Val Loss: {val_loss.item():.4f} - "
          f"Top-1 Acc: {top1_acc:.4f} - "
          f"Top-5 Acc: {top5_acc:.4f} - "
          f"Precision: {precision:.4f} - "
          f"Recall: {recall:.4f} - "
          f"F1-Score: {f1:.4f} - "
          f"Edge Retained: {edge_count}/{total_possible_edges} "
          f"({sparsity_ratio:.2f}% sparse)")

edge_index_final, _, _ = dynamic_percentile_graph(
    X_train, percentile=percentile_threshold, device=device)
plot_dpt_graph(edge_index_final, X_train.shape[0], title=f"Final DPT Graph After {num_epochs} Epochs")


Epoch 1/100 - Loss: 1.1145 - Val Loss: 0.6019 - Top-1 Acc: 0.8218 - Top-5 Acc: 1.0000 - Precision: 0.8406 - Recall: 0.8218 - F1-Score: 0.7857 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 2/100 - Loss: 1.0846 - Val Loss: 0.6030 - Top-1 Acc: 0.8218 - Top-5 Acc: 1.0000 - Precision: 0.8389 - Recall: 0.8218 - F1-Score: 0.7836 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 3/100 - Loss: 1.0464 - Val Loss: 0.6112 - Top-1 Acc: 0.8614 - Top-5 Acc: 1.0000 - Precision: 0.8764 - Recall: 0.8614 - F1-Score: 0.8334 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 4/100 - Loss: 1.0186 - Val Loss: 0.6253 - Top-1 Acc: 0.8812 - Top-5 Acc: 1.0000 - Precision: 0.9028 - Recall: 0.8812 - F1-Score: 0.8604 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 5/100 - Loss: 1.0114 - Val Loss: 0.6394 - Top-1 Acc: 0.8812 - Top-5 Acc: 1.0000 - Precision: 0.9028 - Recall: 0.8812 - F1-Score: 0.8583 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 6/100 - Loss: 1.0180 - Val Loss: 0.6455 - Top-1 Acc: 0

Epoch 47/100 - Loss: 0.7438 - Val Loss: 0.5349 - Top-1 Acc: 0.8515 - Top-5 Acc: 1.0000 - Precision: 0.8701 - Recall: 0.8515 - F1-Score: 0.8273 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 48/100 - Loss: 0.7402 - Val Loss: 0.5352 - Top-1 Acc: 0.8515 - Top-5 Acc: 1.0000 - Precision: 0.8789 - Recall: 0.8515 - F1-Score: 0.8239 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 49/100 - Loss: 0.7369 - Val Loss: 0.5354 - Top-1 Acc: 0.8614 - Top-5 Acc: 1.0000 - Precision: 0.8895 - Recall: 0.8614 - F1-Score: 0.8360 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 50/100 - Loss: 0.7337 - Val Loss: 0.5351 - Top-1 Acc: 0.8614 - Top-5 Acc: 1.0000 - Precision: 0.8895 - Recall: 0.8614 - F1-Score: 0.8360 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 51/100 - Loss: 0.7305 - Val Loss: 0.5344 - Top-1 Acc: 0.8614 - Top-5 Acc: 1.0000 - Precision: 0.8895 - Recall: 0.8614 - F1-Score: 0.8360 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 52/100 - Loss: 0.7274 - Val Loss: 0.5333 - Top-1 

Epoch 91/100 - Loss: 0.6281 - Val Loss: 0.5106 - Top-1 Acc: 0.8812 - Top-5 Acc: 1.0000 - Precision: 0.9038 - Recall: 0.8812 - F1-Score: 0.8623 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 92/100 - Loss: 0.6262 - Val Loss: 0.5100 - Top-1 Acc: 0.8812 - Top-5 Acc: 1.0000 - Precision: 0.9038 - Recall: 0.8812 - F1-Score: 0.8623 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 93/100 - Loss: 0.6244 - Val Loss: 0.5094 - Top-1 Acc: 0.8812 - Top-5 Acc: 1.0000 - Precision: 0.9038 - Recall: 0.8812 - F1-Score: 0.8623 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 94/100 - Loss: 0.6225 - Val Loss: 0.5088 - Top-1 Acc: 0.8713 - Top-5 Acc: 1.0000 - Precision: 0.9010 - Recall: 0.8713 - F1-Score: 0.8535 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 95/100 - Loss: 0.6207 - Val Loss: 0.5082 - Top-1 Acc: 0.8713 - Top-5 Acc: 1.0000 - Precision: 0.9010 - Recall: 0.8713 - F1-Score: 0.8535 - Edge Retained: 16200/162006 (10.00% sparse)
Epoch 96/100 - Loss: 0.6189 - Val Loss: 0.5075 - Top-1 