In [2]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(os.path.join(os.path.dirname('/home/ntmduy/GraphAE/src/model')))
sys.path.append(os.path.join(os.path.dirname('/home/ntmduy/GraphAE/src/utils')))

In [48]:
from model.GAE_Projection_Att import GAE_CLS_Link_NODE_Cosine_SupCon_2
from model.resnet_big import SupCEResNet, SupConResNet, LinearClassifier, SupIncepResnet
import torch
from torch_geometric.nn import summary
from thop import profile
import numpy as np
import time
from torch_geometric.loader import DataLoader
from utils.data import load_and_split_graphs
from torch.profiler import profile, record_function, ProfilerActivity
import torch.nn.functional as F
from thop import profile as thopprofile, clever_format
from torch_geometric.utils import to_dense_adj

In [4]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [5]:
def measure_time_gpu(dummy_input, model, device, rep, none_gnn=False):
    model = model.to(device=device)
    # dummy_input = torch.randn(1, 1, 29, 29, dtype=torch.float).to(device)
    # INIT LOGGERS
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = rep
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP
    for _ in range(100):
        if (none_gnn):
            _ = model(dummy_input)
        else:
            _ = model(dummy_input, device, acummulate = True, remove_random=True)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            if (none_gnn):
                _ = model(dummy_input)
            else:
                _ = model(dummy_input, device, acummulate = True, remove_random=True)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    return mean_syn, std_syn

In [6]:
def measure_time_cpu(model, device, rep = 10):
    model = model.to(device=device)
    x = torch.rand((1, 1, 29, 29), device=device)
    timings=np.zeros((rep,1))
    for i in range(rep):    
        start_time = time.time()
        out = model(x)
        timings[i] = time.time() - start_time
    mean_syn = np.sum(timings) / rep
    std_syn = np.std(timings)
    return mean_syn, std_syn

In [None]:
def calculate_graph_reconstruction_loss(reconstructed, original, batch_index, reduced="mean"):
    """
    Calculate the mean reconstruction loss for each graph in a batch.
    
    Args:
        reconstructed (torch.Tensor): Reconstructed node features, shape [num_nodes, num_features].
        original (torch.Tensor): Original node features, shape [num_nodes, num_features].
        batch_index (torch.Tensor): Batch indices mapping each node to a graph, shape [num_nodes].
        
    Returns:
        torch.Tensor: Mean reconstruction loss for each graph, shape [num_graphs].
    """
    # Calculate MSE loss for each node
    node_mse_loss = torch.mean((reconstructed - original) ** 2, dim=1)  # Shape: [num_nodes]

    # Aggregate losses for each graph
    num_graphs = batch_index.max().item() + 1  # Total number of graphs in the batch
    graph_losses = torch.zeros(num_graphs, device=reconstructed.device)  # Initialize graph loss storage

    for graph_id in range(num_graphs):
        # Mask to select nodes belonging to the current graph
        graph_mask = (batch_index == graph_id)
        
        # Mean MSE loss for the current graph
        if (reduced == "mean"):
            graph_losses[graph_id] = torch.mean(node_mse_loss[graph_mask])
        elif (reduced == "none"):
            return graph_losses
        else:
            graph_losses[graph_id] = sum(node_mse_loss[graph_mask])
    return graph_losses

def open_set_inference_optimized(ae_model,
                                 batch,
                                 thresholds,
                                 class_centers=None,
                                 device='cuda',
                                 num_known_classes=None): # Add num_known_classes
    """
    Optimized forward pass for open-set inference.
    """
    ae_model.eval()
    batch = batch.to(device)

    # --- Model Forward Pass (No changes here, assumed to be optimized within ae_model) ---
    with torch.no_grad():
        M_sup = ae_model(batch, device, acummulate=True, remove_random=True)
        graph_emb_sup = ae_model.graph_embedding(M_sup, batch.batch)
        graph_emb_sup = F.normalize(graph_emb_sup, p=2, dim=1) # [B, EmbDim]

        # Calculate distances to all class centers
        dists = torch.cdist(graph_emb_sup, class_centers, p=2) # [B, C_known]
        logits = -dists
        probs = F.softmax(logits, dim=1) # [B, C_known]
        max_probs, pred_cls_known = probs.max(dim=1) # max_probs and indices for known classes

        # Reconstruction
        # Adjacency reconstruction loss (graph-level)
        adj_rec = ae_model.adj_decode(M_sup, batch.batch, use_sigmoid=False)
        adj_ori = to_dense_adj(batch.edge_index,
                               batch.batch,
                               edge_attr=batch.edge_attr[:, 0].unsqueeze(1) if batch.edge_attr is not None else None).squeeze(3)
        # Ensure adj_rec and adj_ori have compatible shapes for mse_loss
        # This might require padding or careful handling if graph sizes vary significantly
        # and ae_model.adj_decode doesn't produce consistently sized outputs per graph
        # For now, assuming they are compatible or loss handles it.
        adj_loss = F.mse_loss(adj_rec, adj_ori, reduction='none').sum(dim=(1, 2)) # [B]

        # Node reconstruction loss (graph-level)
        node_hat = ae_model.node_recon(M_sup)
        node_rec_loss = calculate_graph_reconstruction_loss(node_hat, batch.x, batch.batch, reduced="sum") # [B]

    # --- Open-Set Logic (Vectorized) ---
    num_graphs_in_batch = graph_emb_sup.size(0)
    if num_known_classes is None and class_centers is not None:
        num_known_classes = class_centers.size(0)
    elif num_known_classes is None:
        raise ValueError("num_known_classes must be provided if class_centers is None or for unknown label assignment")

    # Initialize predictions as the predicted known class
    pred_out = pred_cls_known.clone() # [B]

    # Rule 1: Distance to predicted class center > its threshold (vectorized)
    # Gather the specific distance thresholds for each predicted class
    # thresholds['distance'] is expected to be a tensor or list of length C_known
    distance_thresholds_for_pred_cls = torch.tensor(thresholds['distance'], device=device)[pred_cls_known] # [B]
    distances_to_pred_cls = dists.gather(1, pred_cls_known.unsqueeze(1)).squeeze(1) # [B]
    unknown_due_to_distance_to_pred = (distances_to_pred_cls > distance_thresholds_for_pred_cls) # [B]

    # Initialize reasons (0: benign, 1: known_attack, >1: unknown due to specific anomaly)
    reasons = torch.zeros(num_graphs_in_batch, dtype=torch.long, device=device)
    reasons[pred_cls_known > 0] = 1 # Mark known attacks (assuming class 0 is benign)

    # SVDD distance for benign class (class 0)
    d_svdd_benign = dists[:, 0] # [B]

    # Identify samples initially predicted as benign (class 0) OR
    # those that already violated their predicted class's distance threshold
    # These are candidates for being re-classified as 'unknown' based on anomaly metrics
    # or staying benign if they pass all checks.
    benign_candidates_mask = (pred_cls_known == 0) | unknown_due_to_distance_to_pred # [B]

    # --- Anomaly checks for benign_candidates_mask ---
    # These flags are only relevant for samples in benign_candidates_mask
    adj_flag_all = (adj_loss > thresholds['adj']) # [B]
    node_flag_all = (node_rec_loss > thresholds['node']) # [B]
    svdd_flag_all = (d_svdd_benign > thresholds['distance'][0]) # [B], distance to benign center threshold

    # Combine flags for reason codes (only for benign candidates)
    # We apply the benign_candidates_mask *after* calculating all flags to keep indexing simple
    adj_flag_bc = adj_flag_all[benign_candidates_mask]
    node_flag_bc = node_flag_all[benign_candidates_mask]
    svdd_flag_bc = svdd_flag_all[benign_candidates_mask]

    # Calculate reason codes for the benign candidates subset
    flag_combo_bc = (adj_flag_bc.long() * 1 +
                     node_flag_bc.long() * 2 +
                     svdd_flag_bc.long() * 4) # yields 0-7 for benign candidates

    # Map combo to final reason codes (2-8 for anomalies, 0 if still benign)
    # combo_to_reason_map_tensor should be precomputed and on the correct device
    # For example:
    # _combo_map = {0:0, 1:2, 2:3, 3:5, 4:4, 5:6, 6:7, 7:8}     0 means stays benign/known attack
    # combo_to_reason_map_tensor = torch.tensor([_combo_map[i] for i in range(8)], device=device)
    # Optimized: Create this map directly
    combo_to_reason_map_tensor = torch.tensor([0, 2, 3, 5, 4, 6, 7, 8], device=device, dtype=torch.long)
    reasons_for_benign_candidates = combo_to_reason_map_tensor[flag_combo_bc]

    # Update reasons for those initially benign candidates
    # If reasons_for_benign_candidates > 0, it means it's an anomaly type
    reasons[benign_candidates_mask] = torch.where(
        reasons_for_benign_candidates > 0, # If an anomaly reason code was generated
        reasons_for_benign_candidates,     # Use that anomaly reason
        reasons[benign_candidates_mask]    # Otherwise, keep original reason (0 for benign, 1 for known attack if it was a misclassified known that passed distance)
    )

    # Samples become 'unknown' if:
    # 1. They were benign_candidates AND any of their anomaly flags (adj, node, svdd for benign) were true
    # 2. OR they initially violated their predicted class's distance threshold (unknown_due_to_distance_to_pred)
    is_unknown_anomaly = (benign_candidates_mask & ( (adj_flag_all | node_flag_all | svdd_flag_all)[benign_candidates_mask] )) \
                       | unknown_due_to_distance_to_pred

    pred_out[is_unknown_anomaly] = num_known_classes # Assign 'unknown' label (e.g., C_known if labels are 0 to C_known-1)

    # Ensure pred_out for samples that are not unknown, but were initially marked as known attacks (>0),
    # and did NOT violate distance_to_pred, retain their known attack label.
    # This is implicitly handled as pred_out is initialized with pred_cls_known.
    # We only overwrite with 'unknown' or keep benign/known_attack.

    return pred_out.cpu().numpy(), reasons.cpu().numpy(), \
           d_svdd_benign.cpu(), adj_loss.cpu(), node_rec_loss.cpu(), \
           graph_emb_sup # graph_emb_sup is already on device, return as is or move to cpu() if needed by caller

In [20]:
def open_set_inference(ae_model,
                       batch,
                       thresholds,
                       class_centers=None,
                       device='cuda',
                       num_classes = 5):
    ae_model.eval()

    batch = batch.to(device)
    with torch.no_grad():
        M_sup = ae_model(batch, device, acummulate=True, remove_random=True)
        graph_emb_sup = ae_model.graph_embedding(M_sup, batch.batch)
        graph_emb_sup = F.normalize(graph_emb_sup, p=2, dim=1)
        
        # features = ae_model.graph_pooling(M_sup, batch.batch)
        dists = torch.cdist(graph_emb_sup, class_centers, p=2) # [B, C_known]
        logits = -dists

        probs  = F.softmax(logits, dim=1)        # [B, C_known]
        _, pred_cls = probs.max(dim=1)

        # recon
        adj_rec  = ae_model.adj_decode(M_sup, batch.batch, use_sigmoid=False)
        node_hat = ae_model.node_recon(M_sup)
        adj_ori  = to_dense_adj(batch.edge_index,
                                batch.batch,
                                edge_attr=batch.edge_attr[:, 0].unsqueeze(1)).squeeze(3)
        adj_loss = F.mse_loss(adj_rec, adj_ori, reduction='none').sum(dim=(1,2))

        rec_loss = calculate_graph_reconstruction_loss(node_hat, batch.x, batch.batch, reduced="sum")

    # 1) For each sample, find the closest class center and check against its threshold
    pred_out = pred_cls.cpu().numpy()              # 0 = benign, >0 = known attack

    unknown_distance = []
    for i in range(len(pred_out)):
        unknown_distance.append(False)
        distance = dists[i, pred_out[i]].item()  # distance to the predicted class center
        if (distance > thresholds['distance'][pred_out[i]]):
            unknown_distance[i] = True
    unknown_distance = torch.tensor(unknown_distance).to(device)
   
    pred_benign_mask = torch.logical_or(pred_cls == 0, unknown_distance)   # apply distance check
    reasons = np.zeros_like(pred_out)          # default benign
    reasons[pred_cls.cpu().numpy() > 0] = 1    # known attack

    d_svdd = dists[:, 0].cpu()

    # -------------------------------------------------------- #
    #  Build reason codes for benign‑predicted subset
    #    2  = adj only
    #    3  = node only
    #    4  = svdd only
    #    5  = adj + node
    #    6  = adj + svdd
    #    7  = node + svdd
    #    8  = adj + node + svdd
    # -------------------------------------------------------- 

    if pred_benign_mask.any():
        # --- evaluate anomaly metrics ONLY for those benign-predicted graphs ---
        idx_benign = pred_benign_mask.nonzero(as_tuple=True)[0]

        # (a) reconstruction & SVDD for that subset
        adj_flag   = adj_loss[idx_benign].cpu()  > thresholds['adj']
        node_flag  = rec_loss[idx_benign].cpu() > thresholds['node']
        svdd_flag  = dists[idx_benign, 0].cpu() > thresholds['distance'][0]

        flag_combo = adj_flag.to(torch.int) * 1 \
                + node_flag.to(torch.int) * 2  \
                + svdd_flag.to(torch.int) * 4 # yields 0‑7
        combo_to_reason = {1:2, 2:3, 3:5,
                            4:4, 5:6, 6:7,
                            7:8}
        for local_idx, combo in enumerate(flag_combo.cpu().numpy()):
            if combo == 0:
                continue
            global_idx = idx_benign[local_idx].cpu().item()
            reasons[global_idx] = combo_to_reason[combo]

        # (b) If ANY metric says “anomaly” → flip to “unknown”
        unknown_idx = idx_benign[(adj_flag | node_flag | svdd_flag)]

        # write label = NUM_CLASSES  for those graphs
        pred_out[unknown_idx.cpu().numpy()] = num_classes   # label for unknown attack

    return pred_out, reasons, torch.tensor(d_svdd), adj_loss, rec_loss, graph_emb_sup

In [23]:
def analyze_performance(model, test_loader, device='cuda', num_warmup=10, num_repeats=100, num_classes =5):
    model = model.to(device)
    model.eval()

    times = []
    with torch.no_grad():
        # Warm-up
        print(f"Warming up for {num_warmup} iterations...")
        for i, batch in enumerate(test_loader):
            batch = batch.to(device)
            open_set_inference(model, 
                                         batch, 
                                         thresholds={'distance': [0.5]*num_classes, 'adj': 0.1, 'node': 0.1}, 
                                         device=device,
                                         class_centers=torch.zeros((num_classes, 128), device=device), 
                                         num_classes=num_classes)
            if i >= num_warmup - 1:
                break

        torch.cuda.synchronize()
        print(f"Starting timed inference over {num_repeats} iterations...")
        for i, batch in enumerate(test_loader):
            if i >= num_repeats:
                break
            batch = batch.to(device)
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            open_set_inference(model, 
                                batch, 
                                thresholds={'distance': [0.5]*num_classes, 'adj': 0.1, 'node': 0.1}, 
                                device=device,
                                class_centers=torch.zeros((num_classes, 128), device=device), 
                                num_classes=num_classes)
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            times.append(end_time - start_time)

        times = np.array(times)
        print(f"Tested {len(times)} samples.")
        print(f"Mean Latency: {np.mean(times)*1000:.2f} ms")
        print(f"Median Latency: {np.median(times)*1000:.2f} ms")
        print(f"Throughput: {1/np.mean(times):.2f} samples/sec")

        if torch.cuda.is_available():
            mem = torch.cuda.max_memory_allocated() / (1024**2)
            print(f"Max GPU memory used: {mem:.2f} MB")
            torch.cuda.reset_peak_memory_stats()

In [9]:
path = f'/home/ntmduy/GraphAE/data/mas/WS_300/step_300/9/edge_features_3/normalized_except_id/raw/sort_seperated'

train_graphs, test_graphs, graphs_names = load_and_split_graphs(path, exclude=[], train_ratio=0.8, seed=2025)

In [25]:
# Load data
warmup_loader = DataLoader([train_graphs[0]], batch_size=1, shuffle=False)

for graph in test_graphs:
    graph.edge_attr = torch.tensor(graph.edge_attr[:, 0].reshape(-1, 1), dtype=torch.float32)

test_loader = DataLoader(test_graphs, batch_size=1, shuffle=False)

dummy_input = next(iter(warmup_loader)).to('cuda')
dummy_input.edge_attr = torch.tensor(dummy_input.edge_attr[:, 0].reshape(-1, 1), dtype=torch.float32)

  graph.edge_attr = torch.tensor(graph.edge_attr[:, 0].reshape(-1, 1), dtype=torch.float32)
  dummy_input.edge_attr = torch.tensor(dummy_input.edge_attr[:, 0].reshape(-1, 1), dtype=torch.float32)


In [17]:
NUM_CLASSES = len(np.unique([g.y for g in train_graphs]))
NUM_CLASSES

6

In [None]:
model = GAE_CLS_Link_NODE_Cosine_SupCon_2(num_features=9, embedding_size=32, projection_emb=128, activate='gelu', layer_type='gatv2', num_layers=2, directed=False, id_dim=1, num_classes = NUM_CLASSES, linear_node=True, num_id_embeddings=2048, attn_head=1)
measure_time_gpu(dummy_input, model, 'cuda', rep=1000)

(2.946622718334198, 0.3170323983073103)

In [28]:
incep = SupIncepResnet(num_classes=NUM_CLASSES)
measure_time_gpu(torch.randn(1, 1, 29, 29, dtype=torch.float).cuda(), incep, 'cuda', rep=1000, none_gnn=True)

(2.271770525932312, 0.13737580240076167)

In [42]:
supcon = SupCEResNet(name='resnet18', num_classes=NUM_CLASSES)
measure_time_gpu(torch.randn(1, 1, 29, 29, dtype=torch.float).cuda(), supcon, 'cuda', rep=1000, none_gnn=True)

(3.411258074402809, 0.875621319126651)

In [29]:
analyze_performance(model, test_loader, device='cuda', num_warmup=10, num_repeats=100, num_classes=NUM_CLASSES)

Warming up for 10 iterations...
Starting timed inference over 100 iterations...


  return pred_out, reasons, torch.tensor(d_svdd), adj_loss, rec_loss, graph_emb_sup


Tested 100 samples.
Mean Latency: 8.73 ms
Median Latency: 8.67 ms
Throughput: 114.49 samples/sec
Max GPU memory used: 144.56 MB


In [51]:
def profile_gnn_model(model, example_data, device=None, repeat=10):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device).eval()
    example_data = example_data.to(device)
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                 record_shapes=True,
                 profile_memory=True,
                 with_flops=True) as prof:
        with torch.no_grad():
            for _ in range(repeat):
                with record_function("model_inference"):
                    model(example_data)
    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15))
    print(prof.key_averages().table(sort_by="flops", row_limit=15))
    # For total FLOPs (note: double to get MACs for real-valued ops)
    total_flops = sum([item.flops for item in prof.key_averages() if hasattr(item, 'flops')])
    print(f'Estimated total FLOPs: {clever_format([total_flops], "%.3f")}')
    print(f'Estimated total MACs: {(total_flops/2)} {clever_format([total_flops/2], "%.3f")}')

In [52]:
profile_gnn_model(supcon.to(device='cpu'), torch.randn(1, 1, 29, 29, dtype=torch.float).cuda())

STAGE:2025-06-06 10:46:42 3976940:3976940 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2025-06-06 10:46:42 3976940:3976940 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-06-06 10:46:42 3976940:3976940 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::cudnn_convolution        17.58%      11.820ms        25.63%      17.234ms      86.170us       1.408ms        74.58%       1.408ms       7.040us           0 

In [53]:
dummy_input = dummy_input.to(device='cpu')
model = model.to(device='cpu')
profile_gnn_model(model, dummy_input)

STAGE:2025-06-06 10:46:49 3976940:3976940 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2025-06-06 10:46:49 3976940:3976940 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-06-06 10:46:49 3976940:3976940 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_         4.12%       9.892ms        59.50%     142.954ms     137.192us     681.000us        39.14%     681.000us       0.654us           0 

In [54]:
macs, params = thopprofile(incep, inputs=(torch.randn(1, 1, 29, 29, dtype=torch.float).cuda(),), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}") # thop reports MACs
print(f"Params: {params}")

MACs: 97.191M
Params: 1.695M


In [43]:
macs, params = thopprofile(supcon, inputs=(torch.randn(1, 1, 29, 29, dtype=torch.float).cuda(),), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}") # thop reports MACs
print(f"Params: {params}")

MACs: 32.561M
Params: 700.662K


In [39]:
macs, params = thopprofile(model, inputs=(dummy_input,), verbose=False)

print(f"MACs: {macs}") # thop reports MACs
print(f"Params: {params}")

MACs: 0.0
Params: 0.0


In [None]:
summary(model.cuda(), dummy_input.cuda(), device='cuda')

'+-----------------------------------+------------------------------+-------------------------------+----------+\n| Layer                             | Input Shape                  | Output Shape                  | #Param   |\n|-----------------------------------+------------------------------+-------------------------------+----------|\n| GAE_CLS_Link_NODE_Cosine_SupCon_2 | [72, 72]                     | [72, 32]                      | 74,715   |\n| ├─(id_embedding)Embedding         | --                           | --                            | 2,048    |\n| ├─(encoder)Graph_Encoder_Norm     | [72, 9], [2, 136], [136, 1]  | [72, 32], [72, 32], [2304, 2] | 3,154    |\n| │    └─(bn)BatchNorm1d            | --                           | --                            | 18       |\n| │    └─(convs)ModuleList          | --                           | --                            | 2,944    |\n| │    │    └─(0)GATv2Conv          | [72, 9], [2, 136], [136, 1]  | [72, 32]                  