In [20]:
from torch.utils.data import DataLoader, SequentialSampler
from dataclass_paired_vanilla import PairedVanilla

# Test-Dataset erstellen (Pfad und Dictionaries müssen übereinstimmen)
precision = "allele"
embed_base_dir = f"../../data_10x/embeddings/paired/{precision}"


# Lade Dictionaries (falls benötigt)
traV_dict = {}  # Lade die Dictionaries entsprechend
traJ_dict = {}
trbV_dict = {}
trbJ_dict = {}
mhc_dict = {}

test_file_path = "../../data_10x/splitted_datasets/allele/paired/test.tsv"
test_dataset = PairedVanilla(test_file_path, embed_base_dir, traV_dict, traJ_dict, trbV_dict, trbJ_dict, mhc_dict)

# Test-Dataloader erstellen
SEQ_MAX_LENGTH = 200  # Passe dies an dein Training an
BATCH_SIZE = 1  # Evaluierung oft Batchgröße 1

def pad_collate(self, batch):
        epitope_embeddings, tra_cdr3_embeddings, trb_cdr3_embeddings = [], [], []
        v_alpha, j_alpha, v_beta, j_beta = [], [], [], []
        epitope_sequence, tra_cdr3_sequence, trb_cdr3_sequence = [], [], []
        mhc = []
        task = []
        labels = []

        for item in batch:
            epitope_embeddings.append(item["epitope_embedding"])
            epitope_sequence.append(item["epitope_sequence"])
            tra_cdr3_embeddings.append(item["tra_cdr3_embedding"])
            tra_cdr3_sequence.append(item["tra_cdr3_sequence"])
            trb_cdr3_embeddings.append(item["trb_cdr3_embedding"])
            trb_cdr3_sequence.append(item["trb_cdr3_sequence"])
            v_alpha.append(item["v_alpha"])
            j_alpha.append(item["j_alpha"])
            v_beta.append(item["v_beta"])
            j_beta.append(item["j_beta"])
            mhc.append(item["mhc"])
            task.append(item["task"])
            labels.append(item["label"])

        max_length = self.seq_max_length

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    sampler=SequentialSampler(test_dataset),
    num_workers=0,  # Keep as 0 for compatibility
    collate_fn=pad_collate
)

Unmapped TRA_CDR3 sequences: 0 []
Unmapped TRB_CDR3 sequences: 0 []
Unmapped Epitope sequences: 0 []


In [7]:
import os
print(os.path.abspath("./data_10x/splitted_datasets/allele/paired/test.tsv"))


/home/ubuntu/PA-Cancer-Immunotherapy-Transformer/BA_ZHAW/models/paired_vanilla/data_10x/splitted_datasets/allele/paired/test.tsv


In [29]:
import torch
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from vanilla_model import VanillaModel  # Modellstruktur importieren

# Parameter wie beim Training definieren
EMBEDDING_SIZE = 1024
SEQ_MAX_LENGTH = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 0

# Eingebettete Dimensionen
traV_embed_len = 118
traJ_embed_len = 113
trbV_embed_len = 132
trbJ_embed_len = 29
mhc_embed_len = 79

# Hyperparameter wie im Training
hyperparameters = {
    "optimizer": "sgd",
    "learning_rate": 5e-3,
    "weight_decay": 0.075,
    "dropout_attention": 0.3,
    "dropout_linear": 0.45,
}

# Modell erstellen
model = VanillaModel(
    EMBEDDING_SIZE, 
    SEQ_MAX_LENGTH, 
    DEVICE, 
    traV_embed_len, 
    traJ_embed_len, 
    trbV_embed_len, 
    trbJ_embed_len, 
    mhc_embed_len, 
    hyperparameters
)

# Modellgewichte laden
checkpoint = torch.load("VanillaModel.pth", map_location=DEVICE)
model.load_state_dict(checkpoint, strict=True)
model.eval()

print("Modell erfolgreich geladen mit SEQ_MAX_LENGTH = 30.")


  checkpoint = torch.load("VanillaModel.pth", map_location=DEVICE)


Modell erfolgreich geladen mit SEQ_MAX_LENGTH = 30.


In [60]:
def filter_nan(values, default_value=0):
    import math
    cleaned_values = []
    for v in values:
        if v is None or (isinstance(v, float) and math.isnan(v)):
            # Ersetze None oder NaN durch default_value
            cleaned_values.append(default_value)
        elif isinstance(v, (int, float)):
            try:
                # Nur gültige Zahlen in int konvertieren
                cleaned_values.append(int(v))
            except (ValueError, TypeError):
                print(f"Warnung: Wert {v} konnte nicht konvertiert werden. Ersetze durch {default_value}.")
                cleaned_values.append(default_value)
        else:
            print(f"Warnung: Unerwarteter Wert {v}. Ersetze durch {default_value}.")
            cleaned_values.append(default_value)
    return cleaned_values

print(f"Vor Filterung: v_alpha={v_alpha}, j_alpha={j_alpha}, v_beta={v_beta}, j_beta={j_beta}")
v_alpha = filter_nan(v_alpha)
j_alpha = filter_nan(j_alpha)
v_beta = filter_nan(v_beta)
j_beta = filter_nan(j_beta)
print(f"Nach Filterung: v_alpha={v_alpha}, j_alpha={j_alpha}, v_beta={v_beta}, j_beta={j_beta}")


Vor Filterung: v_alpha=tensor([0], device='cuda:0', dtype=torch.int32), j_alpha=tensor([0], device='cuda:0', dtype=torch.int32), v_beta=tensor([0], device='cuda:0', dtype=torch.int32), j_beta=tensor([0], device='cuda:0', dtype=torch.int32)
Warnung: Unerwarteter Wert 0. Ersetze durch 0.
Warnung: Unerwarteter Wert 0. Ersetze durch 0.
Warnung: Unerwarteter Wert 0. Ersetze durch 0.
Warnung: Unerwarteter Wert 0. Ersetze durch 0.
Nach Filterung: v_alpha=[0], j_alpha=[0], v_beta=[0], j_beta=[0]


In [66]:
def pad_collate(self, batch):
    epitope_embeddings, tra_cdr3_embeddings, trb_cdr3_embeddings = [], [], []
    v_alpha, j_alpha, v_beta, j_beta = [], [], [], []
    epitope_sequence, tra_cdr3_sequence, trb_cdr3_sequence = [], [], []
    mhc = []
    task = []
    labels = []

    for item in batch:
        epitope_embeddings.append(item["epitope_embedding"])
        epitope_sequence.append(item["epitope_sequence"])
        tra_cdr3_embeddings.append(item["tra_cdr3_embedding"])
        tra_cdr3_sequence.append(item["tra_cdr3_sequence"])
        trb_cdr3_embeddings.append(item["trb_cdr3_embedding"])
        trb_cdr3_sequence.append(item["trb_cdr3_sequence"])
        v_alpha.append(item["v_alpha"])
        j_alpha.append(item["j_alpha"])
        v_beta.append(item["v_beta"])
        j_beta.append(item["j_beta"])
        mhc.append(item["mhc"])
        task.append(item["task"])
        labels.append(item["label"])

    # Entferne ungültige Werte
    v_alpha = filter_nan(v_alpha)
    j_alpha = filter_nan(j_alpha)
    v_beta = filter_nan(v_beta)
    j_beta = filter_nan(j_beta)

    # Padding Embeddings
    def pad_embeddings(embeddings):
        return torch.stack([
            torch.nn.functional.pad(embedding, (0, 0, 0, self.seq_max_length - embedding.size(0)), "constant", 0)
            for embedding in embeddings
        ])

    epitope_embeddings = pad_embeddings(epitope_embeddings)
    tra_cdr3_embeddings = pad_embeddings(tra_cdr3_embeddings)
    trb_cdr3_embeddings = pad_embeddings(trb_cdr3_embeddings)

    v_alpha = torch.tensor(v_alpha, dtype=torch.int32)
    j_alpha = torch.tensor(j_alpha, dtype=torch.int32)
    v_beta = torch.tensor(v_beta, dtype=torch.int32)
    j_beta = torch.tensor(j_beta, dtype=torch.int32)
    mhc = torch.tensor(mhc, dtype=torch.int32)

    labels = torch.stack(labels)

    return {
        "epitope_embedding": epitope_embeddings,
        "epitope_sequence": epitope_sequence,
        "tra_cdr3_embedding": tra_cdr3_embeddings,
        "tra_cdr3_sequence": tra_cdr3_sequence,
        "trb_cdr3_embedding": trb_cdr3_embeddings,
        "trb_cdr3_sequence": trb_cdr3_sequence,
        "v_alpha": v_alpha,
        "j_alpha": j_alpha,
        "v_beta": v_beta,
        "j_beta": j_beta,
        "mhc": mhc,
        "task": task,
        "label": labels,
    }


In [67]:
test_batch = [{
    "epitope_embedding": torch.randn(SEQ_MAX_LENGTH, EMBEDDING_SIZE),
    "epitope_sequence": "DUMMY",
    "tra_cdr3_embedding": torch.randn(SEQ_MAX_LENGTH, EMBEDDING_SIZE),
    "tra_cdr3_sequence": "DUMMY",
    "trb_cdr3_embedding": torch.randn(SEQ_MAX_LENGTH, EMBEDDING_SIZE),
    "trb_cdr3_sequence": "DUMMY",
    "v_alpha": [None],  # Soll durch default_value ersetzt werden
    "j_alpha": [3],     # Bleibt int
    "v_beta": [float('NaN')],  # Soll durch default_value ersetzt werden
    "j_beta": [None],   # Soll durch default_value ersetzt werden
    "mhc": [1],         # Bleibt int
    "task": "TASK",
    "label": torch.tensor(1),
}]

# Testen Sie pad_collate
pad_collate_instance = PadCollate(SEQ_MAX_LENGTH)
processed_batch = pad_collate_instance.pad_collate(test_batch)

# Geben Sie die Ergebnisse aus
print("Processed Batch:")
for key, value in processed_batch.items():
    print(f"{key}: {value}")


Processed Batch:
epitope_embedding: tensor([[[-1.3231, -0.3799,  0.4517,  ..., -0.2787,  0.4963,  0.0502],
         [-0.1506,  0.0459,  0.2559,  ...,  0.0403, -0.7913,  0.2771],
         [ 0.8813,  0.0524,  1.0743,  ..., -0.0716, -1.0506, -0.4369],
         ...,
         [-0.6735,  0.6787,  0.1282,  ..., -0.6883,  0.2265,  1.0553],
         [ 1.1066, -1.2512, -2.7224,  ..., -0.0762, -1.4145, -1.7557],
         [ 0.1911,  0.1089, -1.9436,  ..., -0.2423, -1.3605,  1.8166]]])
epitope_sequence: ['DUMMY']
tra_cdr3_embedding: tensor([[[ 0.8789, -0.3927, -0.9342,  ..., -1.0972,  1.9187,  0.4476],
         [ 0.2759, -0.6735,  1.6230,  ...,  3.3177, -0.0772, -0.1254],
         [-0.6491, -0.8555,  0.7213,  ..., -0.4010, -0.3711, -0.8523],
         ...,
         [-1.0976,  1.0830, -1.7744,  ..., -2.0512,  0.5174,  1.6622],
         [-0.7100, -1.3833, -0.9068,  ...,  0.5175,  2.0412, -0.2838],
         [-1.8839,  0.3632,  1.0765,  ..., -0.2857, -0.5553, -2.6480]]])
tra_cdr3_sequence: ['DUMMY']
trb

In [68]:
# Funktion zur Ersetzung von NaN-Werten
def replace_nan_in_dataset(dataset, default_value=0):
    for idx, item in enumerate(dataset):
        for key in ['v_alpha', 'j_alpha', 'v_beta', 'j_beta', 'mhc']:
            if key in item:
                value = item[key]
                if value is None or (isinstance(value, float) and math.isnan(value)):
                    print(f"Ersetze NaN in Datensatz {idx}, Schlüssel '{key}' durch {default_value}")
                    item[key] = default_value
    return dataset

# Vor der Erstellung des DataLoaders
# Ersetze NaN-Werte in test_dataset
test_dataset = replace_nan_in_dataset(test_dataset)

# Test DataLoader
pad_collate_fn = PadCollate(SEQ_MAX_LENGTH).pad_collate
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    sampler=test_sampler,
    num_workers=NUM_WORKERS,
    collate_fn=pad_collate_fn
)

# Funktion zur Evaluierung
def evaluate_model_predictions(model, test_dataloader):
    model.eval()
    all_true_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in test_dataloader:
            inputs = {
                "epitope": batch["epitope_embedding"].to(DEVICE),
                "tra_cdr3": batch["tra_cdr3_embedding"].to(DEVICE),
                "trb_cdr3": batch["trb_cdr3_embedding"].to(DEVICE),
                "v_alpha": batch["v_alpha"].to(DEVICE),
                "j_alpha": batch["j_alpha"].to(DEVICE),
                "v_beta": batch["v_beta"].to(DEVICE),
                "j_beta": batch["j_beta"].to(DEVICE),
                "mhc": batch["mhc"].to(DEVICE),
            }
            true_labels = batch["label"].cpu().numpy()

            outputs = model(
                inputs["epitope"], 
                inputs["tra_cdr3"], 
                inputs["trb_cdr3"], 
                inputs["v_alpha"], 
                inputs["j_alpha"], 
                inputs["v_beta"], 
                inputs["j_beta"], 
                inputs["mhc"]
            ).squeeze(1)

            predictions = (torch.sigmoid(outputs) > 0.5).cpu().numpy()
            all_true_labels.extend(true_labels)
            all_predictions.extend(predictions)

    all_true_labels = np.array(all_true_labels)
    all_predictions = np.array(all_predictions)

    conf_matrix = confusion_matrix(all_true_labels, all_predictions)
    accuracy = accuracy_score(all_true_labels, all_predictions)
    precision = precision_score(all_true_labels, all_predictions, zero_division=0)
    recall = recall_score(all_true_labels, all_predictions, zero_division=0)

    print("Confusion Matrix:")
    print(conf_matrix)
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")

    print("\nBeispielhafte Vorhersagen (erste 20):")
    print("True Labels:     ", all_true_labels[:20])
    print("Predicted Labels:", all_predictions[:20])

# Evaluierung durchführen
evaluate_model_predictions(model, test_dataloader)


Ersetze NaN in Datensatz 0, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 0, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 0, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 0, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 0, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 1, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 1, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 1, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 1, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 1, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 2, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 2, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 2, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 2, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 2, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 3, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 3, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 3, Schlüssel 'v_beta' durch 0
Ersetze NaN

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Ersetze NaN in Datensatz 17568, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 17568, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 17568, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 17568, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 17568, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 17569, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 17569, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 17569, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 17569, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 17569, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 17570, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 17570, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 17570, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 17570, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 17570, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 17571, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 17571, Schlüssel 'j_alpha' durch 

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Ersetze NaN in Datensatz 35329, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 35329, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 35329, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 35329, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 35329, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 35330, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 35330, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 35330, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 35330, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 35330, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 35331, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 35331, Schlüssel 'j_alpha' durch 0
Ersetze NaN in Datensatz 35331, Schlüssel 'v_beta' durch 0
Ersetze NaN in Datensatz 35331, Schlüssel 'j_beta' durch 0
Ersetze NaN in Datensatz 35331, Schlüssel 'mhc' durch 0
Ersetze NaN in Datensatz 35332, Schlüssel 'v_alpha' durch 0
Ersetze NaN in Datensatz 35332, Schlüssel 'j_alpha' durch 

ValueError: cannot convert float NaN to integer

In [26]:
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

# Funktion zur Auswertung des Modells
def evaluate_predictions(model):
    # Extrahiere gespeicherte Vorhersagen und Labels
    test_predictions = torch.stack(model.test_predictions).squeeze(1).cpu().numpy()
    test_labels = torch.stack(model.test_labels).squeeze(1).cpu().numpy()
    
    # Berechnung der binären Vorhersagen (Schwellenwert: 0.5)
    binary_predictions = (test_predictions > 0.5).astype(int)
    
    # Metriken berechnen
    conf_matrix = confusion_matrix(test_labels, binary_predictions)
    accuracy = accuracy_score(test_labels, binary_predictions)
    precision = precision_score(test_labels, binary_predictions, zero_division=0)
    recall = recall_score(test_labels, binary_predictions, zero_division=0)

    # Ergebnisse anzeigen
    print("Confusion Matrix:")
    print(conf_matrix)
    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    
    # Vorhersagen und Labels anzeigen (erste 20)
    print("\nBeispielhafte Vorhersagen (erste 20):")
    print("True Labels:     ", test_labels[:20])
    print("Predicted Labels:", binary_predictions[:20])

# Beispielaufruf
evaluate_predictions(model)


RuntimeError: stack expects a non-empty TensorList

In [1]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

def evaluate_loaded_model(model, dataloader):
    true_labels = []
    predicted_labels = []

    with torch.no_grad():  # Keine Gradientenberechnung erforderlich
        for data in dataloader:
            inputs, labels = data  # Passe dies ggf. an deine Datenstruktur an
            outputs = model(inputs.to(DEVICE))
            _, predicted = torch.max(outputs, 1)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Berechne Metriken
    conf_matrix = confusion_matrix(true_labels, predicted_labels)
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels, zero_division=0)
    recall = recall_score(true_labels, predicted_labels, zero_division=0)

    print("Confusion Matrix:")
    print(conf_matrix)
    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")

# Beispiel für den Aufruf der Evaluierungsfunktion
# Annahme: test_dataloader wurde bereits erstellt, wie in deinem Trainingsskript
evaluate_loaded_model(model, test_dataloader)


NameError: name 'model' is not defined

In [None]:
# Importieren der notwendigen Bibliotheken
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, SequentialSampler
from vanilla_model import VanillaModel
from dataclass_paired_vanilla import PairedVanilla
import os

In [19]:
# Parameter setzen
EMBEDDING_SIZE = 1024
SEQ_MAX_LENGTH = 30  # Maximal erlaubte Sequenzlänge
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Setzen Sie den Pfad zu den Testdaten und Embeddings
precision = "allele"
data_dir = "../../data_10x/splitted_datasets/allele/paired/"
embed_base_dir = "../../data_10x/embeddings/paired/allele"

test_file_path = f"{data_dir}/test.tsv"

# Hilfsfunktion: Spalte in Dictionary umwandeln
def column_to_dictionary(df, column_name):
    list_of_column = df[column_name].unique()
    return {item: index for index, item in enumerate(list_of_column)}

# Hilfsfunktion: Länge der Einbettungen berechnen
def get_embed_len(df, column_name):
    return len(df[column_name].unique())

# Daten laden und Dictionaries erstellen
df_test = pd.read_csv(test_file_path, sep="\t")
traV_dict = column_to_dictionary(df_test, "TRAV")
traJ_dict = column_to_dictionary(df_test, "TRAJ")
trbV_dict = column_to_dictionary(df_test, "TRBV")
trbJ_dict = column_to_dictionary(df_test, "TRBJ")
mhc_dict = column_to_dictionary(df_test, "MHC")

traV_embed_len = get_embed_len(df_test, "TRAV")
traJ_embed_len = get_embed_len(df_test, "TRAJ")
trbV_embed_len = get_embed_len(df_test, "TRBV")
trbJ_embed_len = get_embed_len(df_test, "TRBJ")
mhc_embed_len = get_embed_len(df_test, "MHC")

# Test-Dataset erstellen
test_dataset = PairedVanilla(test_file_path, embed_base_dir, traV_dict, traJ_dict, trbV_dict, trbJ_dict, mhc_dict)
print("Eintrag im Dataset:", test_dataset[0])

# Test-Dataloader erstellen
class PadCollate:
    def __init__(self, seq_max_length):
        self.seq_max_length = seq_max_length

    def pad_collate(self, batch):
        epitope_embeddings, tra_cdr3_embeddings, trb_cdr3_embeddings = [], [], []
        v_alpha, j_alpha, v_beta, j_beta = [], [], [], []
        epitope_sequence, tra_cdr3_sequence, trb_cdr3_sequence = [], [], []
        mhc, task, labels = [], [], []

        for item in batch:
            epitope_embeddings.append(item["epitope_embedding"])
            epitope_sequence.append(item["epitope_sequence"])
            tra_cdr3_embeddings.append(item["tra_cdr3_embedding"])
            tra_cdr3_sequence.append(item["tra_cdr3_sequence"])
            trb_cdr3_embeddings.append(item["trb_cdr3_embedding"])
            trb_cdr3_sequence.append(item["trb_cdr3_sequence"])
            v_alpha.append(item["v_alpha"])
            j_alpha.append(item["j_alpha"])
            v_beta.append(item["v_beta"])
            j_beta.append(item["j_beta"])
            mhc.append(item["mhc"])
            task.append(item["task"])

            # Labels sammeln
            if "Binding" in item:
                labels.append(int(item["Binding"]))  # Binding-Spalte verwenden
            else:
                print("Warnung: 'Binding' fehlt im Batch-Item:", item)

        max_length = self.seq_max_length

        # Debugging für Labels
        print("Labels Debugging:")
        print(labels)  # Zeigt die gesammelten Labels
        print("Label-Typ:", type(labels))  # Überprüfen Sie den Typ von `labels`

        # Konvertieren Sie `labels` zu einem Tensor
        try:
            labels = torch.tensor(labels, dtype=torch.int64)
        except Exception as e:
            print("Fehler bei der Konvertierung von Labels:", e)
            raise

        def pad_embeddings(embeddings):
            return torch.stack([
                torch.nn.functional.pad(embedding, (0, 0, 0, max_length - embedding.size(0)), "constant", 0)
                for embedding in embeddings
            ])

        epitope_embeddings = pad_embeddings(epitope_embeddings)
        tra_cdr3_embeddings = pad_embeddings(tra_cdr3_embeddings)
        trb_cdr3_embeddings = pad_embeddings(trb_cdr3_embeddings)

        v_alpha = torch.tensor(v_alpha, dtype=torch.int32)
        j_alpha = torch.tensor(j_alpha, dtype=torch.int32)
        v_beta = torch.tensor(v_beta, dtype=torch.int32)
        j_beta = torch.tensor(j_beta, dtype=torch.int32)
        mhc = torch.tensor(mhc, dtype=torch.int32)

        return {
            "epitope_embedding": epitope_embeddings,
            "tra_cdr3_embedding": tra_cdr3_embeddings,
            "trb_cdr3_embedding": trb_cdr3_embeddings,
            "v_alpha": v_alpha,
            "j_alpha": j_alpha,
            "v_beta": v_beta,
            "j_beta": j_beta,
            "mhc": mhc,
            "label": labels  # bleibt für Konsistenz in nachfolgendem Code
        }




pad_collate = PadCollate(SEQ_MAX_LENGTH).pad_collate

test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    sampler=test_sampler,
    collate_fn=pad_collate
)

hyperparameters = {}
hyperparameters["optimizer"] = "sgd"  # oder "adam", je nach Bedarf
hyperparameters["learning_rate"] = 5e-3
hyperparameters["weight_decay"] = 0.075
hyperparameters["dropout_attention"] = 0.3
hyperparameters["dropout_linear"] = 0.45

checkpoint_path = "VanillaModel.pth"
if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")


# Modell erstellen und Gewichte laden
model = VanillaModel(EMBEDDING_SIZE, SEQ_MAX_LENGTH, DEVICE, traV_embed_len, traJ_embed_len, trbV_embed_len, trbJ_embed_len, mhc_embed_len, hyperparameters)
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
filtered_state_dict = {k: v for k, v in checkpoint.items() if k in model.state_dict() and model.state_dict()[k].size() == v.size()}
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
print("Model successfully loaded and set to evaluation mode.")

# Vorhersagen sammeln
predictions, labels = [], []
print("Labels before conversion:", labels)
print("Type of labels:", type(labels))

with torch.no_grad():
    for batch in test_dataloader:
        epitope_embedding = batch["epitope_embedding"].to(DEVICE)
        tra_cdr3_embedding = batch["tra_cdr3_embedding"].to(DEVICE)
        trb_cdr3_embedding = batch["trb_cdr3_embedding"].to(DEVICE)
        v_alpha = batch["v_alpha"].to(DEVICE)
        j_alpha = batch["j_alpha"].to(DEVICE)
        v_beta = batch["v_beta"].to(DEVICE)
        j_beta = batch["j_beta"].to(DEVICE)
        mhc = batch["mhc"].to(DEVICE)
        label = batch["label"].to(DEVICE)

        output = model(epitope_embedding, tra_cdr3_embedding, trb_cdr3_embedding, v_alpha, j_alpha, v_beta, j_beta, mhc)
        probs = torch.sigmoid(output)

        predictions.extend(probs.cpu().numpy())
        labels.extend(label.cpu().numpy())

# Visualisierung der Ergebnisse
plt.figure(figsize=(10, 6))
plt.scatter(range(len(predictions)), predictions, label="Predictions", alpha=0.6)
plt.scatter(range(len(labels)), labels, label="True Labels", alpha=0.6)
plt.xlabel("Samples")
plt.ylabel("Probability / Label")
plt.title("Predictions vs. True Labels")
plt.legend()
plt.show()

# Histogramm der Vorhersagen
plt.figure(figsize=(10, 6))
plt.hist([p for p, l in zip(predictions, labels) if l == 1], bins=20, alpha=0.5, label="Positive Class Predictions")
plt.hist([p for p, l in zip(predictions, labels) if l == 0], bins=20, alpha=0.5, label="Negative Class Predictions")
plt.xlabel("Predicted Probability")
plt.ylabel("Frequency")
plt.title("Prediction Distribution")
plt.legend()
plt.show()

# Ergebnisse speichern
results_df = pd.DataFrame({"Prediction": predictions, "Label": labels})
results_df.to_csv("prediction_results.csv", index=False)

from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score

roc_auc = roc_auc_score(labels, predictions)
ap_score = average_precision_score(labels, predictions)
accuracy = accuracy_score(labels, predicted_classes)

print(f"ROC-AUC: {roc_auc:.4f}, Average Precision: {ap_score:.4f}, Accuracy: {accuracy:.4f}")

# Speichern Sie die Metriken in einer Datei
metrics_df = pd.DataFrame({"Metric": ["ROC-AUC", "AP", "Accuracy"], "Value": [roc_auc, ap_score, accuracy]})
metrics_df.to_csv("evaluation_metrics.csv", index=False)


Unmapped TRA_CDR3 sequences: 0 []
Unmapped TRB_CDR3 sequences: 0 []
Unmapped Epitope sequences: 0 []
Eintrag im Dataset: {'epitope_embedding': tensor([[ 0.2024, -0.1784, -0.1524,  ..., -0.0598, -0.0730,  0.0883],
        [ 0.0946, -0.0478, -0.2511,  ...,  0.0762, -0.0088,  0.0973],
        [ 0.2836, -0.0068, -0.0652,  ..., -0.0279,  0.0056, -0.1149],
        ...,
        [ 0.0949,  0.0307,  0.0768,  ..., -0.0136, -0.0991, -0.0265],
        [ 0.1097, -0.0155, -0.0418,  ...,  0.2618, -0.0979,  0.0138],
        [ 0.0736, -0.0596,  0.1575,  ...,  0.1424, -0.1303, -0.1070]]), 'epitope_sequence': 'TTDPSFLGRY', 'tra_cdr3_embedding': tensor([[ 0.1556, -0.0179, -0.3081,  ...,  0.1541, -0.0970, -0.0399],
        [ 0.1623, -0.2130, -0.1466,  ..., -0.1051,  0.2095,  0.1247],
        [ 0.2351, -0.0936, -0.0034,  ...,  0.2117,  0.1583,  0.1143],
        ...,
        [ 0.0273,  0.2817, -0.1794,  ...,  0.0676, -0.1381, -0.0646],
        [-0.0170,  0.1087, -0.0216,  ...,  0.2675, -0.1181, -0.2048],
   

  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


Model successfully loaded and set to evaluation mode.
Labels before conversion: []
Type of labels: <class 'list'>
Warnung: 'Binding' fehlt im Batch-Item: {'epitope_embedding': tensor([[ 0.2024, -0.1784, -0.1524,  ..., -0.0598, -0.0730,  0.0883],
        [ 0.0946, -0.0478, -0.2511,  ...,  0.0762, -0.0088,  0.0973],
        [ 0.2836, -0.0068, -0.0652,  ..., -0.0279,  0.0056, -0.1149],
        ...,
        [ 0.0949,  0.0307,  0.0768,  ..., -0.0136, -0.0991, -0.0265],
        [ 0.1097, -0.0155, -0.0418,  ...,  0.2618, -0.0979,  0.0138],
        [ 0.0736, -0.0596,  0.1575,  ...,  0.1424, -0.1303, -0.1070]]), 'epitope_sequence': 'TTDPSFLGRY', 'tra_cdr3_embedding': tensor([[ 0.1556, -0.0179, -0.3081,  ...,  0.1541, -0.0970, -0.0399],
        [ 0.1623, -0.2130, -0.1466,  ..., -0.1051,  0.2095,  0.1247],
        [ 0.2351, -0.0936, -0.0034,  ...,  0.2117,  0.1583,  0.1143],
        ...,
        [ 0.0273,  0.2817, -0.1794,  ...,  0.0676, -0.1381, -0.0646],
        [-0.0170,  0.1087, -0.0216,  ...

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)