In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import networkx as nx

In [2]:
# Load Features
# ==============================================================

features_df = pd.read_csv("..Feature_with_FrameLevel/Feature_ResNet50Only/Feature_P01_04_EpicKitchen.csv")

# Class Mapping
action_classes = sorted(features_df["ActionLabel"].unique())
action_mapping = {action_classes[i]: i for i in range(len(action_classes))}
features_df["action_class_mapped"] = features_df["ActionLabel"].map(action_mapping)

In [3]:
action_mapping

{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 6: 6,
 7: 7,
 8: 8,
 9: 9,
 10: 10,
 11: 11,
 12: 12,
 13: 13,
 14: 14,
 15: 15,
 16: 16,
 17: 17,
 18: 18,
 19: 19,
 20: 20,
 21: 21,
 22: 22,
 23: 23,
 24: 24,
 25: 25,
 26: 26,
 27: 27,
 28: 28}

In [4]:
# Train/Val/Test Split
# -------------------------

counts = features_df['action_class_mapped'].value_counts()
if (counts < 2).any():
    print("Warning: some classes have <2 samples — using non-stratified split.")
    train_df, test_df = train_test_split(features_df, test_size=0.2, random_state=42, shuffle=True)
    train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, shuffle=True)
else:
    train_df, test_df = train_test_split(features_df, test_size=0.2, random_state=42, stratify=features_df["action_class_mapped"])
    train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, stratify=train_df["action_class_mapped"])

# Convert to Tensors
def process_data(df):
    X = torch.tensor(df.iloc[:, 3:].values, dtype=torch.float32)
    y = torch.tensor(df["action_class_mapped"].values, dtype=torch.long)
    return X, y

X_train, y_train = process_data(train_df)
X_val, y_val = process_data(val_df)
X_test, y_test = process_data(test_df)



In [5]:
#  DPT Graph Construction
# ==============================================================

def dynamic_percentile_graph(X, percentile=95, device='cuda'):
    X = X.to(device)
    N = X.shape[0]
    sim_matrix = F.cosine_similarity(X.unsqueeze(1), X.unsqueeze(0), dim=2)
    # ignore self similarity when computing quantile
    sim_no_diag = sim_matrix.clone()
    sim_no_diag.fill_diagonal_(-1e9)
    threshold = torch.quantile(sim_no_diag, percentile / 100.0)
    adj_matrix = (sim_matrix >= threshold).float()
    adj_matrix.fill_diagonal_(0)
    row_indices, col_indices = torch.nonzero(adj_matrix, as_tuple=True)
    edge_index = torch.stack([row_indices, col_indices], dim=0)
    edge_count = edge_index.shape[1]
    total_possible_edges = N * (N - 1)
    return edge_index, edge_count, total_possible_edges


def plot_dpt_graph(edge_index, num_nodes, title="DPT Graph (Circular Layout)"):
    G = nx.DiGraph()
    G.add_nodes_from(range(1, num_nodes + 1))
    edges = edge_index.cpu().numpy().T
    G.add_edges_from([(int(s)+1, int(d)+1) for s, d in edges])
    pos = nx.circular_layout(G)
    plt.figure(figsize=(8, 8))
    nx.draw_networkx_nodes(G, pos, node_color='lightblue', edgecolors='black', node_size=120)
    nx.draw_networkx_labels(G, pos, font_size=10, font_color='black')
    nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=12)
    plt.title(title)
    plt.axis('off')
    plt.show()



# GraphSAGE Model
# ==============================================================

class GraphSAGEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):
        super(GraphSAGEModel, self).__init__()
        self.convs = nn.ModuleList()
        if num_layers == 1:
            self.convs.append(SAGEConv(input_dim, output_dim))
        else:
            self.convs.append(SAGEConv(input_dim, hidden_dim))
            for _ in range(num_layers - 2):
                self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.convs.append(SAGEConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1:
                x = F.relu(x)
        return F.log_softmax(x, dim=1)



#  Multi-task Loss (Cross-Entropy + Temporal Smoothness)
# ==============================================================

class MultiTaskLoss(nn.Module):
    def __init__(self, alpha=0.8):
        super(MultiTaskLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, outputs, targets, features):
        ce = self.ce_loss(outputs, targets)
        if features.shape[0] > 1:
            temporal_loss = torch.mean(torch.abs(features[1:] - features[:-1]))
        else:
            temporal_loss = torch.tensor(0.0, device=features.device)
        return self.alpha * ce + (1 - self.alpha) * temporal_loss



#  Training Setup
# ==============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = X_train.shape[1]
hidden_dim = 128
output_dim = len(action_classes)

model = GraphSAGEModel(input_dim, hidden_dim, output_dim, num_layers=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
criterion = MultiTaskLoss(alpha=0.8)

num_epochs = 100
percentile_threshold = 95



#  Training Loop
# ==============================================================

for epoch in range(num_epochs):
    model.train()
    edge_index, edge_count, total_possible_edges = dynamic_percentile_graph(X_train, percentile=percentile_threshold, device=device)
    edge_index = edge_index.to(device)

    optimizer.zero_grad()
    output = model(X_train.to(device), edge_index)
    loss = criterion(output, y_train.to(device), X_train.to(device))
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()
    with torch.no_grad():
        edge_index_val, _, _ = dynamic_percentile_graph(X_val, percentile=percentile_threshold, device=device)
        edge_index_val = edge_index_val.to(device)
        val_output = model(X_val.to(device), edge_index_val)
        val_loss = criterion(val_output, y_val.to(device), X_val.to(device))
        val_probs = torch.exp(val_output)
        val_preds = val_probs.argmax(dim=1).cpu().numpy()

        top1_acc = (val_preds == y_val.cpu().numpy()).mean()
        precision = precision_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        recall = recall_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        f1 = f1_score(y_val.cpu().numpy(), val_preds, average='weighted', zero_division=1)
        top5_acc = 0.0
        if output_dim >= 5:
            top5_preds = torch.topk(val_probs, 5, dim=1).indices.cpu().numpy()
            top5_correct = sum([y in top5 for y, top5 in zip(y_val.cpu().numpy(), top5_preds)])
            top5_acc = top5_correct / len(y_val)

    sparsity_ratio = edge_count / total_possible_edges * 100

    print(f"[GraphSAGE] Epoch {epoch+1}/{num_epochs} | "
          f"Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f} | "
          f"Top-1: {top1_acc:.4f} | Top-5: {top5_acc:.4f} | "
          f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | "
          f"Edges: {edge_count}/{total_possible_edges} ({sparsity_ratio:.2f}% retained)")



[GraphSAGE] Epoch 1/100 | Loss: 2.7696 | Val Loss: 2.6322 | Top-1: 0.0990 | Top-5: 0.4653 | Precision: 0.6997 | Recall: 0.0990 | F1: 0.0508 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 2/100 | Loss: 2.6350 | Val Loss: 2.5273 | Top-1: 0.2970 | Top-5: 0.5644 | Precision: 0.7023 | Recall: 0.2970 | F1: 0.2184 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 3/100 | Loss: 2.5267 | Val Loss: 2.4403 | Top-1: 0.2772 | Top-5: 0.5644 | Precision: 0.6935 | Recall: 0.2772 | F1: 0.1427 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 4/100 | Loss: 2.4369 | Val Loss: 2.3683 | Top-1: 0.2574 | Top-5: 0.6238 | Precision: 0.8088 | Recall: 0.2574 | F1: 0.1054 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 5/100 | Loss: 2.3614 | Val Loss: 2.3058 | Top-1: 0.2574 | Top-5: 0.6238 | Precision: 0.8088 | Recall: 0.2574 | F1: 0.1054 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 6/100 | Loss: 2.2962 | Val Loss: 2.2506 | Top-1: 0.2673 | Top-5: 0.6337 | Precision: 0.8095 

[GraphSAGE] Epoch 47/100 | Loss: 1.1409 | Val Loss: 1.2970 | Top-1: 0.6337 | Top-5: 0.9109 | Precision: 0.7905 | Recall: 0.6337 | F1: 0.5131 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 48/100 | Loss: 1.1249 | Val Loss: 1.2846 | Top-1: 0.6337 | Top-5: 0.9109 | Precision: 0.7905 | Recall: 0.6337 | F1: 0.5131 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 49/100 | Loss: 1.1094 | Val Loss: 1.2727 | Top-1: 0.6337 | Top-5: 0.9109 | Precision: 0.7905 | Recall: 0.6337 | F1: 0.5131 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 50/100 | Loss: 1.0943 | Val Loss: 1.2615 | Top-1: 0.6436 | Top-5: 0.9109 | Precision: 0.7934 | Recall: 0.6436 | F1: 0.5303 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 51/100 | Loss: 1.0796 | Val Loss: 1.2508 | Top-1: 0.6535 | Top-5: 0.9109 | Precision: 0.7964 | Recall: 0.6535 | F1: 0.5383 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 52/100 | Loss: 1.0653 | Val Loss: 1.2408 | Top-1: 0.6535 | Top-5: 0.9109 | Precision: 0

[GraphSAGE] Epoch 94/100 | Loss: 0.6878 | Val Loss: 0.9394 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0.8212 | Recall: 0.7129 | F1: 0.6308 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 95/100 | Loss: 0.6823 | Val Loss: 0.9348 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0.8212 | Recall: 0.7129 | F1: 0.6308 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 96/100 | Loss: 0.6769 | Val Loss: 0.9303 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0.8212 | Recall: 0.7129 | F1: 0.6308 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 97/100 | Loss: 0.6716 | Val Loss: 0.9260 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0.8212 | Recall: 0.7129 | F1: 0.6308 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 98/100 | Loss: 0.6664 | Val Loss: 0.9218 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0.8212 | Recall: 0.7129 | F1: 0.6308 | Edges: 8122/162006 (5.01% retained)
[GraphSAGE] Epoch 99/100 | Loss: 0.6612 | Val Loss: 0.9176 | Top-1: 0.7129 | Top-5: 0.9208 | Precision: 0