In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch_geometric.utils import trim_to_layer # add trim?
from torch.utils.data import DataLoader,TensorDataset
from src.data import DatasetLoader,GraphParamBuilder # build dataset
from torch_geometric.loader import NeighborLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch
import os

# dataset
loader = DatasetLoader()
node_df, edge_df = loader.load(p_value='0_001', resolution='1kb')

target = node_df['expression_level'] #regression task
mask = node_df['gene_in_bin']
input_features = node_df.loc[:,'clamp':'psq']
seed = 42

builder = GraphParamBuilder(
    node_df=node_df,
    edge_df=edge_df,
    target=target,
    mask=mask,
    input_features=input_features,
    seed=seed
)
tensors = builder.convert_to_tensors()

train_mask = pd.Series(tensors['train_mask'].numpy())
test_mask = pd.Series(tensors['test_mask'].numpy())

X_train = input_features[train_mask]
X_test = input_features[test_mask]
y_train = target[train_mask]
y_test = target[test_mask]

In [3]:
edge_df.head()

Unnamed: 0,chr1,start1,end1,chr2,start2,end2,contactCount,p-value,q-value,bias1,bias2,ExpCC,loop_size,bin1,bin2,p-value_transformed,loop_size_transformed
0,chr2L,5000,6000,chr2L,90000,91000,3,0.000876,1.0,0.53777,1.185121,0.181888,84.0,5,90,3.057712,1.924279
1,chr2L,5000,6000,chr2L,513000,514000,2,0.000307,1.0,0.53777,1.398324,0.024992,507.0,5,513,3.512654,2.705008
2,chr2L,5000,6000,chr2L,6521000,6522000,1,0.000952,1.0,0.53777,0.597264,0.000952,6515.0,5,6521,3.02153,3.813914
3,chr2L,5000,6000,chr2L,7354000,7355000,1,0.000946,1.0,0.53777,0.67552,0.000946,7348.0,5,7354,3.024241,3.866169
4,chr2L,5000,6000,chr2L,12729000,12730000,1,0.000701,1.0,0.53777,0.902604,0.000701,12723.0,5,12729,3.154235,4.10459


In [4]:
edge_df.describe()

Unnamed: 0,start1,end1,start2,end2,contactCount,p-value,q-value,bias1,bias2,ExpCC,loop_size,bin1,bin2,p-value_transformed,loop_size_transformed
count,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0,480644.0
mean,10314800.0,10315800.0,17248860.0,17249860.0,2.124899,0.000474,0.992108,1.061166,1.059414,0.38335,6933.062951,64294.23889,71228.301841,3.525585,3.496185
std,7040185.0,7040185.0,7192359.0,7192359.0,5.396204,0.000292,0.082847,0.322093,0.324198,2.645173,6328.934783,37310.367827,37998.520945,1.298602,0.822594
min,1000.0,2000.0,8000.0,9000.0,1.0,0.0,0.0,0.500016,0.500016,2.9e-05,0.0,5.0,8.0,3.0,-8.0
25%,4524000.0,4525000.0,12112000.0,12113000.0,1.0,0.000215,1.0,0.80594,0.803502,0.000307,1881.0,32821.0,39661.0,3.135471,3.274389
50%,9234000.0,9235000.0,17935000.0,17936000.0,1.0,0.000446,1.0,1.039888,1.034701,0.000623,4822.0,63545.0,70862.0,3.350656,3.683227
75%,15348000.0,15349000.0,22462000.0,22463000.0,1.0,0.000732,1.0,1.280151,1.278789,0.000939,10908.0,94032.0,103275.0,3.666971,4.037745
max,32070000.0,32071000.0,32072000.0,32073000.0,250.0,0.001,1.0,1.999974,1.999927,56.29553,31836.0,133751.0,133872.0,71.373651,4.502918


In [5]:
import torch


dst = torch.tensor([0, 0, 0, 1, 1, 2])   # target nodes (receivers)
H = 2 # number of attention heads
E = len(dst)


alpha = torch.tensor([
    [0.2, 0.1],   # edge 0 (to node 0)
    [0.9, 0.3],   # edge 1 (to node 0)
    [1.2, 0.01],  # edge 1b
    [0.5, 0.7],   # edge 2 (to node 1)
    [0.4, 0.8],   # edge 3 (to node 1)
    [0.1, 0.9],   # edge 4 (to node 2)
])

print("Raw alpha (E x H):\n", alpha)


k = 2 
keep_mask = torch.zeros_like(alpha, dtype=torch.bool)
for h in range(H):
    uniq_dst = torch.unique(dst)
    for node in uniq_dst.tolist():
        idx = (dst == node).nonzero(as_tuple=False).view(-1)
        a_j = alpha[idx, h]
        k_local = min(k, a_j.numel())
        _, topk_pos = torch.topk(a_j, k=k_local)
        keep_idx = idx[topk_pos]
        keep_mask[keep_idx, h] = True

print("\nKeep mask (E x H):\n", keep_mask)

alpha_pruned = alpha.masked_fill(~keep_mask, float('-inf'))
print("\nPruned alpha (E x H):\n", alpha_pruned)

from torch_geometric.utils import softmax
alpha_sm = softmax(alpha_pruned, dst)
print("\nSoftmax Applied to alpha (E x H):\n", alpha_sm)

Raw alpha (E x H):
 tensor([[0.2000, 0.1000],
        [0.9000, 0.3000],
        [1.2000, 0.0100],
        [0.5000, 0.7000],
        [0.4000, 0.8000],
        [0.1000, 0.9000]])

Keep mask (E x H):
 tensor([[False,  True],
        [ True,  True],
        [ True, False],
        [ True,  True],
        [ True,  True],
        [ True,  True]])

Pruned alpha (E x H):
 tensor([[  -inf, 0.1000],
        [0.9000, 0.3000],
        [1.2000,   -inf],
        [0.5000, 0.7000],
        [0.4000, 0.8000],
        [0.1000, 0.9000]])

Softmax Applied to alpha (E x H):
 tensor([[0.0000, 0.4502],
        [0.4256, 0.5498],
        [0.5744, 0.0000],
        [0.5250, 0.4750],
        [0.4750, 0.5250],
        [1.0000, 1.0000]])


In [25]:
from src.models import TopKGATConv
layerk = TopKGATConv(
    in_channels=2,
    out_channels=4,
    heads=2,
    concat=True,
    use_topk=True,
    k=2,  # keep only one top neighbor per node per head
    edge_dim=1
)

layer = TopKGATConv(
    in_channels=2,
    out_channels=4,
    heads=2,
    concat=True,
    use_topk=False,
    k=2,  # keep only one top neighbor per node per head
    edge_dim=1
)

In [26]:
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse

# --- Tiny graph ---
# adjacency:
# node0 ↔ node1
# node1 ↔ node2
adj = torch.tensor([
    [0, 1, 0],
    [1, 0, 1],
    [0, 1, 0]
])
edge_index = dense_to_sparse(adj)[0]

# Node features
x = torch.tensor([
    [1.0, 0.0],  # node 0
    [0.0, 1.0],  # node 1
    [1.0, 1.0],  # node 2
])

# Optional edge feature: counts (one scalar per edge)
edge_attr = torch.tensor([[1.0],  # 0 to 1
                          [2.0],  # 1 to 0
                          [2.0],  # 1 to 2 
                          [1.0]]) # 2 to 1

# 1 has two incoming edges...

# --- Model ---
# out = layer(x, edge_index, edge_attr)
out2 = layerk(x, edge_index, edge_attr)

tensor([[[-0.6464,  0.1877,  0.0896, -0.2850],
         [ 0.7310, -0.4835, -0.3028,  0.3192]],

        [[ 0.3003,  0.3848, -0.4977, -0.4591],
         [ 0.1963,  0.4662,  0.5596,  0.0278]],

        [[-0.3461,  0.5725, -0.4081, -0.7441],
         [ 0.9274, -0.0173,  0.2568,  0.3470]]], grad_fn=<ViewBackward0>) tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]]) tensor([[1.],
        [2.],
        [2.],
        [1.]])
Alpha (tensor([[-0.1556,  0.4697],
        [-0.0802,  0.0899],
        [-0.2357,  0.5596]], grad_fn=<SumBackward1>), tensor([[-0.2454,  0.4520],
        [ 0.4682,  0.0157],
        [ 0.2227,  0.4677]], grad_fn=<SumBackward1>))
Entering Edge Update
tensor([[ 0.3126,  0.4854],
        [-0.3256,  0.5419],
        [ 0.1426,  0.5576],
        [ 0.2324,  0.5753]], grad_fn=<AddBackward0>)
tensor([[1.],
        [2.],
        [2.],
        [1.]])
Alpha Edge tensor([[0.1078, 0.3753],
        [0.2156, 0.7506],
        [0.2156, 0.7506],
        [0.1078, 0.3753]], grad_fn=<SumBackward1>)
Alp

In [27]:
out2

tensor([[ 0.3003,  0.3848, -0.4977, -0.4591,  0.1963,  0.4662,  0.5596,  0.0278],
        [-0.5023,  0.3724, -0.1493, -0.5053,  0.8336, -0.2399, -0.0104,  0.3337],
        [ 0.3003,  0.3848, -0.4977, -0.4591,  0.1963,  0.4662,  0.5596,  0.0278]],
       grad_fn=<AddBackward0>)