In [1]:
import pandas as pd 
from torch_geometric.data import Dataset, Data # Installieren
import os
import numpy as np
import torch # Installieren
import scipy.sparse # Installieren
from tqdm import tqdm

import gseapy as gp
import time

In [2]:
kegg_gmt = gp.parser.get_library('KEGG_2021_Human', organism='Human', min_size=3, max_size=2000)

In [5]:
# create files for each list
gdsc_dataset = pd.read_csv('/sybig/home/tmu/TUGDA/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:] 
pathway_list = list(kegg_gmt.keys())
expression_data = gdsc_dataset.iloc[:, :1780]

# combine 3-Fold-Validation Files
response_1 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k1.csv", index_col=0)
response_2 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k2.csv", index_col=0)
response_3 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k3.csv", index_col=0)
response_data = pd.concat([response_1, response_2, response_3], axis=0, ignore_index=False)

expression_data = expression_data.sort_index()
response_data = response_data.sort_index()
labels_df = response_data 

In [6]:
expression_data = expression_data[~expression_data.index.duplicated(keep='first')]
labels_df = labels_df[~labels_df.index.duplicated(keep='first')]   

In [None]:
class DrugNetworkDataset(Dataset):
    def __init__(self, root, drug_list, gene_list, pathway_list, labels_df, expression_data, transform=None, pre_transform=None):
        self.drug_list = drug_list
        self.gene_list = gene_list
        self.pathway_list = pathway_list
        self.labels_df = labels_df
        self.expression_data = expression_data

        # Entferne Duplikate im Index
        self.expression_data = self.expression_data[~self.expression_data.index.duplicated(keep='first')]
        self.labels_df = self.labels_df[~self.labels_df.index.duplicated(keep='first')]

        # Alle Kombinationen aus Drug + Cell Line
        self.samples = [
            (drug, cell_line) 
            for drug in self.drug_list 
            for cell_line in self.labels_df.index
        ]

        self.custom_raw_dir = os.path.join(root, 'drug_matrices_csv')

        self.graph_cache = {}

        super(DrugNetworkDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return [f"{drug}_matrix.csv" for drug in self.drug_list]

    @property
    def raw_dir(self):
        return self.custom_raw_dir

    @property
    def processed_file_names(self):
        return [f'drug_{drug}_cellline_{cell_line}.pt' for drug, cell_line in self.samples]

    def download(self):
        pass  # Nicht benötigt

    def _load_adjacency_matrix(self, drug):
        """Lädt die Adjazenzmatrix für ein Drug"""
        csv_path = os.path.join(self.raw_dir, f"{drug}_matrix.csv")
        df = pd.read_csv(csv_path, index_col=0)
        return df

    def _get_node_features(self, nodes, cell_line):
        x = []
        for node in nodes:
            if node in self.gene_list:
                 # If the node is a gene, get its expression value for the given cell line and type [expr, is_gene, is_pathway]
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value), 1.0, 0.0])  # [expr, is_gene, is_pathway]
            elif node in self.pathway_list:
                # If the node is a pathway, use a default feature value of 0.0
                x.append([0.0, 0.0, 1.0])  # [expr_dummy, is_gene, is_pathway]
            else:
                # Unknown node
                x.append([0.0, 0.0, 0.0])  # [expr_dummy, is_gene, is_pathway]
        return torch.tensor(x, dtype=torch.float)
    
    def get_node_name_by_index(self, idx, node_index):
        drug, cell_line = self.samples[idx]
        
        # Lade die Adjazenzmatrix → gibt dir die Knotenliste
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        
        if node_index < len(nodes):
            return nodes[node_index]
        else:
            raise IndexError(f"Node index {node_index} out of range for this graph.")
    
    def _load_graph(self, drug):
        """Lädt oder nutzt gecachten Graphen für ein Drug"""
        if drug in self.graph_cache:
            return self.graph_cache[drug]

        csv_path = os.path.join(self.raw_dir, f"{drug}_matrix.csv")
        df = pd.read_csv(csv_path, index_col=0)
        nodes = df.columns.tolist()
        adj_matrix = df.values

        edge_index = torch.tensor(np.array(np.nonzero(adj_matrix)), dtype=torch.long)

        # Merke dir den Graphen für dieses Drug
        graph_data = {
            "nodes": nodes,
            "edge_index": edge_index
        }

        self.graph_cache[drug] = graph_data
        return graph_data

    def len(self):
        return len(self.samples)

    def get(self, idx):
        drug, cell_line = self.samples[idx]

        # 1. Hole den Graphen für das Drug (entweder cached oder neu geladen)
        graph_data = self._load_graph(drug)
        nodes = graph_data["nodes"]
        edge_index = graph_data["edge_index"]

        # 2. Node Features (ändern sich je nach Zelllinie)
        x = self._get_node_features(nodes, cell_line)

        # 3. Label
        y = torch.tensor([self.labels_df.loc[cell_line, drug]], dtype=torch.float)

        # 4. Erstelle Data-Objekt
        data = Data(x=x, edge_index=edge_index, y=y, drug=drug, cell_line=cell_line)
        data.nodes = nodes  # Optional: um Namen zu speichern

        return data

    def process(self):
        print(f"Processing dataset: {len(self.samples)} graphs to save")
        for idx, (drug, cell_line) in tqdm(enumerate(self.samples), total=len(self.samples), desc="Processing Graphs"):
            filename = f'drug_{drug}_cellline_{cell_line}.pt'
            pt_path = os.path.join(self.processed_dir, filename)

            if not os.path.exists(pt_path):
                data = self.get(idx)
                torch.save(data, pt_path)

In [112]:
dataset = DrugNetworkDataset(
    root="./results/Network/",
    drug_list=drug_list,
    gene_list=gene_list,
    pathway_list=pathway_list,
    labels_df=labels_df, 
    expression_data=expression_data
)

Processing...


Processing dataset: 160800 graphs to save


Processing Graphs:   7%|▋         | 11878/160800 [03:30<43:57, 56.47it/s]  


KeyboardInterrupt: 

In [105]:
data = dataset[0]

print("Drug Name:", data.drug)
print("Cell Line Name:", data.cell_line)

print("Edge Index:")
print(data.edge_index.t())
print(data.edge_index.t().shape)

print("\nNode Features (x):")
print(data.x)
print(data.x.shape)

nodes = data.nodes

for i in range(10):
    node_name = nodes[i]
    feature_value = data.x[i][0].item()  # Just take the first feature (gene expression)
    print(f"Index {i}: {node_name} → Feature: {feature_value:.4f}")

print("\nLabel (y):")
print(data.y)
print(data.y.shape)




Drug Name: Camptothecin
Cell Line Name: 22RV1
Edge Index:
tensor([[   0,    0],
        [   0,    8],
        [   0,    9],
        ...,
        [1368, 1368],
        [1369,  420],
        [1369, 1369]])
torch.Size([23662, 2])

Node Features (x):
tensor([[-0.8910,  1.0000,  0.0000],
        [ 0.5178,  1.0000,  0.0000],
        [-0.5985,  1.0000,  0.0000],
        ...,
        [-0.1359,  1.0000,  0.0000],
        [ 0.6368,  1.0000,  0.0000],
        [ 0.6063,  1.0000,  0.0000]])
torch.Size([1370, 3])
Index 0: ABL1 → Feature: -0.8910
Index 1: ACVR1B → Feature: 0.5178
Index 2: ADORA1 → Feature: -0.5985
Index 3: AR → Feature: 4.2872
Index 4: ATF4 → Feature: -0.1862
Index 5: ATM → Feature: -0.3401
Index 6: ATR → Feature: -0.1959
Index 7: AURKA → Feature: -0.3216
Index 8: BAX → Feature: -0.6138
Index 9: BBC3 → Feature: 0.4687

Label (y):
tensor([-3.1426])
torch.Size([1])


In [106]:
data = dataset[1]

print("Drug Name:", data.drug)
print("Cell Line Name:", data.cell_line)

print("Edge Index:")
print(data.edge_index.t())
print(data.edge_index.t().shape)

print("\nNode Features (x):")
print(data.x)
print(data.x.shape)

nodes = data.nodes

for i in range(10):
    node_name = nodes[i]
    feature_value = data.x[i][0].item()  # Just take the first feature (gene expression)
    print(f"Index {i}: {node_name} → Feature: {feature_value:.4f}")

print("\nLabel (y):")
print(data.y)
print(data.y.shape)


Drug Name: Camptothecin
Cell Line Name: 23132-87
Edge Index:
tensor([[   0,    0],
        [   0,    8],
        [   0,    9],
        ...,
        [1368, 1368],
        [1369,  420],
        [1369, 1369]])
torch.Size([23662, 2])

Node Features (x):
tensor([[-1.5698,  1.0000,  0.0000],
        [ 0.6585,  1.0000,  0.0000],
        [-0.5985,  1.0000,  0.0000],
        ...,
        [ 0.4112,  1.0000,  0.0000],
        [ 0.8274,  1.0000,  0.0000],
        [ 0.0980,  1.0000,  0.0000]])
torch.Size([1370, 3])
Index 0: ABL1 → Feature: -1.5698
Index 1: ACVR1B → Feature: 0.6585
Index 2: ADORA1 → Feature: -0.5985
Index 3: AR → Feature: -0.4920
Index 4: ATF4 → Feature: 0.8800
Index 5: ATM → Feature: 0.0323
Index 6: ATR → Feature: 0.2790
Index 7: AURKA → Feature: 1.0195
Index 8: BAX → Feature: -0.2325
Index 9: BBC3 → Feature: 0.6691

Label (y):
tensor([-3.2769])
torch.Size([1])


In [None]:
print("Anzahl Samples:", len(dataset))
data = dataset[0]
print("Node Features Shape:", data.x.shape)
print("Edge Index Shape:", data.edge_index.shape)
print("Label:", data.y.item())

In [87]:
# Check
print(expression_data[["ABL1"]])


              ABL1
22RV1    -0.890968
23132-87 -1.569799
42-MG-BA  0.536505
5637      0.873951
639-V    -0.708590
...            ...
YAPC      0.712502
YH-13     1.926564
YT       -2.203333
ZR-75-30 -0.708590
huH-1    -0.890968

[804 rows x 1 columns]


In [97]:
import torch
from torch_geometric.loader import DataLoader

batch_size = 8  # Je nach Graph-Größe und GPU-Speicher anpassen

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [102]:
 # Create a GCN model structure that contains two GCNConv layers relu activation and a dropout rate of 0.5. The model consists of 16 hidden channels.  

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.optim as optim

# ----------------------------
# 2. Modell definieren
# ----------------------------
class GCNModel(nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = torch.dropout(x, p=0.5, train=self.training)
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)


model = GCNModel(num_features=3, hidden_channels=16)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [103]:
def train():
    model.train()
    total_loss = 0
    for data_batch in tqdm(train_loader):  # Iteriere über Batches
        data_batch = data_batch.to(device)

        optimizer.zero_grad()
        out = model(data_batch.x, data_batch.edge_index, data_batch.batch)
        loss = criterion(out.squeeze(), data_batch.y)  # Verwende Label des gesamten Graphen
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data_batch.num_graphs

    return total_loss / len(dataset)

def evaluate(loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data_batch in tqdm(loader):
            data_batch = data_batch.to(device)
            out = model(data_batch.x, data_batch.edge_index, data_batch.batch)
            loss = criterion(out.squeeze(), data_batch.y)
            total_loss += loss.item() * data_batch.num_graphs
    return total_loss / len(dataset)

In [104]:
for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.6f}')

  0%|          | 0/20100 [00:00<?, ?it/s]

  0%|          | 40/20100 [00:45<6:17:21,  1.13s/it]


KeyboardInterrupt: 

In [None]:
model.eval()
with torch.no_grad():
    data_example = dataset[0].to(device)
    pred = model(data_example.x, data_example.edge_index, data_example.batch)
    print("Prediction:", pred.item())
    print("True Value:", data_example.y.item())

# Test

In [93]:
print(data)

Data(x=[1370, 3], edge_index=[2, 23662], y=[1], drug='Camptothecin', cell_line='23132-87', nodes=[1370])


In [94]:
 # Create a GCN model structure that contains two GCNConv layers relu activation and a dropout rate of 0.5. The model consists of 16 hidden channels.  

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.optim as optim

# ----------------------------
# 2. Modell definieren
# ----------------------------
class GCNModel(nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = torch.dropout(x, p=0.5, train=self.training)
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = GCNModel(num_features=3, hidden_channels=16)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# ----------------------------
# 3. Training
# ----------------------------
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.batch)
    loss = criterion(out, data.y.unsqueeze(0))
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(1, 101):
    loss = train()
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.6f}')

model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index, data.batch)
    print("Prediction:", pred.item())
    print("True Value:", data.y.item())

Epoch: 050, Loss: 0.070937
Epoch: 100, Loss: 0.001628
Prediction: -3.2174384593963623
True Value: -3.276945114135742


# Encoder

Komponenten: 
- GNNEncoder: Drug-Specific Network Graphen 
- Aim: tugda_mtl verarbeitet Rohdaten aus Expressionsprofilen --> Drug-spezfische Graphen aus DrugNetworkDataset verwenden, GNN-Encoder als Feature Extractor einbauen 

- Wir ersetzen:
- Die bisherigen nn.Linear-Schichten durch den GNNEncoder. Die Eingabe von X_train durch drug_data (vom Typ Data). Der Output des Encoders wird an die S, A etc. Schichten weitergeleitet

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

class GNNEncoder(nn.Module):
    def __init__(self, feature_size=1, embedding_size=512, output_dim=256):
        super(GNNEncoder, self).__init__()

        # GNN Layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform1 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn1 = nn.BatchNorm1d(embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)

        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform2 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn2 = nn.BatchNorm1d(embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)

        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform3 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn3 = nn.BatchNorm1d(embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)

        # Final projection
        self.linear1 = nn.Linear(embedding_size * 2, 512)
        self.linear2 = nn.Linear(512, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # First Block
        x = self.conv1(x, edge_index)
        x = F.relu(self.head_transform1(x))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Second Block
        x = self.conv2(x, edge_index)
        x = F.relu(self.head_transform2(x))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Third Block
        x = self.conv3(x, edge_index)
        x = F.relu(self.head_transform3(x))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Combine pooled features
        x = x1 + x2 + x3

        # Final layers
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.8, training=self.training)
        graph_embedding = self.linear2(x)

        return graph_embedding

In [51]:
# Hole einen echten Datensatz
data = dataset[0]  # Dataset[Index] gibt ein Data-Objekt zurück

encoder = GNNEncoder(feature_size=1, embedding_size=512, output_dim=256)

# Encoder aufrufen
encoder.eval()
with torch.no_grad():
    embedding = encoder(data)

print("Echtes Embedding Shape:", embedding.shape)
print(embedding)

# Output: Gewichte in einem latenten Raum

Echtes Embedding Shape: torch.Size([1, 256])
tensor([[-0.0340,  0.0056,  0.0333, -0.0014,  0.0465, -0.0045, -0.0033, -0.0056,
         -0.0343, -0.0003, -0.0384,  0.0118,  0.0132,  0.0353, -0.0043,  0.0051,
         -0.0179,  0.0193, -0.0360, -0.0123,  0.0289,  0.0181, -0.0017, -0.0368,
         -0.0129,  0.0491, -0.0041, -0.0287, -0.0133,  0.0006,  0.0312, -0.0048,
         -0.0107, -0.0153,  0.0420,  0.0122,  0.0162, -0.0282, -0.0230, -0.0037,
         -0.0423, -0.0312,  0.0334,  0.0155,  0.0018, -0.0340, -0.0045,  0.0048,
         -0.0317,  0.0129, -0.0150, -0.0003, -0.0010,  0.0195,  0.0010,  0.0294,
         -0.0101,  0.0239,  0.0062, -0.0011,  0.0234, -0.0264,  0.0053, -0.0193,
          0.0415, -0.0296,  0.0008, -0.0174,  0.0538, -0.0461, -0.0371,  0.0447,
         -0.0331, -0.0163, -0.0219,  0.0326,  0.0300,  0.0454, -0.0077, -0.0094,
         -0.0374, -0.0422,  0.0356,  0.0216, -0.0058, -0.0117, -0.0180,  0.0456,
          0.0244, -0.0266,  0.0173,  0.0364, -0.0243,  0.0089, -

# Cell Encoder

In [106]:
class Cell_EmbedNet(nn.Module):
    def __init__(self, fc_in_dim=1780, fc_hid_dim=[512, 512], embed_dim=256, dropout=0.5):
        super(Cell_EmbedNet, self).__init__()
        self.fc_hid_dim = fc_hid_dim
        self.fc = nn.Linear(fc_in_dim, self.fc_hid_dim[0])
        self.act = nn.ReLU()
        self.dropout = dropout
        self.classifier = nn.ModuleList()

        for input_size, output_size in zip(self.fc_hid_dim, self.fc_hid_dim[1:]):
            self.classifier.append(
                nn.Sequential(
                    nn.Linear(input_size, output_size),
                    nn.BatchNorm1d(output_size),
                    self.act,
                    nn.Dropout(p=self.dropout)
                )
            )
        self.fc2 = nn.Linear(self.fc_hid_dim[-1], embed_dim)

        # Weight init
        for layer in self.classifier:
            nn.init.xavier_uniform_(layer[0].weight)

    def forward(self, x):
        x = F.relu(self.fc(x))
        for fc in self.classifier:
            x = fc(x)
        x = self.fc2(x)
        return x

In [107]:
class DrugCellPredictor(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256):
        super(DrugCellPredictor, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)  # Ausgabe: z.B. AUC-Wert
        )

    def forward(self, drug_emb, cell_emb):
        combined = torch.cat([drug_emb, cell_emb], dim=1)  # shape: [batch_size, 512]
        prediction = self.mlp(combined)
        return prediction

In [None]:
class DrugResponseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.gnn_encoder = GNNEncoder(output_dim=256)
        self.cell_encoder = Cell_EmbedNet(fc_in_dim=1780, embed_dim=256)
        self.predictor = DrugCellPredictor(input_dim=512)

    def forward(self, drug_data, cell_expression):
        drug_emb = self.gnn_encoder(drug_data)
        cell_emb = self.cell_encoder(cell_expression)
        prediction = self.predictor(drug_emb, cell_emb)
        return prediction