In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.neighbors import kneighbors_graph
from sklearn.feature_selection import VarianceThreshold
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from lifelines import CoxPHFitter

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
USE_DAAL4PY_SKLEARN variable is deprecated for Intel(R) Extension for Scikit-learn
and will be delete in the 2022.1 release.
Please, use new construction of global patching:
python sklearnex.glob patch_sklearn
Read more: https://intel.github.io/scikit-learn-intelex/global_patching.html


In [2]:
def prefilter_variance(df, top_k=300, time_col="survival_time", event_col="event_status"):
    X = df.drop(columns=[time_col, event_col])
    y = df[[time_col, event_col]]
    variances = X.var().sort_values(ascending=False)
    top_features = variances.head(top_k).index
    X_top = X[top_features]
    return pd.concat([y, X_top], axis=1)

# applying Cox
def cox_feature_selection(df, time_col="survival_time", event_col="event_status", p_threshold=0.05):
    cph = CoxPHFitter(penalizer=0.1)
    cph.fit(df, duration_col=time_col, event_col=event_col)
    selected = cph.summary[cph.summary['p'] < p_threshold].index.tolist()
    return df[selected + [time_col, event_col]]

mirna_df = pd.read_csv("mirna.csv")
rnaseq_df = pd.read_csv("rnaseq.csv")


In [3]:
mirna_labels_raw = mirna_df["Unnamed: 0"]
rnaseq_labels_raw = rnaseq_df["Unnamed: 0"]

mirna_df = prefilter_variance(mirna_df, top_k=300)
rnaseq_df = prefilter_variance(rnaseq_df, top_k=300)

mirna_df = cox_feature_selection(mirna_df)
rnaseq_df = cox_feature_selection(rnaseq_df)

#Supervised labels are back here
mirna_df.insert(0, "Unnamed: 0", mirna_labels_raw)
rnaseq_df.insert(0, "Unnamed: 0", rnaseq_labels_raw)

def preprocess(df):
    le = LabelEncoder()
    labels = le.fit_transform(df["Unnamed: 0"].values)
    features = df.drop(columns=["Unnamed: 0", "survival_time", "event_status"], errors="ignore")
    features = StandardScaler().fit_transform(features)
    return features, labels

mirna_features, mirna_labels = preprocess(mirna_df)
rnaseq_features, rnaseq_labels = preprocess(rnaseq_df)

In [4]:
def build_graph(features, labels, k=5):
    knn = kneighbors_graph(features, k, include_self=False)
    edge_index = torch.tensor(np.vstack(knn.nonzero()), dtype=torch.long)
    x = torch.tensor(features, dtype=torch.float32)
    y = torch.tensor(labels, dtype=torch.long)
    return Data(x=x, edge_index=edge_index, y=y)

mirna_graph = build_graph(mirna_features, mirna_labels)
rnaseq_graph = build_graph(rnaseq_features, rnaseq_labels)

In [5]:
mirna_graph = build_graph(mirna_features, mirna_labels)
rnaseq_graph = build_graph(rnaseq_features, rnaseq_labels)

In [6]:
class GAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, heads=1):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_dim, hidden_dim, heads=heads)
        self.conv2 = GATConv(hidden_dim * heads, out_dim, heads=1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.elu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

In [7]:
def train_gat_embedding(model, data, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    model.train()
    for _ in range(epochs):
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        return model(data)

In [8]:
mirna_model = GAT(mirna_features.shape[1], 32, 16)
rnaseq_model = GAT(rnaseq_features.shape[1], 32, 16)

mirna_embed = train_gat_embedding(mirna_model, mirna_graph)
rnaseq_embed = train_gat_embedding(rnaseq_model, rnaseq_graph)


In [9]:
mirna_norm = F.normalize(mirna_embed, p=2, dim=1)
rnaseq_norm = F.normalize(rnaseq_embed, p=2, dim=1)
combined_features = torch.cat([mirna_norm, rnaseq_norm], dim=1)
combined_labels = torch.tensor(mirna_labels, dtype=torch.long)

In [10]:
X_train, X_test, y_train, y_test = train_test_split(
    combined_features, combined_labels,
    test_size=0.2, random_state=42, stratify=combined_labels
)


In [11]:
class FinalANN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(FinalANN, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.model(x)

final_ann = FinalANN(
    in_dim=X_train.shape[1],
    hidden_dim=64,
    out_dim=len(set(combined_labels.numpy()))
)

optimizer = torch.optim.Adam(final_ann.parameters(), lr=0.01, weight_decay=5e-4)

In [12]:
for epoch in range(100):
    final_ann.train()
    optimizer.zero_grad()
    out = final_ann(X_train)
    loss = F.cross_entropy(out, y_train)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f"[ANN] Epoch {epoch+1:03d} - Loss: {loss.item():.4f}")

[ANN] Epoch 010 - Loss: 0.5925
[ANN] Epoch 020 - Loss: 0.2380
[ANN] Epoch 030 - Loss: 0.1282
[ANN] Epoch 040 - Loss: 0.0921
[ANN] Epoch 050 - Loss: 0.0638
[ANN] Epoch 060 - Loss: 0.0625
[ANN] Epoch 070 - Loss: 0.0527
[ANN] Epoch 080 - Loss: 0.0617
[ANN] Epoch 090 - Loss: 0.0560
[ANN] Epoch 100 - Loss: 0.0488


In [13]:
final_ann.eval()
with torch.no_grad():
    preds = final_ann(X_test).argmax(dim=1)
    print("\n--- Classification Report: ANN on Fused GAT Embeddings ---")
    print(classification_report(y_test.numpy(), preds.numpy()))


--- Classification Report: ANN on Fused GAT Embeddings ---
              precision    recall  f1-score   support

           0       0.98      0.96      0.97        51
           1       0.95      0.97      0.96        39
           2       1.00      1.00      1.00        13

    accuracy                           0.97       103
   macro avg       0.98      0.98      0.98       103
weighted avg       0.97      0.97      0.97       103

