In [37]:
import torch
from torch_geometric.data import Data
import numpy as np
import scanpy as sc
import pandas as pd
from graph_construction import construct_graph
from tcn_autoencoder import train_and_extract_features
from data_splitting import create_train_val_tf, generate_train_val_kfold
from transformer_model import train_model, test_model
from test_split import split_edges, split_tfs


In [38]:
file_name = "mESC"
adata = sc.read("./" +file_name + ".h5ad")

filtered_refnet = adata.uns["grn"]
fea_df = pd.DataFrame(adata.uns["gpt_emb"])
fea_df["Gene"] = adata.var_names
X_norm = adata.X.T
tfs = np.unique(filtered_refnet["Gene1"])

In [39]:
data_orig, tfs_index, le = construct_graph(filtered_refnet, adata.var_names,fea_df, tfs)
data_orig

Data(x=[500, 1536], edge_index=[2, 2347])

In [40]:

tcauto_model, features = train_and_extract_features(X_norm,learning_rate=0.001, weight_decay=1e-4)
print(features.shape)




Epoch [1/200], Train Loss: 0.0245
Epoch [2/200], Train Loss: 0.0242
Epoch [3/200], Train Loss: 0.0241
Epoch [4/200], Train Loss: 0.0240
Epoch [5/200], Train Loss: 0.0239
Epoch [6/200], Train Loss: 0.0239
Epoch [7/200], Train Loss: 0.0239
Epoch [8/200], Train Loss: 0.0238
Epoch [9/200], Train Loss: 0.0237
Epoch [10/200], Train Loss: 0.0235
Epoch [11/200], Train Loss: 0.0233
Epoch [12/200], Train Loss: 0.0228
Epoch [13/200], Train Loss: 0.0219
Epoch [14/200], Train Loss: 0.0192
Epoch [15/200], Train Loss: 0.0120
Epoch [16/200], Train Loss: 0.0068
Epoch [17/200], Train Loss: 0.0056
Epoch [18/200], Train Loss: 0.0051
Epoch [19/200], Train Loss: 0.0046
Epoch [20/200], Train Loss: 0.0044
Epoch [21/200], Train Loss: 0.0043
Epoch [22/200], Train Loss: 0.0041
Epoch [23/200], Train Loss: 0.0040
Epoch [24/200], Train Loss: 0.0038
Epoch [25/200], Train Loss: 0.0036
Epoch [26/200], Train Loss: 0.0035
Epoch [27/200], Train Loss: 0.0034
Epoch [28/200], Train Loss: 0.0033
Epoch [29/200], Train Loss: 0

In [41]:
reduced_features_df = pd.DataFrame(features.cpu().numpy())
reduced_features_df["Gene"] = adata.var_names
reduced_features_df

# create features of nodes
node_features_1 = fea_df.set_index('Gene').reindex(le.classes_).fillna(0).values
node_features_2 = reduced_features_df.set_index('Gene').reindex(le.classes_).fillna(0).values


x = torch.tensor(node_features_1, dtype=torch.float)
x_additional = torch.tensor(node_features_2, dtype=torch.float)
# build graph object
data = Data(x=x, x_additional=x_additional, edge_index=data_orig.edge_index)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
data

Data(x=[500, 1536], edge_index=[2, 2347], x_additional=[500, 36])

In [42]:
train_val_sets_kfold = generate_train_val_kfold(data,10)


In [43]:
train_val_sets_kfold

[(Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 704], edge_label=[704]),
  Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 3286], edge_label=[3286])),
 (Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 704], edge_label=[704]),
  Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 3286], edge_label=[3286])),
 (Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 704], edge_label=[704]),
  Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 3286], edge_label=[3286])),
 (Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 704], edge_label=[704]),
  Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 3286], edge_label=[3286])),
 (Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 704], edge_label=[704]),
  Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2, 3286], edge_label=[3286])),
 (Data(x=[500, 1536], x_additional=[500, 36], edge_label_index=[2

In [44]:


train_val_sets = train_val_sets_kfold

results = []

for fold, (train_data, val_data) in enumerate(train_val_sets):
    print(f'Fold {fold + 1}:')
    
    # training
    model = train_model(
        train_data, 
        
        hidden_channels=64, 
        num_heads=16, 
        dropout=0.5, 
        lr=0.000005, 
        weight_decay=1e-3, 
        num_epochs=200, 
        print_interval=10
    )
    
    results.append({
        'fold': fold + 1,
       
        'model': model
    })
    
    print('--------------')



model_list = [r['model'] for r in results]
test_auc = [test_model(i, val_data) for i in model_list]
test_auc

Fold 1:
--------------
Fold 2:
--------------
Fold 3:
--------------
Fold 4:
--------------
Fold 5:
--------------
Fold 6:
--------------
Fold 7:
--------------
Fold 8:
--------------
Fold 9:
--------------
Fold 10:
--------------


[(0.9666951940880183, 0.9637874247670541),
 (0.9605497388406883, 0.9691244398393895),
 (0.9555319653212806, 0.9596006444278072),
 (0.9534135787739697, 0.9591827813750138),
 (0.9703399510623356, 0.974674461343778),
 (0.9776808524388558, 0.9805952992629237),
 (0.9436331563778506, 0.9492533476276084),
 (0.9554744009407463, 0.9545422047240011),
 (0.9701376971790318, 0.9723343541589413),
 (0.9783275368297054, 0.98069124476884)]