In [2]:
from torch_geometric.data import HeteroData
import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv, Linear
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.loader import NeighborLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, GATConv
from torch.nn import Linear

In [3]:
# # Define the Heterogeneous GAT Model
# class HeteroGAT(torch.nn.Module):
#     def __init__(self, metadata, hidden_dim, out_dim, heads=4):
#         super().__init__()

#         self.convs = HeteroConv({
#             edge_type: GATConv((-1, -1), hidden_dim, heads=heads, concat=True, add_self_loops=False)
#             for edge_type in metadata[1]  # Edge types
#         }, aggr="sum")

#         self.final_conv = HeteroConv({
#             edge_type: GATConv((-1, -1), out_dim, heads=1, concat=False, add_self_loops=False)
#             for edge_type in metadata[1]
#         }, aggr="sum")

#         # self.edge_predictor = torch.nn.ModuleDict({
#         #     edge_type: Linear(out_dim * 2, 1) for edge_type in metadata[1]
#         # })
#         self.edge_predictor = torch.nn.ModuleDict({
#             '__'.join(edge_type): Linear(out_dim * 2, 1) for edge_type in metadata[1]
#         })
#     # def forward(self, x_dict, edge_index_dict):
#     #     x_dict = {key: F.elu(self.convs[key](x_dict[key], edge_index_dict[key])) 
#     #               for key in self.convs.keys()}
        
#     #     x_dict = {key: self.final_conv[key](x_dict[key], edge_index_dict[key])
#     #               for key in self.final_conv.keys()}
        
#     #     return x_dict  # Return final node embeddings
#     def forward(self, x_dict, edge_index_dict):
#         x_dict = self.convs(x_dict, edge_index_dict)
#         x_dict = {key: F.elu(x) for key, x in x_dict.items()}

#         x_dict = self.final_conv(x_dict, edge_index_dict)

#         return x_dict  # final node embeddings


#     # def predict_links(self, x_dict, edge_index_dict):
#     #     scores = {}
#     #     for edge_type, edge_index in edge_index_dict.items():
#     #         src, dst = edge_index
#     #         edge_feat = torch.cat([x_dict[edge_type[0]][src], x_dict[edge_type[-1]][dst]], dim=-1)
#     #         scores[edge_type] = torch.sigmoid(self.edge_predictor[edge_type](edge_feat)).squeeze()
#     #     return scores  # Return confidence scores, not boolean values
#     def predict_links(self, x_dict, edge_index_dict):
#         scores = {}
#         for edge_type, edge_index in edge_index_dict.items():
#             src, dst = edge_index
#             edge_feat = torch.cat([
#                 x_dict[edge_type[0]][src],
#                 x_dict[edge_type[-1]][dst]
#             ], dim=-1)

#             edge_type_str = '__'.join(edge_type)  # 🔥 fix here
#             scores[edge_type] = torch.sigmoid(self.edge_predictor[edge_type_str](edge_feat)).squeeze()

#         return scores



import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import HeteroConv, GATConv

class HeteroGAT(torch.nn.Module):
    def __init__(self, metadata, input_dims, hidden_dim, heads=4):
        super().__init__()
        self.metadata = metadata
        self.hidden_dim = hidden_dim
        self.heads = heads

        # 1. Per-node-type input projection
        self.node_proj = torch.nn.ModuleDict({
            ntype: Linear(in_dim, hidden_dim)
            for ntype, in_dim in input_dims.items()
        })

        # 2. Shared GAT layer across all edge types
        self.convs = HeteroConv({
            edge_type: GATConv((hidden_dim, hidden_dim), hidden_dim, heads=heads, concat=True, add_self_loops=False)
            for edge_type in metadata[1]
        }, aggr="sum")

        # 3. Per-edge-type link predictor
        self.edge_predictor = torch.nn.ModuleDict({
            '__'.join(edge_type): Linear(hidden_dim * heads * 2, 1)
            for edge_type in metadata[1]
        })

    def forward(self, x_dict, edge_index_dict):
        # 1. Project node features to shared hidden size
        x_dict = {
            ntype: self.node_proj[ntype](x.float())
            for ntype, x in x_dict.items()
        }

        # 2. Apply GATConv
        x_dict = self.convs(x_dict, edge_index_dict)
        x_dict = {k: F.elu(v) for k, v in x_dict.items()}
        return x_dict

    def predict_links(self, x_dict, edge_label_index_dict):
        scores = {}
        for edge_type, edge_index in edge_label_index_dict.items():
            src_type, _, dst_type = edge_type

            # Skip if embeddings were not returned
            if src_type not in x_dict or dst_type not in x_dict:
                print(f"⚠️ Skipping {edge_type} — embeddings missing for {src_type} or {dst_type}")
                continue

            src, dst = edge_index
            src_emb = x_dict[src_type][src]
            dst_emb = x_dict[dst_type][dst]
            edge_input = torch.cat([src_emb, dst_emb], dim=-1)
            edge_type_str = '__'.join(edge_type)

            logits = self.edge_predictor[edge_type_str](edge_input)
            scores[edge_type] = torch.sigmoid(logits).squeeze()
        return scores


In [12]:
from torch_geometric.utils import negative_sampling
from torch_geometric.data import HeteroData

def custom_link_split(data: HeteroData, neg_ratio=1, min_test_edges=10):
    train_data = HeteroData()
    test_data = HeteroData()

    # Copy node features
    for ntype in data.node_types:
        train_data[ntype].x = data[ntype].x
        train_data[ntype].num_nodes = data[ntype].num_nodes
        test_data[ntype].x = data[ntype].x
        test_data[ntype].num_nodes = data[ntype].num_nodes

    for edge_type in data.edge_types:
        edge_index = data[edge_type].edge_index
        edge_label = data[edge_type].edge_label

        # Check if test_mask is valid
        if "test_mask" in data[edge_type] and data[edge_type].test_mask.sum() > 0:
            test_mask = data[edge_type].test_mask
        else:
            # Fallback random test split
            num_edges = edge_index.size(1)
            perm = torch.randperm(num_edges)
            test_size = max(min_test_edges, int(0.1 * num_edges))
            test_mask = torch.zeros(num_edges, dtype=torch.bool)
            test_mask[perm[:test_size]] = True

        # Split
        pos_train_idx = (~test_mask).nonzero(as_tuple=False).view(-1)
        pos_test_idx = test_mask.nonzero(as_tuple=False).view(-1)

        if len(pos_train_idx) == 0 or len(pos_test_idx) == 0:
            print(f"⚠️ Skipping {edge_type} due to no train or test positives.")
            continue

        train_pos_edges = edge_index[:, pos_train_idx]
        test_pos_edges = edge_index[:, pos_test_idx]

        train_pos_labels = edge_label[pos_train_idx]
        test_pos_labels = edge_label[pos_test_idx]

        # Negative sampling
        num_train_neg = int(train_pos_edges.size(1) * neg_ratio)
        num_test_neg = int(test_pos_edges.size(1) * neg_ratio)

        neg_train_edges = negative_sampling(
            edge_index=train_pos_edges,
            num_nodes=(data[edge_type[0]].num_nodes, data[edge_type[2]].num_nodes),
            num_neg_samples=num_train_neg
        )
        neg_test_edges = negative_sampling(
            edge_index=test_pos_edges,
            num_nodes=(data[edge_type[0]].num_nodes, data[edge_type[2]].num_nodes),
            num_neg_samples=num_test_neg
        )

        train_data[edge_type].edge_index = torch.cat([train_pos_edges, neg_train_edges], dim=1)
        train_data[edge_type].edge_label = torch.cat([train_pos_labels, torch.zeros(num_train_neg, dtype=torch.long)], dim=0)
        train_data[edge_type].edge_label_index = torch.cat([train_pos_edges, neg_train_edges], dim=1)

        test_data[edge_type].edge_index = torch.cat([test_pos_edges, neg_test_edges], dim=1)
        test_data[edge_type].edge_label = torch.cat([test_pos_labels, torch.zeros(num_test_neg, dtype=torch.long)], dim=0)
        test_data[edge_type].edge_label_index = torch.cat([test_pos_edges, neg_test_edges], dim=1)

    return train_data, test_data


In [13]:
data = torch.load('./data/sw/sw_with_violations.pt', weights_only=False)
for edge_type in data.edge_types:
    if 'edge_label' in data[edge_type]:
        edge_label = data[edge_type]['edge_label']
        count_ones = (edge_label == 1).sum().item()
        count_zeros = (edge_label == 0).sum().item()
        print(f"{edge_type}: count of 1s in edge_label = {count_ones} and count of 0s = {count_zeros}")

train_data, test_data = custom_link_split(data)


('Character', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 173
('Character', 'BELONGS_TO', 'Faction'): count of 1s in edge_label = 0 and count of 0s = 105
('Character', 'DIED', 'Film'): count of 1s in edge_label = 0 and count of 0s = 20
('Character', 'HOMEWORLD', 'Planet'): count of 1s in edge_label = 0 and count of 0s = 87
('Character', 'KILLED', 'Character'): count of 1s in edge_label = 0 and count of 0s = 6
('Character', 'OF', 'Species'): count of 1s in edge_label = 0 and count of 0s = 82
('Character', 'PILOT', 'Starship'): count of 1s in edge_label = 0 and count of 0s = 31
('Character', 'PILOT', 'Vehicle'): count of 1s in edge_label = 0 and count of 0s = 13
('Planet', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 34
('Species', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 76
('Species', 'HOMEWORLD', 'Planet'): count of 1s in edge_label = 0 and count of 0s = 36
('Starship', 'APPEARED_IN', 'Film'): count 

In [10]:
train_data

HeteroData(
  Film={
    x=[7, 1],
    num_nodes=7,
  },
  Character={
    x=[87, 385],
    num_nodes=87,
  },
  Planet={
    x=[61, 0],
    num_nodes=61,
  },
  Species={
    x=[37, 384],
    num_nodes=37,
  },
  Vehicle={
    x=[39, 384],
    num_nodes=39,
  },
  Starship={
    x=[37, 0],
    num_nodes=37,
  },
  Faction={
    x=[22, 384],
    num_nodes=22,
  },
  (Character, APPEARED_IN, Film)={
    edge_index=[2, 221],
    edge_label=[221],
    edge_label_index=[2, 221],
  },
  (Character, BELONGS_TO, Faction)={
    edge_index=[2, 123],
    edge_label=[123],
    edge_label_index=[2, 123],
  },
  (Character, DIED, Film)={
    edge_index=[2, 19],
    edge_label=[19],
    edge_label_index=[2, 19],
  },
  (Character, HOMEWORLD, Planet)={
    edge_index=[2, 100],
    edge_label=[100],
    edge_label_index=[2, 100],
  },
  (Character, OF, Species)={
    edge_index=[2, 91],
    edge_label=[91],
    edge_label_index=[2, 91],
  },
  (Character, PILOT, Starship)={
    edge_index=[2, 23],
   

In [7]:
test_data

HeteroData(
  Film={
    x=[7, 1],
    num_nodes=7,
  },
  Character={
    x=[87, 385],
    num_nodes=87,
  },
  Planet={
    x=[61, 0],
    num_nodes=61,
  },
  Species={
    x=[37, 384],
    num_nodes=37,
  },
  Vehicle={
    x=[39, 384],
    num_nodes=39,
  },
  Starship={
    x=[37, 0],
    num_nodes=37,
  },
  Faction={
    x=[22, 384],
    num_nodes=22,
  },
  (Character, APPEARED_IN, Film)={
    edge_index=[2, 22],
    edge_label=[22],
    edge_label_index=[2, 22],
  },
  (Character, BELONGS_TO, Faction)={
    edge_index=[2, 13],
    edge_label=[13],
    edge_label_index=[2, 13],
  },
  (Character, DIED, Film)={
    edge_index=[2, 13],
    edge_label=[13],
    edge_label_index=[2, 13],
  },
  (Character, HOMEWORLD, Planet)={
    edge_index=[2, 13],
    edge_label=[13],
    edge_label_index=[2, 13],
  },
  (Character, OF, Species)={
    edge_index=[2, 13],
    edge_label=[13],
    edge_label_index=[2, 13],
  },
  (Character, PILOT, Starship)={
    edge_index=[2, 13],
    edge_lab

In [8]:
for edge_type in data.edge_types:
    if 'edge_label' in data[edge_type]:
        edge_label = data[edge_type]['edge_label']
        count_ones = (edge_label == 1).sum().item()
        count_zeros = (edge_label == 0).sum().item()
        print(f"{edge_type}: count of 1s in edge_label = {count_ones} and count of 0s = {count_zeros}")


('Character', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 173
('Character', 'BELONGS_TO', 'Faction'): count of 1s in edge_label = 0 and count of 0s = 105
('Character', 'DIED', 'Film'): count of 1s in edge_label = 0 and count of 0s = 20
('Character', 'HOMEWORLD', 'Planet'): count of 1s in edge_label = 0 and count of 0s = 87
('Character', 'KILLED', 'Character'): count of 1s in edge_label = 0 and count of 0s = 6
('Character', 'OF', 'Species'): count of 1s in edge_label = 0 and count of 0s = 82
('Character', 'PILOT', 'Starship'): count of 1s in edge_label = 0 and count of 0s = 31
('Character', 'PILOT', 'Vehicle'): count of 1s in edge_label = 0 and count of 0s = 13
('Planet', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 34
('Species', 'APPEARED_IN', 'Film'): count of 1s in edge_label = 0 and count of 0s = 76
('Species', 'HOMEWORLD', 'Planet'): count of 1s in edge_label = 0 and count of 0s = 36
('Starship', 'APPEARED_IN', 'Film'): count 

In [13]:
# Initialize Model
metadata = (data.node_types, data.edge_types)
model = HeteroGAT(metadata, hidden_dim=32, out_dim=16)



In [15]:
from torch.nn import BCEWithLogitsLoss

def train(model, data, optimizer, epochs=50, device='cpu'):
    model.to(device)
    data = data.to(device)
    for edge_type in data.edge_types:
        if data[edge_type].edge_index.dtype != torch.long:
            print(f"⚠️ Casting edge_index for {edge_type} to torch.long")
            data[edge_type].edge_index = data[edge_type].edge_index.long()
        if 'edge_label_index' in data[edge_type]:
            if data[edge_type].edge_label_index.dtype != torch.long:
                print(f"⚠️ Casting edge_label_index for {edge_type} to torch.long")
                data[edge_type].edge_label_index = data[edge_type].edge_label_index.long()
    loss_fn = BCEWithLogitsLoss()

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        x_dict = model(data.x_dict, data.edge_index_dict)
        x_dict = {k: v.float() for k, v in x_dict.items()}  # force dtype consistency

        total_loss = 0
        skipped = 0

        for edge_type in data.edge_types:
            if 'edge_label' not in data[edge_type]:
                continue

            if edge_type not in data.edge_label_index_dict:
                continue

            edge_label_index = data[edge_type].edge_label_index
            # Inside for edge_type in data.edge_types
            edge_label = data[edge_type].edge_label.float()

            # Check label class imbalance
            if edge_label.sum() == 0 or edge_label.sum() == edge_label.numel():
                print(f"⚠️ Skipping {edge_type} due to only one class in edge_label")
                continue

            src, dst = edge_label_index
            src_emb = x_dict[src_type][src]
            dst_emb = x_dict[dst_type][dst]

            # Check for NaNs
            if src_emb.isnan().any() or dst_emb.isnan().any():
                print(f"❌ NaNs in embeddings for {edge_type}")
                continue

            edge_input = torch.cat([src_emb, dst_emb], dim=-1)

            if edge_input.isnan().any() or edge_input.isinf().any():
                print(f"❌ NaNs/Infs in edge_input for {edge_type}")
                continue

            edge_type_str = '__'.join(edge_type)
            logits = model.edge_predictor[edge_type_str](edge_input).squeeze()

            loss = loss_fn(logits, edge_label)
            total_loss += loss

        total_loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1:02d} | Loss: {total_loss.item():.4f} | Skipped edge types: {skipped}")


In [20]:
for edge_type in test_data.edge_types:
    if 'edge_label_index' in test_data[edge_type]:
        if test_data[edge_type].edge_label_index.dtype != torch.long:
            test_data[edge_type].edge_label_index = test_data[edge_type].edge_label_index.long()


In [12]:
for ntype in train_data.x_dict:
    print(f"{ntype}: {train_data[ntype].x.shape}, dtype={train_data[ntype].x.dtype}")


Film: torch.Size([7, 1]), dtype=torch.float32
Character: torch.Size([87, 385]), dtype=torch.float32
Planet: torch.Size([61, 0]), dtype=torch.float32
Species: torch.Size([37, 384]), dtype=torch.float32
Vehicle: torch.Size([39, 384]), dtype=torch.float32
Starship: torch.Size([37, 0]), dtype=torch.float32
Faction: torch.Size([22, 384]), dtype=torch.float32


In [13]:
for edge_type in data.edge_types:
    src_type, _, dst_type = edge_type
    edge_index = data[edge_type].edge_index

    src_max = data[src_type].x.shape[0] - 1
    dst_max = data[dst_type].x.shape[0] - 1

    if edge_index[0].max() > src_max:
        print(f"🚨 Out-of-bounds src index in {edge_type}")
    if edge_index[1].max() > dst_max:
        print(f"🚨 Out-of-bounds dst index in {edge_type}")



In [16]:
input_dims = {ntype: train_data[ntype].x.size(1) for ntype in train_data.node_types}
for edge_type in test_data.edge_types:
    edge_index = test_data[edge_type].edge_index
    if edge_index.dtype != torch.long:
        print(f"⚠️ Casting {edge_type} edge_index to long")
        test_data[edge_type].edge_index = edge_index.long()
for edge_type in test_data.edge_types:
    if 'edge_label_index' in test_data[edge_type]:
        if test_data[edge_type].edge_label_index.dtype != torch.long:
            test_data[edge_type].edge_label_index = test_data[edge_type].edge_label_index.long()


model = HeteroGAT(metadata=train_data.metadata(), input_dims=input_dims, hidden_dim=64)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train(model, train_data, optimizer, epochs=50, device='cpu')  # or 'cuda' if applicable


⚠️ Skipping ('Character', 'APPEARED_IN', 'Film') due to only one class in edge_label
⚠️ Skipping ('Character', 'BELONGS_TO', 'Faction') due to only one class in edge_label
⚠️ Skipping ('Character', 'DIED', 'Film') due to only one class in edge_label
⚠️ Skipping ('Character', 'HOMEWORLD', 'Planet') due to only one class in edge_label
⚠️ Skipping ('Character', 'OF', 'Species') due to only one class in edge_label
⚠️ Skipping ('Character', 'PILOT', 'Starship') due to only one class in edge_label
⚠️ Skipping ('Character', 'PILOT', 'Vehicle') due to only one class in edge_label
⚠️ Skipping ('Planet', 'APPEARED_IN', 'Film') due to only one class in edge_label
⚠️ Skipping ('Species', 'APPEARED_IN', 'Film') due to only one class in edge_label
⚠️ Skipping ('Species', 'HOMEWORLD', 'Planet') due to only one class in edge_label
⚠️ Skipping ('Starship', 'APPEARED_IN', 'Film') due to only one class in edge_label
⚠️ Skipping ('Starship', 'BELONGS_TO', 'Faction') due to only one class in edge_label
⚠️ 

AttributeError: 'int' object has no attribute 'backward'

In [75]:
test_data

HeteroData(
  edge_label_index_dict={
    (Account, Repay, Loan)=[2, 0],
    (Account, Transfer, Account)=[2, 0],
    (Account, Withdraw, Account)=[2, 0],
    (Company, Apply, Loan)=[2, 0],
    (Company, Guarantee, Company)=[2, 0],
    (Company, Invest, Company)=[2, 0],
    (Company, Own, Account)=[2, 0],
    (Loan, Deposit, Account)=[2, 0],
    (Medium, SignIn, Account)=[2, 0],
    (Person, Apply, Loan)=[2, 0],
    (Person, Guarantee, Person)=[2, 0],
    (Person, Invest, Company)=[2, 0],
    (Person, Own, Account)=[2, 0],
  },
  Account={
    x=[20409, 2],
    num_nodes=20409,
  },
  Company={
    x=[3892, 1],
    num_nodes=3892,
  },
  Loan={
    x=[13833, 1],
    num_nodes=13833,
  },
  Medium={
    x=[9699, 1],
    num_nodes=9699,
  },
  Person={
    x=[8771, 384],
    num_nodes=8771,
  },
  (Account, Repay, Loan)={
    edge_index=[2, 0],
    edge_label=[0],
    edge_label_index=[2, 0],
  },
  (Account, Transfer, Account)={
    edge_index=[2, 0],
    edge_label=[0],
    edge_label_

In [26]:
model.eval()
x_dict = model(test_data.x_dict, test_data.edge_index_dict)
for edge_type in test_data.edge_types:
    if 'edge_index' in test_data[edge_type] and 'edge_label' in test_data[edge_type]:
        test_data[edge_type].edge_label_index = test_data[edge_type].edge_index


scores = model.predict_links(x_dict, test_data.edge_label_index_dict)
scores

⚠️ Skipping ('Character', 'APPEARED_IN', 'Film') — embeddings missing for Character or Film
⚠️ Skipping ('Character', 'BELONGS_TO', 'Faction') — embeddings missing for Character or Faction
⚠️ Skipping ('Character', 'DIED', 'Film') — embeddings missing for Character or Film
⚠️ Skipping ('Character', 'HOMEWORLD', 'Planet') — embeddings missing for Character or Planet
⚠️ Skipping ('Character', 'OF', 'Species') — embeddings missing for Character or Species
⚠️ Skipping ('Character', 'PILOT', 'Starship') — embeddings missing for Character or Starship
⚠️ Skipping ('Character', 'PILOT', 'Vehicle') — embeddings missing for Character or Vehicle


{('Planet',
  'APPEARED_IN',
  'Film'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Species',
  'APPEARED_IN',
  'Film'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Species',
  'HOMEWORLD',
  'Planet'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Starship',
  'APPEARED_IN',
  'Film'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Starship',
  'BELONGS_TO',
  'Faction'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Vehicle',
  'APPEARED_IN',
  'Film'): tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        grad_fn=<SqueezeBackward0>),
 ('Vehicle',
  'BELONGS_TO',
  'Faction'): tensor([nan, nan, nan, nan, nan, nan, nan,

In [21]:
type(scores)

dict

In [22]:
x_dict

{'Loan': tensor([[ 0.0472, -0.0254, -0.0076,  ..., -0.0315,  0.0160,  0.0019],
         [-0.1229, -0.3722,  0.1223,  ...,  0.3089,  0.2167, -0.0874],
         [ 0.0472, -0.0254, -0.0076,  ..., -0.0315,  0.0160,  0.0019],
         ...,
         [ 0.1447, -0.0338,  0.1699,  ...,  0.2122,  0.2029, -0.1371],
         [-0.1229, -0.3722,  0.1223,  ...,  0.3089,  0.2167, -0.0874],
         [ 0.0472, -0.0254, -0.0076,  ..., -0.0315,  0.0160,  0.0019]],
        grad_fn=<EluBackward0>),
 'Account': tensor([[-0.0197,  0.4517, -0.1674,  ..., -0.0354,  0.3186,  0.5537],
         [ 0.4111,  0.5238,  0.0223,  ...,  0.0992, -0.0566,  0.1245],
         [ 0.0306,  0.0035,  0.0276,  ..., -0.0334,  0.0408,  0.0720],
         ...,
         [ 0.0306,  0.0035,  0.0276,  ..., -0.0334,  0.0408,  0.0720],
         [ 0.0306,  0.0035,  0.0276,  ..., -0.0334,  0.0408,  0.0720],
         [ 0.0306,  0.0035,  0.0276,  ..., -0.0334,  0.0408,  0.0720]],
        grad_fn=<EluBackward0>),
 'Company': tensor([[-0.3749,  0.

In [78]:
from sklearn.metrics import accuracy_score

for edge_type, preds in scores.items():
    preds_binary = (preds > 0.5).long()  # 1 if prob > 0.5, else 0
    labels = test_data[edge_type].edge_label.long()

    acc = accuracy_score(labels.cpu(), preds_binary.cpu())
    print(f"✅ Accuracy for {edge_type}: {acc:.4f}")


✅ Accuracy for ('Account', 'Repay', 'Loan'): 1.0000
✅ Accuracy for ('Account', 'Transfer', 'Account'): 0.9870
✅ Accuracy for ('Account', 'Withdraw', 'Account'): 1.0000
✅ Accuracy for ('Company', 'Apply', 'Loan'): 0.9496
✅ Accuracy for ('Company', 'Guarantee', 'Company'): 1.0000
✅ Accuracy for ('Company', 'Invest', 'Company'): 1.0000
✅ Accuracy for ('Company', 'Own', 'Account'): 1.0000
✅ Accuracy for ('Loan', 'Deposit', 'Account'): 1.0000
✅ Accuracy for ('Person', 'Apply', 'Loan'): 1.0000
✅ Accuracy for ('Person', 'Guarantee', 'Person'): 0.7954
✅ Accuracy for ('Person', 'Invest', 'Company'): 1.0000
✅ Accuracy for ('Person', 'Own', 'Account'): 1.0000


HeteroData(
  Account={
    x=[20409, 2],
    node_id=[20409],
  },
  Company={
    x=[3892, 1],
    node_id=[3892],
  },
  Loan={
    x=[13833, 1],
    node_id=[13833],
  },
  Medium={
    x=[9699, 1],
    node_id=[9699],
  },
  Person={
    x=[8771, 384],
    node_id=[8771],
  },
  (Account, Repay, Loan)={
    edge_index=[2, 26556],
    edge_label=[26556],
    test_mask=[26556],
  },
  (Account, Transfer, Account)={
    edge_index=[2, 79909],
    edge_label=[79909],
    test_mask=[79909],
  },
  (Account, Withdraw, Account)={
    edge_index=[2, 90241],
    edge_label=[90241],
    test_mask=[90241],
  },
  (Company, Apply, Loan)={
    edge_index=[2, 4560],
    edge_label=[4560],
    test_mask=[4560],
  },
  (Company, Guarantee, Company)={
    edge_index=[2, 1820],
    edge_label=[1820],
    test_mask=[1820],
  },
  (Company, Invest, Company)={
    edge_index=[2, 6798],
    edge_label=[6798],
    test_mask=[6798],
  },
  (Company, Own, Account)={
    edge_index=[2, 6840],
    edge_labe

In [80]:
# Save node embeddings (x_dict) and edge scores (scores)
torch.save(x_dict, './data/finbench/x_dict.pt')
torch.save(scores, './data/finbench/scores.pt')


In [81]:
torch.save(test_data.edge_label_index_dict, "./data/finbench/edge_label_index_dict.pt")
