In [1]:
import pandas as pd # read Dataframe 
import torch
from torch_geometric.data import Dataset, Data # for creating graph-based datasets
import os # for file and directory operations 
import numpy as np # for numerical computations with arrays 
import gseapy as gp # for retrieving pathway information


# Creating a Custom Dataset in Pytorch Geometric 

# -----------------------------------
# Step 0: Load and Prepare DataFrames
# -----------------------------------

print("\n" + "="*31)
print("Step 0: Download and Preprocess")
print("="*31)

# Load the full GDSC dataset (FPKM + AUC values for all drugs)
# gdsc_dataset = pd.read_csv('/sybig/home/tmu/TUGDA/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
gdsc_dataset = pd.read_csv('/Users/tm03/Desktop/TUGDA_1/data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
# Extract gene and drug columns:
# - First 1780 columns correspond to gene expression data
# - Remaining columns represent drug AUC values
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:] 

# Retrieve KEGG pathways using gseapy
kegg_gmt = gp.parser.get_library('KEGG_2021_Human', organism='Human', min_size=3, max_size=2000)
pathway_list = list(kegg_gmt.keys())

# Gene Expression Data (FPKM values)
expression_data = gdsc_dataset.iloc[:, :1780]

# Response Data: Combine log_IC50 values from 3-fold cross-validation test sets
# response_1 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k1.csv", index_col=0)
# response_2 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k2.csv", index_col=0)
# response_3 = pd.read_csv("/sybig/home/tmu/TUGDA/data/cl_y_test_o_k3.csv", index_col=0)

response_1 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k1.csv", index_col=0)
response_2 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k2.csv", index_col=0)
response_3 = pd.read_csv("/Users/tm03/Desktop/TUGDA_1/data/cl_y_test_o_k3.csv", index_col=0)
response_data = pd.concat([response_1, response_2, response_3], axis=0, ignore_index=False)

# Sort both datasets by index to ensure alignment
expression_data = expression_data.sort_index()
response_data = response_data.sort_index()

# Remove duplicate indices (keep first occurrence) to avoid conflicts during merging
expression_data = expression_data[~expression_data.index.duplicated(keep='first')]
labels_df = response_data[~response_data.index.duplicated(keep='first')] 

print("\n" + "Done!")

###################
### GNN Dataset ###
###################

print("\n" + "="*33)
print("Step 1: GNN Dataset & Quick Check")
print("="*33)

class DrugNetworkDataset(Dataset):
    def __init__(self, root, drug_list, gene_list, pathway_list, labels_df, expression_data, transform=None, pre_transform=None):

        """
        A custom PyTorch Geometric Dataset for drug-cell line interaction graphs.

        Parameters:
            - root = Where the datase4t should be stored. This folder is split 
            - into raw_dir (downloaded datset) and processed_dir (processed data)

            - drug_list: List of drugs (tasks) (200)
            - gene_list: Lisz of genes sample (1780)
            - pathway_list: List of pathways from KEGG 
            - labels_df: response data (log_IC50)
            - expression_data: Gene expression values (preprocessed according to Mourragui et al. (2020): library-size using TMM, log-transformed, gene-level-mean-centering and standardization)
        """

        # Define all files that the dataset needs
        self.drug_list = drug_list
        self.gene_list = gene_list
        self.pathway_list = pathway_list
        self.labels_df = labels_df
        self.expression_data = expression_data

        # Get all combination of Drug + Cell Line
        self.samples = [
            (drug, cell_line) 
            for drug in self.drug_list 
            for cell_line in self.labels_df.index
        ]

        # Define custom raw_dir
        self.custom_raw_dir = os.path.join(root, 'drug_matrices_csv')

        super(DrugNetworkDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        """ If these files exist in raw_dir, the download is not triggered """
        return [f"{drug}_matrix.csv" for drug in self.drug_list]

    @property
    def raw_dir(self):
        return self.custom_raw_dir

    @property
    def processed_file_names(self):
        """ If these files exist in processed_dir, processing is skipped """
        return ['placeholder.pt'] # not implemented 
        #  return [f'data_{i}.pt' for i in range(len(self.samples))]

    def download(self):
        pass 

    def _load_adjacency_matrix(self, drug):
        """ Loads the adjacency matrix for every drug """
        csv_path = os.path.join(self.raw_dir, f"{drug}_matrix.csv")
        df = pd.read_csv(csv_path, index_col=0)
        return df

    def _get_node_features(self, nodes, cell_line):
        x = []
        for node in nodes:
            if node in self.gene_list:
                 # If the node is a gene, get its expression value for the given cell line and type [expr, is_gene, is_pathway]
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value), 1.0, 0.0])  # [expr, is_gene, is_pathway]
            elif node in self.pathway_list:
                # If the node is a pathway, use a default feature value of 0.0
                x.append([0.0, 0.0, 1.0])  # [expr_dummy, is_gene, is_pathway]
            else:
                # Unknown node
                x.append([0.0, 0.0, 0.0])  # [expr_dummy, is_gene, is_pathway]
        return torch.tensor(x, dtype=torch.float)

        ''' Code for just Gene expression value: 
        x = []
        for node in nodes:
             # If the node is a gene, get its expression value for the given cell line
            if node in self.gene_list:
                expr_value = self.expression_data.loc[cell_line, node]
                x.append([float(expr_value)]) # Append as a single-element list 
            elif node in self.pathway_list:
                # If the node is a pathway, use a default feature value of 0.0
                x.append([0.0])
            else:
                x.append([0.0])
        return torch.tensor(x, dtype=torch.float)
        ''' 
    
    # Qucik check fuction for debugging
    def get_node_name_by_index(self, idx, node_index):
        """ Returns the name of the node at the specified index for the given sample """
        drug, cell_line = self.samples[idx]
        
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        
        if node_index < len(nodes):
            return nodes[node_index]
        else:
            raise IndexError(f"Node index {node_index} out of range for this graph.")

    def len(self):
        """ 
        Returns the total number of samples (drug-cell line combinations)
        useful for classes such as datasets for machine learning so that they are compatible with len() and for loops 
        """
        return len(self.samples)

    def get(self, idx):
        """
        In standard PyG datasets, get() loads preprocessed data saved via process(); 
        here the preprocessing step is skipped and construct each drug–cell line graph directly in get(), 
        as preprocessing wouldn't reduce computation time
        """
        drug, cell_line = self.samples[idx]

        # Load adjacency matric for each drug 
        adj_df = self._load_adjacency_matrix(drug)
        nodes = adj_df.columns.tolist()
        adj_matrix = adj_df.values
        # Build edge_index in COO format
        edge_index = torch.tensor(np.array(np.nonzero(adj_matrix)), dtype=torch.long)
        # Get node_features
        x = self._get_node_features(nodes, cell_line)
        # Get label info (log_IC50 value for each drug-cell line combination)
        # y = torch.tensor([self.labels_df.loc[cell_line, drug]], dtype=torch.float)

        label_value = self.labels_df.loc[cell_line, drug]
        if pd.isna(label_value):
            return self.__getitem__((idx + 1) % len(self))  # nächstes Sample probieren
        
        y = torch.tensor([float(label_value)], dtype=torch.float)

        # Create data object
        data = Data(x=x, edge_index=edge_index, y=y, drug=drug, cell_line=cell_line)

        # Attach node names for debugging purposes
        data.nodes = nodes

        return data
    
    def process(self):
        pass # Skip right now, can be implemented to store the output files in your
        '''
        for i in range(len(self.samples)):
            data = self.get(i)
            torch.save(data, os.path.join(self.processed_dir, f'data_{i}.pt'))
        '''


Step 0: Download and Preprocess

Done!

Step 1: GNN Dataset & Quick Check


In [2]:
# Encoder 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

class GNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, heads, ratio):
        super().__init__()
        # self.conv = GATConv(in_channels, out_channels, heads=heads, dropout=0.6)
        # self.transform = nn.Linear(out_channels * heads, out_channels)
        self.conv = GCNConv(in_channels, out_channels)  # Kein heads-Parameter
        self.transform = nn.Linear(out_channels, out_channels)  # Ohne heads * out_channels
        self.pool = TopKPooling(out_channels, ratio=ratio)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv(x, edge_index)
        x = F.relu(self.transform(x))
        x = self.bn(x)
        x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch)
        return x, edge_index, batch

class GNNEncoder(nn.Module):
    def __init__(self, feature_size=1, embedding_size=32, output_dim=16):
        super(GNNEncoder, self).__init__()

        self.block1 = GNNBlock(in_channels=feature_size, out_channels=embedding_size, heads=3, ratio=0.8)
        self.block2 = GNNBlock(in_channels=embedding_size, out_channels=embedding_size, heads=3, ratio=0.5)
        self.block3 = GNNBlock(in_channels=embedding_size, out_channels=embedding_size, heads=3, ratio=0.2)

        # Final projection
        self.linear1 = nn.Linear(embedding_size * 2, 128)
        self.linear2 = nn.Linear(128, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Block 1
        x, edge_index, batch = self.block1(x, edge_index, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Block 2
        x, edge_index, batch = self.block2(x, edge_index, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Block 3
        x, edge_index, batch = self.block3(x, edge_index, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Combine pooled outputs
        x = x1 + x2 + x3

        # Final projection
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)  # Optional: dropout p=0.8 → 0.5
        graph_embedding = self.linear2(x)

        return graph_embedding



In [31]:
# Hole einen echten Datensatz
data = dataset[0]  # Dataset[Index] gibt ein Data-Objekt zurück

encoder = GNNEncoder(feature_size=3, embedding_size=512, output_dim=256)

# Encoder aufrufen
encoder.eval()
with torch.no_grad():
    embedding = encoder(data)

print("Echtes Embedding Shape:", embedding.shape)
print(embedding)

Echtes Embedding Shape: torch.Size([1, 256])
tensor([[ 4.8114e-03, -3.0771e-02, -2.0618e-02,  1.1778e-03, -3.1502e-03,
          4.1565e-02,  3.6314e-02, -3.2110e-02,  2.9935e-02,  4.1352e-02,
         -3.3374e-02, -3.5360e-02,  4.2497e-02,  2.8199e-02,  2.7739e-02,
         -3.4205e-02,  2.1935e-02,  3.2177e-02, -3.1567e-03, -1.9492e-02,
          1.8535e-02,  5.3408e-03, -1.8785e-02,  8.4074e-03, -2.4234e-02,
         -2.9076e-02, -2.4892e-02, -2.5499e-02, -2.2481e-02,  5.0862e-02,
         -4.7569e-02, -1.3031e-02, -1.7625e-02,  3.3975e-02, -7.3230e-03,
          3.5741e-02,  4.4565e-02, -4.6033e-02, -3.7713e-02,  4.9750e-02,
         -4.0473e-02,  1.9499e-02, -5.4038e-02,  1.5889e-04,  3.6636e-02,
         -2.4025e-02,  3.3086e-02,  6.4988e-03, -3.4207e-02,  1.7666e-02,
         -2.0782e-03, -3.1399e-02,  1.6178e-02, -2.8136e-02,  1.9356e-02,
          3.5235e-02,  2.6857e-02,  6.6925e-03, -3.2557e-03,  1.5801e-02,
          2.1448e-02,  3.4143e-02, -3.1542e-02, -1.1062e-02, -3.622

# TUGDA

In [None]:
#get list of drugs to be trained and predicted
folder = 'data/'
drug_list = pd.read_csv('{}/cl_y_test_o_k1.csv'.format(folder), index_col=0 )
drug_list = drug_list.columns

#3-fold training and test data;
train_data_report = {}
test_data_report = {}

for k in range(1,4):
    train_data_report['x_k_fold{}'.format(k)] = pd.read_csv('{}/cl_x_train_o_k{}.csv'.format(folder, k), index_col=0)
    train_data_report['y_k_fold{}'.format(k)] = pd.read_csv('{}/cl_y_train_o_k{}.csv'.format(folder, k), index_col=0)
    
    test_data_report['x_k_fold{}'.format(k)] = pd.read_csv('{}/cl_x_test_o_k{}.csv'.format(folder, k), index_col=0)
    test_data_report['y_k_fold{}'.format(k)] = pd.read_csv('{}/cl_y_test_o_k{}.csv'.format(folder, k), index_col=0)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [5]:
from pytorch_lightning import Callback

class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_test_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [21]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch_geometric.data import Data


class tugda_mtl(pl.LightningModule):
    def __init__(self, params, drug_list):
        super(tugda_mtl, self).__init__()
        
        # Hyperparameter
        self.learning_rate = params['lr']
        self.mu = params['mu']           # L1 regularization
        self.lambda_ = params['lambda_'] # L2 regularization
        self.gamma = params['gamma']     # Autoencoder loss weight
        self.passes = params['passes']   # MC-Dropout simulations
        self.drug_list = drug_list
        self.num_tasks = len(drug_list)
        
        # Encoder laden
        feature_size = 3  # expr_value, is_gene, is_pathway
        embedding_size = params.get('hidden_units_1', 512)
        output_dim = params.get('latent_space', 256)
        
        self.gnn_encoder = GNNEncoder(feature_size=feature_size,
                                      embedding_size=embedding_size,
                                      output_dim=output_dim)

        # Task-specific Head (S-Matrix analog)
        self.S = nn.Linear(output_dim, self.num_tasks)

        # Decoder (Autoencoder Loss)
        self.A = nn.Sequential(
            nn.Linear(self.num_tasks, output_dim),
            nn.ReLU()
        )

        # Uncertainty Tracking
        self.log_vars = torch.zeros(self.num_tasks, requires_grad=True, device=self.device)

    def forward(self, data):
        graph_emb = self.gnn_encoder(data)
        preds = self.S(graph_emb)
        rec_emb = self.A(preds)
        return preds, graph_emb, rec_emb

    def configure_optimizers(self):
        params = list(self.parameters()) + [self.log_vars]
        optimizer = torch.optim.Adagrad(params, lr=self.learning_rate)
        return optimizer

    def _mse_ignore_nan(self, preds, labels, reduction='mean'):
        mse_loss = torch.nn.MSELoss(reduction='none')

        # Sorge dafür, dass preds und labels 2D sind
        if preds.dim() == 1:
            preds = preds.unsqueeze(1)  # [batch_size] -> [batch_size, 1]
        if labels.dim() == 1:
            labels = labels.unsqueeze(1)  # [batch_size] -> [batch_size, 1]

        per_task_loss = torch.zeros(labels.shape[1], device=self.device)

        for k in range(labels.shape[1]):
            precision = torch.exp(-self.log_vars[k])
            mask = ~torch.isnan(labels[:, k])
            if mask.sum() == 0:
                continue
            diff = mse_loss(preds[mask, k], labels[mask, k])
            per_task_loss[k] = torch.mean(precision * diff) + self.log_vars[k]

        if reduction == 'mean':
            return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss
        elif reduction == 'none':
            return per_task_loss

    def training_step(self, train_batch, batch_idx):
        data = train_batch

        preds_simulation = torch.zeros(data.y.shape[0], self.num_tasks, self.passes, device=self.device)
        for sim in range(self.passes):
            pred, h, h_hat = self(data)
            preds_simulation[:, :, sim] = pred

        preds_mean = preds_simulation.mean(dim=2)  # [batch_size, num_tasks]

        # Stelle sicher, dass y 2D ist
        if data.y.dim() == 1:
            data.y = data.y.unsqueeze(1)  # [batch_size] -> [batch_size, 1]

        # Loss Berechnungen
        local_loss, task_loss = self._mse_ignore_nan(preds_mean, data.y)
        recon_loss = self.gamma * torch.nn.MSELoss()(h, h_hat)

        # Berechne a: muss Shape [num_tasks] haben
        var_preds = preds_simulation.var(dim=2).mean(dim=0)  # [num_tasks]

        # Sicherstellen, dass A richtig skaliert ist
        weight_sum = torch.abs(self.A[0].weight).sum(dim=1)  # [out_features] → muss [num_tasks] sein!

        # 🚨 Falls A nicht die richtige Form hat (z. B. bei num_tasks=1), fixe sie manuell:
        if weight_sum.shape[0] != self.num_tasks:
            weight_sum = weight_sum.view(self.num_tasks)

        a = 1 + (var_preds + weight_sum)  # [num_tasks]

        # Validierungsmaske
        valid_mask = ~torch.isnan(task_loss)  # [num_tasks]

        # Prüfung: Shapes müssen übereinstimmen
        assert a.shape == task_loss.shape, f"a.shape {a.shape} != task_loss.shape {task_loss.shape}"

        print("preds_mean:", preds_mean.shape)
        print("data.y:", data.y.shape)
        print("var_preds:", var_preds.shape)
        print("weight_sum:", weight_sum.shape)
        print("a:", a.shape)
        print("task_loss:", task_loss.shape)

        weighted_loss = (a[valid_mask] * task_loss[valid_mask]).sum()

        l1_S = self.mu * self.S.weight.norm(1)
        l2_L = self.lambda_ * (
            self.gnn_encoder.linear1.weight.norm(2) +
            self.gnn_encoder.linear2.weight.norm(2)
        )

        total_loss = weighted_loss + recon_loss + l1_S + l2_L
        self.log("train_loss", total_loss)

        return {'loss': total_loss}

    def test_step(self, test_batch, batch_idx):
        data = test_batch

        # MC Dropout aktivieren
        self.gnn_encoder.train()
        self.S.train()
        self.A.train()

        preds_simulation = torch.zeros(
            data.y.shape[0], 
            self.num_tasks, 
            self.passes, 
            device=self.device
        )

        for sim in range(self.passes):
            with torch.no_grad():
                pred, _, _ = self(data)
                preds_simulation[:, :, sim] = pred

        preds_mean = preds_simulation.mean(dim=2)  # Shape: [batch_size, num_tasks]

        # Hole den Index des aktuellen Drugs für jede Probe im Batch
        drug_indices = torch.tensor(
            [self.drug_list.index(drug) for drug in data.drug], 
            device=self.device
        )

        # Wähle nur die Vorhersage des zugehörigen Drugs aus
        preds_task_specific = preds_mean[torch.arange(preds_mean.size(0)), drug_indices]  # Shape: [batch_size]
        
        # Labels extrahieren
        labels = data.y.squeeze()  # Shape: [batch_size]

        # Berechne MSE pro Sample
        per_sample_loss = F.mse_loss(preds_task_specific, labels, reduction='none')  # Shape: [batch_size]

        # Erstelle Tensor zur Zuordnung: sample -> task
        task_ids = drug_indices  # Shape: [batch_size]

        # Initialisiere leeren Tensor für Task-spezifische Verluste
        task_losses = torch.zeros(self.num_tasks, device=self.device)
        counts = torch.zeros(self.num_tasks, device=self.device)

        # Aggregiere Loss pro Task
        for i in range(self.num_tasks):
            mask = (task_ids == i)
            if mask.any():
                task_losses[i] = per_sample_loss[mask].mean()
                counts[i] = mask.sum()
            else:
                task_losses[i] = float('nan')

        return {
            'test_preds': preds_mean.cpu().numpy(),
            'test_task_losses_per_class': task_losses.cpu().numpy(),
            'test_task_counts': counts.cpu().numpy()
        }

In [6]:
#best set of hyperparamters found on this dataset setting (GDSC)
# net_params = {
#  #tunned hyperparameters
#  'hidden_units_1': 1024,
#  'latent_space': 700,
#  'lr': 0.001,
#  'dropout': 0.1,
#  'mu': 0.01,
#  'lambda_': 0.001,
#  'gamma': 0.0001,
#  'bs': 300,
#  'passes': 50,
#  'num_tasks': 200,
#  'epochs': 100}

net_params = {
    # Reduzierte Größen für CPU & Einzel-Drug Training
    'hidden_units_1': 64,      # embedding_size in GNNEncoder
    'latent_space': 32,        # output_dim in GNNEncoder
    'lr': 0.001,               # Lernrate (kann so bleiben)
    'dropout': 0.1,            # kann so bleiben
    'mu': 0.01,                # L1 Reg.
    'lambda_': 0.001,          # L2 Reg.
    'gamma': 0.0001,           # Autoencoder Loss weight
    'bs': 2,                   # Batch Size runter auf 2–4
    'passes': 10,              # MC Dropout Simulationen runter
    'num_tasks': 1,            # Nur eine Drug!
    'epochs': 50               # Für schnelles Debugging reichen 50 Epochen
}

In [22]:
import torch
import random
import numpy as np
from pytorch_lightning import Trainer
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold

# -------------------------------
# 1. Seed setzen (für Reproduzierbarkeit)
# -------------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------------
# 2. Liste für Fehler speichern
# -------------------------------
error_list = []
pcorr_list = []  # Falls du Pearson Korrelation später hinzufügst

dataset = DrugNetworkDataset(
    root="./results/Network/",
    drug_list=drug_list,
    gene_list=gene_list,
    pathway_list=pathway_list,
    labels_df=labels_df, 
    expression_data=expression_data
)

# -------------------------------
# 3. Cell Lines splitten (KFold)
# -------------------------------
all_samples = dataset.samples
all_cell_lines = [cl for _, cl in all_samples]
unique_cell_lines = list(set(all_cell_lines))

kf = KFold(n_splits=3, shuffle=True, random_state=42)

fold = 1
for train_cl_indices, val_cl_indices in kf.split(unique_cell_lines):

    # Hole Cell Lines
    train_cl = [unique_cell_lines[i] for i in train_cl_indices]
    val_cl = [unique_cell_lines[i] for i in val_cl_indices]

    # Finde gültige Indices im Dataset
    # train_indices = [i for i, (drug, cl) in enumerate(all_samples) if cl in train_cl]

    selected_drugs = ['Doxorubicin']
    train_indices = [i for i, (drug, cl) in enumerate(all_samples) if cl in train_cl and drug in selected_drugs]

    val_indices = [i for i, (drug, cl) in enumerate(all_samples) if cl in val_cl]

    print(f"Fold {fold}:")
    print(f"  Train: {len(train_cl)} unique cell lines -> {len(train_indices)} samples")
    print(f"  Val:   {len(val_cl)} unique cell lines -> {len(val_indices)} samples")

    # -------------------------------
    # 4. DataLoader erstellen
    # -------------------------------
    train_dataset = dataset[train_indices]
    val_dataset = dataset[val_indices]

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=2, num_workers=0)

    # -------------------------------
    # 5. Seed setzen & Modell erstellen
    # -------------------------------
    seed = 42
    seed_everything(seed)

    model = tugda_mtl(net_params, drug_list)

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,  # Funktioniert sowohl mit als auch ohne GPU
        log_every_n_steps=10,
        deterministic=True
    )

    # -------------------------------
    # 6. Training & Testing
    # -------------------------------
    trainer.fit(model, train_loader)
    results = trainer.test(model, val_loader)

    # -------------------------------
    # 7. Fehler pro Drug speichern
    # -------------------------------
    task_losses = results[0]['test_task_losses_per_class']  # shape: (num_tasks,)
    error_mtl_nn_results = np.concatenate((
        np.array(drug_list, ndmin=2).T,
        np.array(task_losses, ndmin=2).T
    ), axis=1)

    error_list.append(error_mtl_nn_results)
    fold += 1

Processing...
Done!
INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | gnn_encoder | GNNEncoder | 42.3 K | train
1 | S           | Linear     | 6.6 K  | train
2 | A           | Sequential | 6.4 K  | train
---------------------------------------------------
55.3 K    Trainable params
0         Non-trainable params
55.3 K    Total params
0.221     Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode
/sybig/hom

Fold 1:
  Train: 536 unique cell lines -> 536 samples
  Val:   268 unique cell lines -> 53600 samples
Epoch 0:   0%|          | 0/268 [00:00<?, ?it/s] 

RuntimeError: shape '[200]' is invalid for input of size 32