In [None]:
import pandas as pd # read Dataframe 
import torch
from torch_geometric.data import Dataset, Data # for creating graph-based datasets
import os # for file and directory operations 
import numpy as np # for numerical computations with arrays 
import gseapy as gp # for retrieving pathway information


# Creating a Custom Dataset in Pytorch Geometric 

# -----------------------------------
# Step 0: Load and Prepare DataFrames
# -----------------------------------

print("\n" + "="*31)
print("Step 0: Download and Preprocess")
print("="*31)

# Load the full GDSC dataset (FPKM + AUC values for all drugs)
gdsc_dataset = pd.read_csv('/sybig/home/tmu/TUGDA/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
# gdsc_dataset = pd.read_csv('/Users/tm03/Desktop/TUGDA_1/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
# Extract gene and drug columns:
# - First 1780 columns correspond to gene expression data
# - Remaining columns represent drug AUC values
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:] 

# Retrieve KEGG pathways using gseapy
kegg_gmt = gp.parser.get_library('KEGG_2021_Human', organism='Human', min_size=3, max_size=2000)
pathway_list = list(kegg_gmt.keys())

# Gene Expression Data (FPKM values)
expression_data = gdsc_dataset.iloc[:, :1780]

# Response Data: Combine log_IC50 values from 3-fold cross-validation test sets
response_1 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k1.csv", index_col=0)
response_2 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k2.csv", index_col=0)
response_3 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k3.csv", index_col=0)

# response_1 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k1.csv", index_col=0)
# response_2 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k2.csv", index_col=0)
# response_3 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k3.csv", index_col=0)
# response_data = pd.concat([response_1, response_2, response_3], axis=0, ignore_index=False)

# Sort both datasets by index to ensure alignment
expression_data = expression_data.sort_index()
response_data = response_data.sort_index()

# Remove duplicate indices (keep first occurrence) to avoid conflicts during merging
expression_data = expression_data[~expression_data.index.duplicated(keep='first')]
labels_df = response_data[~response_data.index.duplicated(keep='first')] 

print("\n" + "Done!")


Step 0: Download and Preprocess


FileNotFoundError: [Errno 2] No such file or directory: '/sybig/home/tmu/TUGDA/data/GDSCDA_fpkm_AUC_all_drugs.zip'

In [4]:
nan_count_total = expression_data.isna().sum().sum()
print(f"Anzahl der NaN-Werte insgesamt: {nan_count_total}")

nan_count_total = labels_df.isna().sum().sum()
print(f"Anzahl der NaN-Werte insgesamt: {nan_count_total}")

Anzahl der NaN-Werte insgesamt: 0
Anzahl der NaN-Werte insgesamt: 24340


In [5]:
print(labels_df)

          Oxaliplatin  Ulixertinib  Fulvestrant  Uprosertib  Dactinomycin  \
22RV1        3.081337     2.803572     4.227332    1.133014     -4.923521   
23132-87     4.391634     2.526244     3.697545    0.188871     -4.530814   
42-MG-BA     4.129735     3.147708     4.045351    1.682197     -4.231256   
5637         4.391109     2.884360     3.535578    4.453362     -4.262836   
639-V        3.253815     2.546112     3.550254    4.297324     -4.760262   
...               ...          ...          ...         ...           ...   
YAPC         5.245318     4.309617     4.761417    5.532699     -2.255783   
YH-13        3.945620     3.377953     3.812616    2.172373     -3.734196   
YT           0.461330     1.745128     2.024318    4.395238     -5.679264   
ZR-75-30     7.154014     6.182055     5.587853    3.443328     -0.055238   
huH-1        6.006727     4.090465     4.028713    3.325009     -0.850808   

          Docetaxel  Camptothecin  5-Fluorouracil  Afatinib  Taselisib  ...

In [6]:
###################
### GNN Dataset ###
###################

print("\n" + "="*33)
print("Step 1: GNN Dataset & Quick Check")
print("="*33)

class DrugNetworkDataset(Dataset):
    def __init__(self, root, drug_list, gene_list, pathway_list, labels_df, expression_data, transform=None, pre_transform=None):

        """
        A custom PyTorch Geometric Dataset for drug-cell line interaction graphs.

        Parameters:
            - root = Where the datase4t should be stored. This folder is split 
            - into raw_dir (downloaded datset) and processed_dir (processed data)

            - drug_list: List of drugs (tasks) (200)
            - gene_list: Lisz of genes sample (1780)
            - pathway_list: List of pathways from KEGG 
            - labels_df: response data (log_IC50)
            - expression_data: Gene expression values (preprocessed according to Mourragui et al. (2020): library-size using TMM, log-transformed, gene-level-mean-centering and standardization)
        """

        # Define all files that the dataset needs
        self.drug_list = drug_list
        self.gene_list = gene_list
        self.pathway_list = pathway_list
        self.labels_df = labels_df
        self.expression_data = expression_data

        # Get all combination of Drug + Cell Line
        self.samples = [
            (drug, cell_line) 
            for drug in self.drug_list 
            for cell_line in self.labels_df.index
        ]

        # Define custom raw_dir
        self.custom_raw_dir = os.path.join(root, 'drug_matrices_csv')

        super(DrugNetworkDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        """ If these files exist in raw_dir, the download is not triggered """
        return [f"{drug}_matrix.csv" for drug in self.drug_list]

    @property
    def raw_dir(self):
        return self.custom_raw_dir

    @property
    def processed_file_names(self):
        """ If these files exist in processed_dir, processing is skipped """
        return ['placeholder.pt'] # not implemented 
        #  return [f'data_{i}.pt' for i in range(len(self.samples))]

    def download(self):
        pass 

    def _load_adjacency_matrix(self, drug):
        """ Loads the adjacency matrix for every drug """
        csv_path = os.path.join(self.raw_dir, f"{drug}_matrix.csv")
        df = pd.read_csv(csv_path, index_col=0)
        return df

    def _get_node_features(self, nodes, cell_line):
        x = []
        for node in nodes:
            if node in self.gene_list:
                 # If the node is a gene, get its expression value for the given cell line and type [expr, is_gene, is_pathway]
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value), 1.0, 0.0])  # [expr, is_gene, is_pathway]
            elif node in self.pathway_list:
                # If the node is a pathway, use a default feature value of 0.0
                x.append([0.0, 0.0, 1.0])  # [expr_dummy, is_gene, is_pathway]
            else:
                # Unknown node
                x.append([0.0, 0.0, 0.0])  # [expr_dummy, is_gene, is_pathway]
        return torch.tensor(x, dtype=torch.float)

        ''' Code for just Gene expression value: 
        x = []
        for node in nodes:
             # If the node is a gene, get its expression value for the given cell line
            if node in self.gene_list:
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value)]) # Append as a single-element list 
            elif node in self.pathway_list:
                # If the node is a pathway, use a default feature value of 0.0
                x.append([0.0])
            else:
                x.append([0.0])
        return torch.tensor(x, dtype=torch.float)
        ''' 
    
    # Qucik check fuction for debugging
    def get_node_name_by_index(self, idx, node_index):
        """ Returns the name of the node at the specified index for the given sample """
        drug, cell_line = self.samples[idx]
        
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        
        if node_index < len(nodes):
            return nodes[node_index]
        else:
            raise IndexError(f"Node index {node_index} out of range for this graph.")

    def len(self):
        """ 
        Returns the total number of samples (drug-cell line combinations)
        useful for classes such as datasets for machine learning so that they are compatible with len() and for loops 
        """
        return len(self.samples)

    def get(self, idx):
        """
        In standard PyG datasets, get() loads preprocessed data saved via process(); 
        here the preprocessing step is skipped and construct each drug–cell line graph directly in get(), 
        as preprocessing wouldn't reduce computation time
        """
        drug, cell_line = self.samples[idx]

        # Load adjacency matric for each drug 
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        adj_matrix = adj_df.values
        # Build edge_index in COO format
        edge_index = torch.tensor(np.array(np.nonzero(adj_matrix)), dtype=torch.long)
        # Get node_features
        x = self._get_node_features(nodes, cell_line)
        # Get label info (log_IC50 value for each drug-cell line combination)
        # y = torch.tensor([self.labels_df.loc[cell_line, drug]], dtype=torch.float)

        label_value = self.labels_df.loc[cell_line, drug]
        if pd.isna(label_value):
            raise IndexError(f"Label is NaN for {drug} + {cell_line}. Skipping.")
        else:
            y = torch.tensor([float(label_value)], dtype=torch.float)

        # Create data object
        data = Data(x=x, edge_index=edge_index, y=y, drug=drug, cell_line=cell_line)

        # Attach node names for debugging purposes
        data.nodes = nodes

        return data
    
    def process(self):
        pass # Skip right now, can be implemented to store the output files in your
        '''
        for i in range(len(self.samples)):
            data = self.get(i)
            torch.save(data, os.path.join(self.processed_dir, f'data_{i}.pt'))
        '''

# Run class
dataset = DrugNetworkDataset(
    root="./results/Network/",
    drug_list=drug_list,
    gene_list=gene_list,
    pathway_list=pathway_list,
    labels_df=labels_df, 
    expression_data=expression_data
)

# Qucik Ckeck with results
data = dataset[1] # graph-base representations of drug-cell line pairs (200 * 1780)

print("Drug Name:", data.drug)
print("Cell Line Name:", data.cell_line)

print("Edge Index (COO format):") # tensor([a,b], [c,d]): node a is conntected to node b and node c is conntected to node d
print(data.edge_index.t())
print(data.edge_index.t().shape) # Tensor of shape [num_edges, 2]

print("\nNode Features (x):") # Gene expression values of gene_x with Cell line and is_gene, is_pathway
print(data.x)
print(data.x.shape) # Tensor of shape [num_nodes, num_node_features]

nodes = data.nodes

for i in range(10):
    node_name = nodes[i]
    feature_value = data.x[i][0].item()  # Just take the first feature (gene expression)
    print(f"Index {i}: {node_name} → Feature: {feature_value:.4f}")

print("\nLabel (y):") # log_IC50 value for the drug-cell line combination
print(data.y)
print(data.y.shape) # Tensor of shape [1]


Step 1: GNN Dataset & Quick Check


Processing...
Done!


Drug Name: Camptothecin
Cell Line Name: 23132-87
Edge Index (COO format):
tensor([[   0,    0],
        [   0,    4],
        [   0,    6],
        ...,
        [1825, 1825],
        [1826,  440],
        [1826, 1826]])
torch.Size([52461, 2])

Node Features (x):
tensor([[-1.5698,  1.0000,  0.0000],
        [ 0.6585,  1.0000,  0.0000],
        [-0.5985,  1.0000,  0.0000],
        ...,
        [-0.2597,  1.0000,  0.0000],
        [-0.4080,  1.0000,  0.0000],
        [ 0.1387,  1.0000,  0.0000]])
torch.Size([1827, 3])
Index 0: ABL1 → Feature: -1.5698
Index 1: ACVR1B → Feature: 0.6585
Index 2: ADORA1 → Feature: -0.5985
Index 3: AKT1 → Feature: 0.2557
Index 4: AR → Feature: -0.4920
Index 5: ATF4 → Feature: 0.8800
Index 6: ATM → Feature: 0.0323
Index 7: ATR → Feature: 0.2790
Index 8: AURKA → Feature: 1.0195
Index 9: BAX → Feature: -0.2325

Label (y):
tensor([-3.2769])
torch.Size([1])


In [53]:
dataset[dataset.samples.index(("BMS-754807", "23132-87"))]

IndexError: Label is NaN for BMS-754807 + 23132-87. Skipping.

In [7]:
# Check
print(expression_data[["ABL1"]])


              ABL1
22RV1    -0.890968
23132-87 -1.569799
42-MG-BA  0.536505
5637      0.873951
639-V    -0.708590
...            ...
YAPC      0.712502
YH-13     1.926564
YT       -2.203333
ZR-75-30 -0.708590
huH-1    -0.890968

[804 rows x 1 columns]


# First Encoder 

From the GDSC Data Portal, the gene expression data (RMA normalised expression data for cell-lines, Cell_line_RMA_proc_basalExp.txt) and two annotations (methSampleId_2_cosmicIds.xlsx, Mapping between cell-line COSMIC identifiers and cell-line methylation data identifiers + TableS1E.xlsx, Annotation of cell lines used in the GDSC dataset) was downloaded.

- three GNN layers with GATConv, Transform, TopK-Pooling

In [62]:
print(data)

Data(x=[1827, 3], edge_index=[2, 52461], y=[1], drug='Camptothecin', cell_line='23132-87', nodes=[1827])


In [27]:
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATConv, Linear, TopKPooling, global_mean_pool, global_max_pool
import torch.nn
import torch.nn.functional

class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN, self).__init__()
        num_classes = 1 
        embedding_size = 256

        # three GNN layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio = 0.8)

        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio = 0.5)

        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio = 0.2)
    
        # Linear Layers
        self.linear1 = Linear(embedding_size*2, 256)
        self.linear2 = Linear(256, num_classes)

    def forward(self, x, edge_index, batch_index):
        # First block
        x = self.conv1(x, edge_index)
        x = self.head_transform1(x)

        x, edge_index, _, batch_index, _, _ = self.pool1(x, edge_index, None, batch_index)
        x1 = torch.cat([global_max_pool(x, batch_index), global_mean_pool(x, batch_index)], dim=1)

        # Second block
        x = self.conv2(x, edge_index)
        x = self.head_transform2(x)

        x, edge_index, _, batch_index, _, _ = self.pool1(x, edge_index, None, batch_index)
        x2 = torch.cat([global_max_pool(x, batch_index), global_mean_pool(x, batch_index)], dim=1)

        # Third block
        x = self.conv3(x, edge_index)
        x = self.head_transform3(x)

        x, edge_index, _, batch_index, _, _ = self.pool1(x, edge_index, None, batch_index)
        x3 = torch.cat([global_max_pool(x, batch_index), global_mean_pool(x, batch_index)], dim=1)

        # Concat pooled vectors
        x = x1 + x2 + x3
        
        # Output block
        x = self.linear1(x).relu()
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        x = self.linear2(x)

        return x



In [28]:
# Batch erstellen (wichtig für Pooling-Layer)
data = dataset[0]
data.batch = torch.zeros(data.x.shape[0], dtype=torch.long)  # Alle Knoten sind im selben Graph

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import global_mean_pool

# ----------------------------
# 1. Modell, Optimizer & Loss definieren
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GNN(feature_size=data.x.shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

data = data.to(device)

# ----------------------------
# 2. Training Loop
# ----------------------------
def train():
    model.train()
    optimizer.zero_grad()

    out = model(data.x, data.edge_index, data.batch)
    loss = criterion(out, data.y.unsqueeze(0))  # y muss shape [batch_size, 1] haben
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(1, 201):
    loss = train()
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.6f}')

# ----------------------------
# 3. Evaluation
# ----------------------------
model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index, data.batch)
    print("Prediction:", pred.item())
    print("True Value:", data.y.item())

Epoch: 020, Loss: 0.032193
Epoch: 040, Loss: 2.162785
Epoch: 060, Loss: 2.746925
Epoch: 080, Loss: 0.493325
Epoch: 100, Loss: 1.822758
Epoch: 120, Loss: 0.406313
Epoch: 140, Loss: 0.111263
Epoch: 160, Loss: 1.357738
Epoch: 180, Loss: 0.004807
Epoch: 200, Loss: 0.022333
Prediction: -1.7428950071334839
True Value: -3.1426310539245605


# gcncONV 

In [66]:
 # Create a GCN model structure that contains two GCNConv layers relu activation and a dropout rate of 0.5. The model consists of 16 hidden channels.  

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.optim as optim

# ----------------------------
# 2. Modell definieren
# ----------------------------
class GCNModel(nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = torch.dropout(x, p=0.5, train=self.training)
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = GCNModel(num_features=3, hidden_channels=16)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# ----------------------------
# 3. Training
# ----------------------------
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.batch)
    loss = criterion(out, data.y.unsqueeze(0))
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(1, 201):
    loss = train()
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.6f}')

model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index, data.batch)
    print("Prediction:", pred.item())
    print("True Value:", data.y.item())

Epoch: 050, Loss: 0.078836
Epoch: 100, Loss: 0.000333
Epoch: 150, Loss: 0.000822
Epoch: 200, Loss: 0.000166
Prediction: -3.1559970378875732
True Value: -3.1426310539245605


In [None]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# ----------------------------
# 2. Modell definieren
# ----------------------------
class GCNModel(nn.Module):
    def __init__(self, num_features, hidden_channels):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = torch.dropout(x, p=0.5, train=self.training)
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

model = GCNModel(num_features=3, hidden_channels=16)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# ----------------------------
# 3. DataLoader erstellen
# ----------------------------
loader = DataLoader(dataset[:50], batch_size=4, shuffle=True)

# ----------------------------
# 4. Training Loop mit tqdm
# ----------------------------
def train():
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training", leave=False):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch).squeeze()
        loss = criterion(out, data.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

for epoch in range(1, 51):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.6f}')

                                                         

Epoch: 001, Loss: 9.173643


                                                         

Epoch: 002, Loss: 3.504404


                                                         

Epoch: 003, Loss: 2.370892


                                                         

Epoch: 004, Loss: 2.589767


                                                         

Epoch: 005, Loss: 2.235986


                                                         

Epoch: 006, Loss: 2.487135


                                                         

Epoch: 007, Loss: 2.589848


                                                         

Epoch: 008, Loss: 2.373393


                                                         

Epoch: 009, Loss: 2.463561


                                                         

Epoch: 010, Loss: 2.231457


                                                         

Epoch: 011, Loss: 2.176494


                                                         

Epoch: 012, Loss: 2.226722


                                                         

Epoch: 013, Loss: 2.207750


                                                         

Epoch: 014, Loss: 2.268195


                                                         

Epoch: 015, Loss: 2.517760


                                                         

Epoch: 016, Loss: 2.250463


                                                         

Epoch: 017, Loss: 2.322152


                                                         

Epoch: 018, Loss: 2.809301


                                                         

Epoch: 019, Loss: 2.335249


                                                         

Epoch: 020, Loss: 2.218212


                                                         

Epoch: 021, Loss: 2.332426


                                                         

Epoch: 022, Loss: 2.188250


                                                         

Epoch: 023, Loss: 2.170479


                                                         

Epoch: 024, Loss: 2.253595


                                                         

Epoch: 025, Loss: 2.287416


                                                         

Epoch: 026, Loss: 2.254548


                                                         

Epoch: 027, Loss: 2.238691


                                                         

Epoch: 028, Loss: 2.375519


                                                         

Epoch: 029, Loss: 2.409533


                                                         

Epoch: 030, Loss: 2.263764


                                                         

Epoch: 031, Loss: 2.347821


                                                         

Epoch: 032, Loss: 2.233757


                                                         

Epoch: 033, Loss: 2.288306


                                                         

Epoch: 034, Loss: 2.480765


                                                         

Epoch: 035, Loss: 2.269089


                                                         

Epoch: 036, Loss: 2.343753


                                                         

Epoch: 037, Loss: 2.163118


                                                         

Epoch: 038, Loss: 2.840684


                                                         

Epoch: 039, Loss: 2.288268


                                                         

Epoch: 040, Loss: 2.117000


                                                         

Epoch: 041, Loss: 2.294795


                                                         

Epoch: 042, Loss: 2.244043


                                                         

Epoch: 043, Loss: 2.206588


                                                         

Epoch: 044, Loss: 2.288073


                                                         

Epoch: 045, Loss: 2.161090


                                                         

Epoch: 046, Loss: 2.147072


                                                         

Epoch: 047, Loss: 2.875207


                                                         

Epoch: 048, Loss: 2.254188


                                                         

Epoch: 049, Loss: 2.100991


                                                         

Epoch: 050, Loss: 2.178792


Evaluating: 100%|██████████| 13/13 [00:10<00:00,  1.21it/s]


TypeError: got an unexpected keyword argument 'squared'

In [82]:
# ----------------------------
# 5. Evaluation für alle Samples
# ----------------------------
from sklearn.metrics import r2_score
from scipy.stats import pearsonr

model.eval()
preds = []
truths = []

with torch.no_grad():
    for data in tqdm(loader, desc="Evaluating"):
        out = model(data.x, data.edge_index, data.batch).squeeze()
        preds.extend(out.cpu().numpy())
        truths.extend(data.y.cpu().numpy())

r2 = r2_score(truths, preds)
pearson = pearsonr(preds, truths)[0]

print("\n" + "="*40)
print("Evaluation Metrics")
print("="*40)
print(f"R² Score: {r2:.4f}")
print(f"Pearson Correlation: {pearson:.4f}")

# ----------------------------
# 6. Einzelner Graph evaluieren
# ----------------------------
data_single = dataset[0]
pred_single = model(data_single.x, data_single.edge_index, data_single.batch)
print("\nEinzelvorhersage:")
print("Prediction:", pred_single.item())
print("True Value:", data_single.y.item())

Evaluating: 100%|██████████| 13/13 [00:11<00:00,  1.15it/s]



Evaluation Metrics
R² Score: 0.0472
Pearson Correlation: 0.2418

Einzelvorhersage:
Prediction: -3.2071785926818848
True Value: -3.1426310539245605


In [83]:
from sklearn.metrics import mean_squared_error

model.eval()
preds = []
truths = []

with torch.no_grad():
    for data in loader:  # Alle Daten durchgehen
        out = model(data.x, data.edge_index, data.batch).squeeze()
        preds.extend(out.cpu().numpy())
        truths.extend(data.y.cpu().numpy())

# Berechne MSE
mse = mean_squared_error(truths, preds)

print("\n" + "="*30)
print("Evaluation Metrics")
print("="*30)
print(f"MSE Loss: {mse:.6f}")


Evaluation Metrics
MSE Loss: 2.153636


# Encoder

Komponenten: 
- GNNEncoder: Drug-Specific Network Graphen 
- Aim: tugda_mtl verarbeitet Rohdaten aus Expressionsprofilen --> Drug-spezfische Graphen aus DrugNetworkDataset verwenden, GNN-Encoder als Feature Extractor einbauen 

- Wir ersetzen:
- Die bisherigen nn.Linear-Schichten durch den GNNEncoder. Die Eingabe von X_train durch drug_data (vom Typ Data). Der Output des Encoders wird an die S, A etc. Schichten weitergeleitet

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.data import DataLoader

class GNNEncoder(nn.Module):
    def __init__(self, feature_size=1, embedding_size=512, output_dim=256):
        super(GNNEncoder, self).__init__()

        # GNN Layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform1 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn1 = nn.BatchNorm1d(embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)

        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform2 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn2 = nn.BatchNorm1d(embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)

        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.6)
        self.head_transform3 = nn.Linear(embedding_size * 3, embedding_size)
        self.bn3 = nn.BatchNorm1d(embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.2)

        # Final projection
        self.linear1 = nn.Linear(embedding_size * 2, 512)
        self.linear2 = nn.Linear(512, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # First Block
        x = self.conv1(x, edge_index)
        x = F.relu(self.head_transform1(x))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Second Block
        x = self.conv2(x, edge_index)
        x = F.relu(self.head_transform2(x))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Third Block
        x = self.conv3(x, edge_index)
        x = F.relu(self.head_transform3(x))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Combine pooled features
        x = x1 + x2 + x3

        # Final layers
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.8, training=self.training)
        graph_embedding = self.linear2(x)

        return graph_embedding

In [85]:
# Hole einen echten Datensatz
data = dataset[0]  # Dataset[Index] gibt ein Data-Objekt zurück

encoder = GNNEncoder(feature_size=3, embedding_size=512, output_dim=256)

# Encoder aufrufen
encoder.eval()
with torch.no_grad():
    embedding = encoder(data)

print("Echtes Embedding Shape:", embedding.shape)
print(embedding)

# Output: Gewichte in einem latenten Raum

Echtes Embedding Shape: torch.Size([1, 256])
tensor([[-0.0012,  0.0287, -0.0383, -0.0205,  0.0359, -0.0504, -0.0081,  0.0101,
         -0.0466,  0.0271, -0.0150,  0.0136, -0.0172,  0.0354, -0.0151, -0.0158,
          0.0106, -0.0343,  0.0383, -0.0310,  0.0229,  0.0041, -0.0144,  0.0159,
          0.0294,  0.0265, -0.0489,  0.0140,  0.0096,  0.0385, -0.0157,  0.0044,
          0.0420,  0.0287, -0.0119,  0.0313,  0.0144,  0.0081, -0.0035,  0.0197,
         -0.0032,  0.0357, -0.0419, -0.0226,  0.0123, -0.0388,  0.0438, -0.0119,
         -0.0273, -0.0189,  0.0179,  0.0180,  0.0099,  0.0300,  0.0194,  0.0073,
         -0.0089, -0.0176, -0.0224, -0.0379,  0.0397,  0.0220, -0.0062, -0.0254,
          0.0340, -0.0484, -0.0065,  0.0296, -0.0079,  0.0006, -0.0468,  0.0034,
         -0.0219,  0.0361,  0.0011,  0.0002, -0.0355, -0.0314,  0.0272,  0.0355,
          0.0253,  0.0509, -0.0271, -0.0354, -0.0239, -0.0430,  0.0208, -0.0206,
          0.0351,  0.0190,  0.0282, -0.0193, -0.0428, -0.0007, -