In [48]:
import pandas as pd 
from torch_geometric.data import Dataset, Data # Installieren
import os
import numpy as np
import torch # Installieren
import scipy.sparse # Installieren
from tqdm import tqdm

import gseapy as gp
import time

In [2]:
kegg_gmt = gp.parser.get_library('KEGG_2021_Human', organism='Human', min_size=3, max_size=2000)

In [30]:
# create files for each list
gdsc_dataset = pd.read_csv('/Users/tm03/Desktop/TUGDA_1/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:] 
pathway_list = list(kegg_gmt.keys())
expression_data = gdsc_dataset.iloc[:, :1780]

# combine 3-Fold-Validation Files
expression_1 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k1.csv", index_col=0)
expression_2 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k2.csv", index_col=0)
expression_3 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k3.csv", index_col=0)
response_data = pd.concat([expression_1, expression_2, expression_3], axis=0, ignore_index=False)

expression_data = expression_data.sort_index()
response_data = response_data.sort_index()
labels_df = response_data 

In [69]:
expression_data = expression_data[~expression_data.index.duplicated(keep='first')]
labels_df = labels_df[~labels_df.index.duplicated(keep='first')]   

In [91]:
class DrugNetworkDataset(Dataset):
    def __init__(self, root, drug_list, gene_list, pathway_list, labels_df, expression_data, transform=None, pre_transform=None):
        self.drug_list = drug_list
        self.gene_list = gene_list
        self.pathway_list = pathway_list
        self.labels_df = labels_df
        self.expression_data = expression_data

        # Entferne Duplikate im Index
        self.expression_data = self.expression_data[~self.expression_data.index.duplicated(keep='first')]
        self.labels_df = self.labels_df[~self.labels_df.index.duplicated(keep='first')]

        # Alle Kombinationen aus Drug + Cell Line
        self.samples = [
            (drug, cell_line) 
            for drug in self.drug_list 
            for cell_line in self.labels_df.index
        ]

        self.custom_raw_dir = os.path.join(root, 'drug_matrices_csv')

        super(DrugNetworkDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [f"{drug}_matrix.csv" for drug in self.drug_list]

    @property
    def raw_dir(self):
        return self.custom_raw_dir

    @property
    def processed_file_names(self):
        return ['placeholder.pt']  # Dummy, damit Dataset-Klasse nicht meckert

    def download(self):
        pass  # Nicht benötigt

    def _load_adjacency_matrix(self, drug):
        """Lädt die Adjazenzmatrix für ein Drug"""
        csv_path = os.path.join(self.raw_dir, f"{drug}_matrix.csv")
        df = pd.read_csv(csv_path, index_col=0)
        return df

    def _get_node_features(self, nodes, cell_line):
        x = []
        for node in nodes:
            if node in self.gene_list:
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value)])
            elif node in self.pathway_list:
                x.append([0.0])
            else:
                x.append([0.0])
        return torch.tensor(x, dtype=torch.float)
    
    def get_node_name_by_index(self, idx, node_index):
        drug, cell_line = self.samples[idx]
        
        # Lade die Adjazenzmatrix → gibt dir die Knotenliste
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        
        if node_index < len(nodes):
            return nodes[node_index]
        else:
            raise IndexError(f"Node index {node_index} out of range for this graph.")

    def len(self):
        return len(self.samples)

    def get(self, idx):
        drug, cell_line = self.samples[idx]

        # 1. Adjazenzmatrix laden
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        adj_matrix = adj_df.values

        # 2. Edge-Index bauen
        edge_index = torch.tensor(np.array(np.nonzero(adj_matrix)), dtype=torch.long)

        # 3. Node Features
        x = self._get_node_features(nodes, cell_line)

        # 4. Label
        y = torch.tensor([self.labels_df.loc[cell_line, drug]], dtype=torch.float)

        # 5. Data-Objekt erstellen
        data = Data(x=x, edge_index=edge_index, y=y, drug=drug, cell_line=cell_line)
        data.nodes = nodes

        return data

In [92]:
dataset = DrugNetworkDataset(
    root="./results/Network/",
    drug_list=drug_list,
    gene_list=gene_list,
    pathway_list=pathway_list,
    labels_df=labels_df, 
    expression_data=expression_data
)

In [122]:
data = dataset[0]

print("Drug Name:", data.drug)
print("Cell Line Name:", data.cell_line)

print("Edge Index:")
print(data.edge_index.t())
print(data.edge_index.t().shape)

print("\nNode Features (x):")
print(data.x)
print(data.x.shape)

nodes = data.nodes

for i in range(10):
    node_name = nodes[i]
    feature_value = data.x[i].item()
    print(f"Index {i}: {node_name} → Feature: {feature_value:.4f}")

print("\nLabel (y):")
print(data.y)
print(data.y.shape)




Drug Name: Camptothecin
Cell Line Name: 22RV1
Edge Index:
tensor([[   0,    0],
        [   0,  182],
        [   0,  195],
        ...,
        [1337, 1337],
        [1338,  418],
        [1338, 1338]])
torch.Size([22505, 2])

Node Features (x):
tensor([[-0.8910],
        [ 0.5178],
        [-0.5985],
        ...,
        [-0.1359],
        [ 0.6368],
        [ 0.6063]])
torch.Size([1339, 1])
Index 0: ABL1 → Feature: -0.8910
Index 1: ACVR1B → Feature: 0.5178
Index 2: ADORA1 → Feature: -0.5985
Index 3: AR → Feature: 4.2872
Index 4: ATF4 → Feature: -0.1862
Index 5: ATM → Feature: -0.3401
Index 6: ATR → Feature: -0.1959
Index 7: AURKA → Feature: -0.3216
Index 8: BAX → Feature: -0.6138
Index 9: BBC3 → Feature: 0.4687

Label (y):
tensor([-3.1426])
torch.Size([1])


In [96]:
# Check
print(expression_data[["ABL1"]])


              ABL1
22RV1    -0.890968
23132-87 -1.569799
42-MG-BA  0.536505
5637      0.873951
639-V    -0.708590
...            ...
YAPC      0.712502
YH-13     1.926564
YT       -2.203333
ZR-75-30 -0.708590
huH-1    -0.890968

[804 rows x 1 columns]


# Encoder

Komponenten: 
- GNNEncoder: Drug-Specific Network Graphen 
- Aim: tugda_mtl verarbeitet Rohdaten aus Expressionsprofilen --> Drug-spezfische Graphen aus DrugNetworkDataset verwenden, GNN-Encoder als Feature Extractor einbauen 

- Wir ersetzen:
- Die bisherigen nn.Linear-Schichten durch den GNNEncoder. Die Eingabe von X_train durch drug_data (vom Typ Data). Der Output des Encoders wird an die S, A etc. Schichten weitergeleitet

In [119]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

class GNNEncoder(nn.Module):
    def __init__(self, feature_size=1, embedding_size=512, output_dim=256):
        super(GNNEncoder, self).__init__()

        # GNN Layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform1 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn1 = nn.BatchNorm1d(embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)

        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform2 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn2 = nn.BatchNorm1d(embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)

        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform3 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn3 = nn.BatchNorm1d(embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)

        # Final projection
        self.linear1 = nn.Linear(embedding_size * 2, 512)
        self.linear2 = nn.Linear(512, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # First Block
        x = self.conv1(x, edge_index)
        x = F.relu(self.head_transform1(x))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Second Block
        x = self.conv2(x, edge_index)
        x = F.relu(self.head_transform2(x))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Third Block
        x = self.conv3(x, edge_index)
        x = F.relu(self.head_transform3(x))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Combine pooled features
        x = x1 + x2 + x3

        # Final layers
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.8, training=self.training)
        graph_embedding = self.linear2(x)

        return graph_embedding

In [120]:
# Hole einen echten Datensatz
data = dataset[0]  # Dataset[Index] gibt ein Data-Objekt zurück

# Encoder aufrufen
encoder.eval()
with torch.no_grad():
    embedding = encoder(data)

print("Echtes Embedding Shape:", embedding.shape)
print(embedding)

# Output: Gewichte in einem latenten Raum

Echtes Embedding Shape: torch.Size([1, 256])
tensor([[ 0.0416, -0.0418, -0.0064, -0.0336,  0.0083, -0.0032,  0.0064, -0.0348,
          0.0170,  0.0100, -0.0263,  0.0370, -0.0149,  0.0343,  0.0127, -0.0105,
         -0.0248, -0.0263,  0.0216, -0.0502,  0.0193,  0.0387, -0.0115, -0.0072,
          0.0065,  0.0292,  0.0005, -0.0315, -0.0124, -0.0299,  0.0396, -0.0292,
         -0.0184, -0.0171,  0.0421, -0.0211,  0.0296,  0.0060, -0.0079,  0.0204,
         -0.0182,  0.0210, -0.0380, -0.0179,  0.0075,  0.0362,  0.0327, -0.0203,
         -0.0386,  0.0243, -0.0028,  0.0519,  0.0480, -0.0109,  0.0186, -0.0027,
         -0.0100, -0.0132,  0.0316, -0.0089,  0.0231, -0.0022, -0.0098, -0.0124,
          0.0508, -0.0189,  0.0070,  0.0277,  0.0263,  0.0277, -0.0169,  0.0167,
         -0.0045,  0.0102, -0.0247, -0.0313,  0.0104,  0.0041, -0.0474,  0.0233,
         -0.0481,  0.0526, -0.0039, -0.0140, -0.0386, -0.0350, -0.0024,  0.0520,
          0.0157, -0.0334,  0.0287, -0.0038, -0.0131, -0.0184,  

# Cell Encoder

In [106]:
class Cell_EmbedNet(nn.Module):
    def __init__(self, fc_in_dim=1780, fc_hid_dim=[512, 512], embed_dim=256, dropout=0.5):
        super(Cell_EmbedNet, self).__init__()
        self.fc_hid_dim = fc_hid_dim
        self.fc = nn.Linear(fc_in_dim, self.fc_hid_dim[0])
        self.act = nn.ReLU()
        self.dropout = dropout
        self.classifier = nn.ModuleList()

        for input_size, output_size in zip(self.fc_hid_dim, self.fc_hid_dim[1:]):
            self.classifier.append(
                nn.Sequential(
                    nn.Linear(input_size, output_size),
                    nn.BatchNorm1d(output_size),
                    self.act,
                    nn.Dropout(p=self.dropout)
                )
            )
        self.fc2 = nn.Linear(self.fc_hid_dim[-1], embed_dim)

        # Weight init
        for layer in self.classifier:
            nn.init.xavier_uniform_(layer[0].weight)

    def forward(self, x):
        x = F.relu(self.fc(x))
        for fc in self.classifier:
            x = fc(x)
        x = self.fc2(x)
        return x

In [107]:
class DrugCellPredictor(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256):
        super(DrugCellPredictor, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)  # Ausgabe: z.B. AUC-Wert
        )

    def forward(self, drug_emb, cell_emb):
        combined = torch.cat([drug_emb, cell_emb], dim=1)  # shape: [batch_size, 512]
        prediction = self.mlp(combined)
        return prediction

In [None]:
class DrugResponseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.gnn_encoder = GNNEncoder(output_dim=256)
        self.cell_encoder = Cell_EmbedNet(fc_in_dim=1780, embed_dim=256)
        self.predictor = DrugCellPredictor(input_dim=512)

    def forward(self, drug_data, cell_expression):
        drug_emb = self.gnn_encoder(drug_data)
        cell_emb = self.cell_encoder(cell_expression)
        prediction = self.predictor(drug_emb, cell_emb)
        return prediction