In [1]:
import torch.nn as nn
from tqdm import tqdm
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, auc
import wandb
import pandas as pd
from sklearn.model_selection import KFold

import pickle
import timeit
import os
import random
import glob
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold

from sklearn.metrics import roc_curve, auc


class ProteinProteinInteractionPrediction(nn.Module):
    def __init__(self,mod_embed,prot_embed,dim=20,layer_gnn=2):
        super(ProteinProteinInteractionPrediction, self).__init__()
        self.prot_embed_fingerprint = nn.Embedding(n_fingerprint, dim)
        self.mod_embed_fingerprint = nn.Embedding(nmod_fingerprint, dim)
        self.mod_W_gnn = nn.ModuleList([nn.Linear(dim, dim) for _ in range(layer_gnn)])
        self.prot_gnn = nn.ModuleList([nn.Linear(dim, dim) for _ in range(layer_gnn)])
        self.prot_W1_attention = nn.Linear(dim, dim)
        self.prot_W2_attention = nn.Linear(dim, dim)
        self.prot_w = nn.Parameter(torch.zeros(dim))
        self.W1_attention = nn.Linear(dim, dim)
        self.W2_attention = nn.Linear(2*dim, dim)  # Modified to accept concatenated protein vector
        self.w = nn.Parameter(torch.zeros(dim))
        self.W_out = nn.Sequential(
            nn.Linear(2*dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64,16),
            nn.ReLU(),
            nn.Linear(16, 2)
        )
        
        self.mod_embed = mod_embed
        self.prot_embed = prot_embed

        self.rdkit_linear = nn.Linear(2048, dim)

    def gnn(self, xs1, A1, type):
        for i in range(layer_gnn):
            if type == "mod":
                hs1 = torch.relu(self.mod_W_gnn[i](xs1))
            elif type == "prot":
                hs1 = torch.relu(self.prot_gnn[i](xs1))
            xs1 = torch.matmul(A1, hs1)

        return xs1
    
    def mutual_attention(self, h1, h2):
        x1 = self.W1_attention(h1)
        x2 = self.W2_attention(h2)

        m1 = x1.size()[0]
        m2 = x2.size()[0]

        c1 = x1.repeat(1, m2).view(m1, m2, dim)
        c2 = x2.repeat(m1, 1).view(m1, m2, dim)

        d = torch.tanh(c1 + c2)
        alpha = torch.matmul(d, self.w).view(m1, m2)

        b1 = torch.mean(alpha, 1)
        p1 = torch.softmax(b1, 0)
        s1 = torch.matmul(torch.t(x1), p1).view(-1, 1)

        b2 = torch.mean(alpha, 0)
        p2 = torch.softmax(b2, 0)
        s2 = torch.matmul(torch.t(x2), p2).view(-1, 1)

        return torch.cat((s1, s2), 0).view(1, -1), p1, p2
    def prot_mutual_attention(self, h1, h2):
        x1 = self.prot_W1_attention(h1)
        x2 = self.prot_W2_attention(h2)

        m1 = x1.size()[0]
        m2 = x2.size()[0]

        c1 = x1.repeat(1, m2).view(m1, m2, dim)
        c2 = x2.repeat(m1, 1).view(m1, m2, dim)

        d = torch.tanh(c1 + c2)
        alpha = torch.matmul(d, self.prot_w).view(m1, m2)

        b1 = torch.mean(alpha, 1)
        p1 = torch.softmax(b1, 0)
        s1 = torch.matmul(torch.t(x1), p1).view(-1, 1)

        b2 = torch.mean(alpha, 0)
        p2 = torch.softmax(b2, 0)
        s2 = torch.matmul(torch.t(x2), p2).view(-1, 1)

        return torch.cat((s1, s2), 0).view(1, -1), p1, p2

    def forward(self, inputs, train = True):
        fingerprints1, adjacency1, fingerprints2, adjacency2, fingerprints3, adjacency3, smiles = inputs

        """Protein vector with GNN."""
        x_fingerprints1 = self.mod_embed_fingerprint(fingerprints1)
        x_fingerprints2 = self.prot_embed_fingerprint(fingerprints2)
        x_fingerprints3 = self.prot_embed_fingerprint(fingerprints3)
        if self.mod_embed == "gnn":
            x_mod = self.gnn(x_fingerprints1, adjacency1, "mod")
        elif self.mod_embed == "rdkit":
            # Implement RDKit molecular fingerprinting
            from rdkit import Chem
            from rdkit.Chem import rdFingerprintGenerator

            # Convert SMILES to RDKit mol object
            mol = Chem.MolFromSmiles(smiles)
            # Generate Morgan fingerprint
            morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=2048)
            morgan_fp = morgan_gen.GetFingerprint(mol)
            # Convert bit vector to PyTorch tensor
            x_mod = torch.tensor([int(b) for b in morgan_fp.ToBitString()]).float().view(1, -1)
            x_mod = self.rdkit_linear(x_mod)            
        elif self.mod_embed == "graphmvp":
            x_mod = self.graphmvp_embed(fingerprints1)
        elif self.mod_embed == "infograph":
            x_mod = self.infograph_embed(fingerprints1)

        if self.prot_embed == "gnn":
            x_protein2 = self.gnn(x_fingerprints2, adjacency2, "prot")
            x_protein3 = self.gnn(x_fingerprints3, adjacency3, "prot")
        # x_mod = N, dim

        """Concatenate protein vectors"""
        x_proteins, p1, p2 = self.prot_mutual_attention(x_protein2, x_protein3)
        # print(f"shape of x_proteins: {x_proteins.shape}")
        # print(f"shape of x_protein2: {x_protein2.shape}")
        # print(f"shape of x_protein3: {x_protein3.shape}")
        """Protein vector with mutual-attention."""
        y, p1, p2 = self.mutual_attention(x_mod, x_proteins)

        z_interaction = self.W_out(y)
        # ADD CHEMICAL FEATURE VECTOR HERE 
        return z_interaction, p1, p2

    def __call__(self, data, train=True):
        inputs, t_interaction = data[:-1], data[-1]
        z_interaction, p1, p2 = self.forward(inputs, train)
        if train:
            loss = F.cross_entropy(z_interaction, t_interaction)
            return loss
        else:
            z = F.softmax(z_interaction, 1).to("cpu").data[0].numpy()
            t = int(t_interaction.to("cpu").data[0].numpy())
            return z, t, p1, p2

In [2]:

class Trainer(object):
    def __init__(self, model, lr = 1e-4, train_size = 3000):
        self.model = model
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.train_size = train_size

    def train(self, dataset):
        self.model.train()  # Set model to training mode
        sampling = random.choices(dataset, k=self.train_size)
        loss_total = 0
        for data in tqdm(sampling):
            try:
                mod, p1, p2, interaction, family = data
                A1 = np.load(
                        prot_data[prot_data['protein'] == p1].iloc[0]['adj_path'],
                        allow_pickle=True
                    )
                A2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['adj_path'],
                    allow_pickle=True
                )

                P1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                P2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_fp = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_adj = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                smiles = mod_data[mod_data['inchikey'] == mod].iloc[0]['smiles']
            except Exception as e:
                # print(f"failed for {e}")
                continue
            protein1 = torch.LongTensor(P1.astype(np.float32))
            protein2 = torch.LongTensor(P2.astype(np.float32))
            adjacency1 = torch.FloatTensor(A1.astype(np.float32))
            adjacency2 = torch.FloatTensor(A2.astype(np.float32))
            mod_fp = torch.LongTensor(mod_fp.astype(np.float32))
            mod_adj = torch.FloatTensor(mod_adj.astype(np.float32))
            interaction = torch.LongTensor([interaction.astype(int)])

            # comb = (protein1.to(device), adjacency1.to(device), protein2.to(device), adjacency2.to(device), interaction.to(device))
            # with torch.no_grad():
            # prot_loss, prot_embed = self.prot_model(comb, train=True)
            
            comb = (mod_fp.to(device), mod_adj.to(device), 
                    protein1.to(device), adjacency1.to(device), 
                    protein2.to(device), adjacency2.to(device), smiles,
                    interaction.to(device))
            loss = self.model(comb)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_total += loss.item()
        return loss_total / len(sampling)


In [3]:


class Tester(object):
    def __init__(self, model, test_size = 1000, wandb_project_name=False):
        self.model = model
        self.test_size = test_size

        # Initialize wandb if a project name is provided
        # if wandb_project_name:
        #     wandb.init(project=wandb_project_name)
        #     self.use_wandb = True
        # else:
        #     self.use_wandb = False

    def test(self, dataset, epoch=None):
        sampling = random.choices(dataset, k=self.test_size)
        z_list, t_list = [], []
        for data in tqdm(sampling):
            try:
                mod, p1, p2, interaction, _ = data
                A1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                A2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                P1 = np.load(
                    prot_data[prot_data['protein'] == p1].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                P2 = np.load(
                    prot_data[prot_data['protein'] == p2].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_fp = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['fp_path'],
                    allow_pickle=True
                )
                mod_adj = np.load(
                    mod_data[mod_data['inchikey'] == mod].iloc[0]['adj_path'],
                    allow_pickle=True
                )
                smiles = mod_data[mod_data['inchikey'] == mod].iloc[0]['smiles']
            except Exception as e:
                # print(f"failed for {mod}, {p1}, and {p2}: {e}")
                continue

            protein1 = torch.LongTensor(P1.astype(np.float32))
            protein2 = torch.LongTensor(P2.astype(np.float32))
            adjacency1 = torch.FloatTensor(A1.astype(np.float32))
            adjacency2 = torch.FloatTensor(A2.astype(np.float32))
            mod_fp = torch.LongTensor(mod_fp.astype(np.float32))
            mod_adj = torch.FloatTensor(mod_adj.astype(np.float32))
            interaction = torch.LongTensor([interaction.astype(int)])

            # comb = (protein1.to(device), adjacency1.to(device), protein2.to(device), adjacency2.to(device), interaction.to(device))
            # with torch.no_grad():
            # prot_loss, prot_embed = self.prot_model(comb, train=True)
            
            comb = (mod_fp.to(device), mod_adj.to(device), 
                    protein1.to(device), adjacency1.to(device), 
                    protein2.to(device), adjacency2.to(device), smiles,
                    interaction.to(device))
            z, _, _, _ = self.model(comb, train=False)
            # print(z,interaction)
            # print(z,torch.argmax(torch.FloatTensor(z)).item())
            z_list.append(z)
            t_list.append(interaction)

        score_list, label_list = [], []
        for z in z_list:
            score_list.append(z[1].item())
            label_list.append(torch.argmax(torch.FloatTensor(z)).item())

        labels = np.array(label_list)
        y_true = np.array([t.item() for t in t_list])
        y_pred = np.array(score_list)

        (
            tp,
            fp,
            tn,
            fn,
            accuracy,
            precision,
            sensitivity,
            recall,
            specificity,
            MCC,
            F1_score,
            Q9,
            ppv,
            npv,
        ) = self.calculate_performace(len(sampling), labels, y_true)
        roc_auc_val = roc_auc_score(y_true, y_pred)
        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
        auc_val = auc(fpr, tpr)

        # Log results to wandb
        # if self.use_wandb:
        #     wandb.log({
        #         "Epoch": epoch,
        #         "Accuracy": accuracy,
        #         "Precision": precision,
        #         "Recall": recall,
        #         "Sensitivity": sensitivity,
        #         "Specificity": specificity,
        #         "MCC": MCC,
        #         "F1 Score": F1_score,
        #         "ROC AUC": roc_auc_val,
        #         "AUC": auc_val,
        #         "TP": tp,
        #         "FP": fp,
        #         "TN": tn,
        #         "FN": fn,
        #         "PPV": ppv,
        #         "NPV": npv,
        #     })

        return (
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
        )

    def result(
        self,
        epoch,
        time,
        loss,
        accuracy,
        precision,
        recall,
        sensitivity,
        specificity,
        MCC,
        F1_score,
        roc_auc_val,
        auc_val,
        Q9,
        ppv,
        npv,
        tp,
        fp,
        tn,
        fn,
        file_name,
    ):
        import os

        # Ensure the directory exists
        os.makedirs(os.path.dirname(file_name), exist_ok=True)

        # Open the file in append mode, creating it if it doesn't exist
        with open(file_name, "a+") as f:
            result = map(
                str,
                [
                    epoch,
                    time,
                    loss,
                    accuracy,
                    precision,
                    recall,
                    sensitivity,
                    specificity,
                    MCC,
                    F1_score,
                    roc_auc_val,
                    auc_val,
                    Q9,
                    ppv,
                    npv,
                    tp,
                    fp,
                    tn,
                    fn,
                ],
            )
            f.write("\t".join(result) + "\n")

    def calculate_performace(self, test_num, pred_y, labels):
        tp = 0
        fp = 0
        tn = 0
        fn = 0
        test_num = len(pred_y)
        for index in range(test_num):
            if labels[index] == 1:
                if labels[index] == pred_y[index]:
                    tp = tp + 1
                else:
                    fn = fn + 1
            else:
                if labels[index] == pred_y[index]:
                    tn = tn + 1
                else:
                    fp = fp + 1

        if (tp + fn) == 0:
            q9 = float(tn - fp) / (tn + fp + 1e-06)
        if (tn + fp) == 0:
            q9 = float(tp - fn) / (tp + fn + 1e-06)
        if (tp + fn) != 0 and (tn + fp) != 0:
            q9 = 1 - float(np.sqrt(2)) * np.sqrt(float(fn * fn) / ((tp + fn) * (tp + fn)) + float(fp * fp) / ((tn + fp) * (tn + fp)))

        Q9 = (float)(1 + q9) / 2
        accuracy = float(tp + tn) / test_num
        precision = float(tp) / (tp + fp + 1e-06)
        sensitivity = float(tp) / (tp + fn + 1e-06)
        recall = float(tp) / (tp + fn + 1e-06)
        specificity = float(tn) / (tn + fp + 1e-06)
        ppv = float(tp) / (tp + fp + 1e-06)
        npv = float(tn) / (tn + fn + 1e-06)
        F1_score = float(2 * tp) / (2 * tp + fp + fn + 1e-06)
        MCC = float(tp * tn - fp * fn) / (np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))

        return tp, fp, tn, fn, accuracy, precision, sensitivity, recall, specificity, MCC, F1_score, Q9, ppv, npv

    def save_model(self, model, file_name):
        torch.save(model.state_dict(), file_name)

In [4]:
z = torch.FloatTensor(np.array([0.1, 0.2], dtype=np.float32))
print(z, torch.argmax(z,dim=0))

tensor([0.1000, 0.2000]) tensor(1)


In [5]:

# get data
prot_data = pd.read_csv("prot_data.csv")
mod_data = pd.read_csv("mod_data.csv")
train_data = pd.read_csv("interaction_data.csv")
examples = np.array(train_data.values.tolist())
# examples = np.array(random.choices(examples_all, k=500))

# setup folders
prot_fp_folder = "protein_fingerprints"
mod_fp_folder = "mod_fingerprints"
prot_fp_dict = np.load("protein_fingerprints/prot_fingerprint_dict.pickle",allow_pickle=True)
mod_fp_dict = np.load("mod_fingerprints/mod_fingerprint_dict.pickle",allow_pickle=True)


n_fingerprint = len(prot_fp_dict) + 100
nmod_fingerprint = len(mod_fp_dict) + 100


### Hyperparameters ###

radius         = 1
dim        = 100
layer_gnn      = 2
lr             = 1e-4
lr_decay       = 0.5
decay_interval = 10
iteration      = 30

import wandb
import timeit
import torch

fold_count = 2
### Check if GPU is available ###
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

### Define k-folds ###
num_kfolds = 5
kfold      = KFold(n_splits=num_kfolds, shuffle=True, random_state=1)  # Updated


In [6]:
train_size = 3000
test_size = 1000
log_wandb = True

for train, test in kfold.split(examples):
    dataset_train = examples[train]  # mod, prot1, prot2, int, int_family
    dataset_test = examples[test]

    start = timeit.default_timer()

    model = ProteinProteinInteractionPrediction(dim=dim, layer_gnn=layer_gnn, mod_embed="rdkit", prot_embed = "gnn").to(device)
    trainer = Trainer(model, lr = lr, train_size=train_size, )
    file_model = "ppim/model/" + "rdkit_fold_" + str(fold_count)
    file_result = "ppim/result/" + "rdkit_fold_" + str(fold_count) + ".txt"
    if log_wandb:
        wandb.init(project="promisegat4")
        # Log file paths to wandb
        wandb.config.update({
            "fold": fold_count,
            "model_path": file_model,
            "result_path": file_result,
            "protein_embed": "gcn",
            "mod_embed": "gcn",
            "radius": radius,
            "dim": dim,
            "layer_gnn": layer_gnn,
            "lr": lr,
            "lr_decay": lr_decay,
            "decay_interval": decay_interval,
            "iteration": iteration
        })

    for epoch in range(iteration):
        if (epoch + 1) % decay_interval == 0:
            trainer.optimizer.param_groups[0]["lr"] *= lr_decay

        loss = trainer.train(dataset_train)
        print(f"finished with loss {loss}")

        # Log training loss and GPU usage to wandb
        gpu_usage = torch.cuda.memory_allocated(device=device) / 1024 ** 3  # in GB
        if log_wandb:
            wandb.log({"epoch": epoch, "loss": loss, "gpu_usage_gb": gpu_usage, "fold": fold_count})

        tester = Tester(model, test_size = test_size)
        (
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
        ) = tester.test(dataset_test, epoch=epoch)
        end = timeit.default_timer()
        time = end - start

        # Log results to wandb
        if log_wandb:
            wandb.log({
            "epoch": epoch,
            "time": time,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "sensitivity": sensitivity,
            "specificity": specificity,
            "MCC": MCC,
            "F1_score": F1_score,
            "ROC_AUC": roc_auc_val,
            "AUC": auc_val,
            "Q9": Q9,
            "PPV": ppv,
            "NPV": npv,
            "TP": tp,
            "FP": fp,
            "TN": tn,
            "FN": fn,
            "fold": fold_count
        })

        tester.result(
            epoch,
            time,
            loss,
            accuracy,
            precision,
            recall,
            sensitivity,
            specificity,
            MCC,
            F1_score,
            roc_auc_val,
            auc_val,
            Q9,
            ppv,
            npv,
            tp,
            fp,
            tn,
            fn,
            file_result,
        )
        tester.save_model(model, file_model)


        print("Epoch: " + str(epoch))
        print("Accuracy: " + str(accuracy))
        print("Precision: " + str(precision))
        print("Recall: " + str(recall))
        print("Sensitivity: " + str(sensitivity))
        print("Specificity: " + str(specificity))
        print("MCC: " + str(MCC))
        print("F1-score: " + str(F1_score))
        print("ROC-AUC: " + str(roc_auc_val))
        print("AUC: " + str(auc_val))
        print("Q9: " + str(Q9))
        print("PPV: " + str(ppv))
        print("NPV: " + str(npv))
        print("TP: " + str(tp))
        print("FP: " + str(fp))
        print("TN: " + str(tn))
        print("FN: " + str(fn))
        print("\n")

        torch.manual_seed(1234)
    if log_wandb:
        wandb.finish()
    fold_count += 1

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01113756018878323, max=1.0)…

wandb: Network error (ConnectionError), entering retry loop.


  0%|          | 0/3000 [00:00<?, ?it/s][34m[1mwandb[0m: Network error resolved after 0:00:09.007714, resuming normal operation.
100%|██████████| 3000/3000 [33:11<00:00,  1.51it/s]    


finished with loss 0.5844416664093732


100%|██████████| 1000/1000 [01:47<00:00,  9.28it/s]


Epoch: 0
Accuracy: 0.5863636363636363
Precision: 0.4683098575059512
Recall: 0.38439306247285243
Sensitivity: 0.38439306247285243
Specificity: 0.7172284630763512
MCC: 0.10617251348402833
F1-score: 0.4222222215520282
ROC-AUC: 0.5902448528934208
AUC: 0.5902448528934208
Q9: 0.5209740917767137
PPV: 0.4683098575059512
NPV: 0.6426174485862124
TP: 133
FP: 151
TN: 383
FN: 213




100%|██████████| 3000/3000 [18:23<00:00,  2.72it/s]   


finished with loss 0.5729459016571442


100%|██████████| 1000/1000 [01:27<00:00, 11.46it/s]


Epoch: 1
Accuracy: 0.6064073226544623
Precision: 0.5294117632903429
Recall: 0.5409836050792798
Sensitivity: 0.5409836050792798
Specificity: 0.6535433058001116
MCC: 0.19396853778557266
F1-score: 0.5351351344119796
ROC-AUC: 0.6510826771653544
AUC: 0.6510826771653544
Q9: 0.5933500955898594
PPV: 0.5294117632903429
NPV: 0.663999998672
TP: 198
FP: 176
TN: 332
FN: 168




100%|██████████| 3000/3000 [06:15<00:00,  7.99it/s]


finished with loss 0.5336938950838521


100%|██████████| 1000/1000 [01:44<00:00,  9.62it/s]


Epoch: 2
Accuracy: 0.6472564389697648
Precision: 0.576271181557024
Recall: 0.20420420359097838
Sensitivity: 0.20420420359097838
Specificity: 0.9107142840880103
MCC: 0.16410211782088463
F1-score: 0.30155210576152525
ROC-AUC: 0.7429617117117118
AUC: 0.7429617117117118
Q9: 0.43375672746522886
PPV: 0.576271181557024
NPV: 0.6580645152799167
TP: 68
FP: 50
TN: 510
FN: 265




100%|██████████| 3000/3000 [06:33<00:00,  7.61it/s]


finished with loss 0.47699855659335544


100%|██████████| 1000/1000 [01:18<00:00, 12.76it/s]


Epoch: 3
Accuracy: 0.6938775510204082
Precision: 0.5967741921733314
Recall: 0.7316384160123208
Sensitivity: 0.7316384160123208
Specificity: 0.6685606047943928
MCC: 0.3923835082400499
F1-score: 0.6573604052571569
ROC-AUC: 0.7559386235233693
AUC: 0.7559386235233693
Q9: 0.698445684406604
PPV: 0.5967741921733314
NPV: 0.7879464268126196
TP: 259
FP: 175
TN: 353
FN: 95




100%|██████████| 3000/3000 [06:10<00:00,  8.11it/s]


finished with loss 0.4298171651582234


100%|██████████| 1000/1000 [01:30<00:00, 11.03it/s]


Epoch: 4
Accuracy: 0.7642369020501139
Precision: 0.693820222770168
Recall: 0.715942026910313
Sensitivity: 0.715942026910313
Specificity: 0.7954971842485982
MCC: 0.5087529332143521
F1-score: 0.7047075596223858
ROC-AUC: 0.8367349158441418
AUC: 0.8367349158441418
Q9: 0.7525021899726329
PPV: 0.693820222770168
NPV: 0.8122605348424128
TP: 247
FP: 109
TN: 424
FN: 98




100%|██████████| 3000/3000 [06:16<00:00,  7.96it/s]


finished with loss 0.39243348833860364


100%|██████████| 1000/1000 [01:29<00:00, 11.15it/s]


Epoch: 5
Accuracy: 0.7839643652561247
Precision: 0.7854984870528746
Recall: 0.6788511731622685
Sensitivity: 0.6788511731622685
Specificity: 0.8621359206560467
MCC: 0.5546057684037559
F1-score: 0.7282913155065948
ROC-AUC: 0.8465512433775255
AUC: 0.8465512433775255
Q9: 0.7528734415804534
PPV: 0.7854984870528746
NPV: 0.7830687816877094
TP: 260
FP: 71
TN: 444
FN: 123




100%|██████████| 3000/3000 [06:39<00:00,  7.51it/s]


finished with loss 0.3582613538142323


100%|██████████| 1000/1000 [01:22<00:00, 12.13it/s]


Epoch: 6
Accuracy: 0.8379888268156425
Precision: 0.7983193254949038
Recall: 0.796089383251147
Sensitivity: 0.796089383251147
Specificity: 0.8659217860969799
MCC: 0.6623213116400156
F1-score: 0.7972027960878283
ROC-AUC: 0.8950641365750133
AUC: 0.8950641365750133
Q9: 0.8274362353693029
PPV: 0.7983193254949038
NPV: 0.8643122660514642
TP: 285
FP: 72
TN: 465
FN: 73




100%|██████████| 3000/3000 [08:51<00:00,  5.64it/s]  


finished with loss 0.3190947156609873


100%|██████████| 1000/1000 [20:24<00:00,  1.22s/it]  


Epoch: 7
Accuracy: 0.813692480359147
Precision: 0.816513758970906
Recall: 0.7158176924508909
Sensitivity: 0.7158176924508909
Specificity: 0.8841698824629926
MCC: 0.6141127678777935
F1-score: 0.762857141767347
ROC-AUC: 0.8881369880029398
AUC: 0.8881369880029398
Q9: 0.7830020753942379
PPV: 0.816513758970906
NPV: 0.8120567361488356
TP: 267
FP: 60
TN: 458
FN: 106




100%|██████████| 3000/3000 [10:34<00:00,  4.73it/s]  


finished with loss 0.297444284398747


100%|██████████| 1000/1000 [01:54<00:00,  8.71it/s]


Epoch: 8
Accuracy: 0.8223234624145785
Precision: 0.8356164354944643
Recall: 0.6931818162125517
Sensitivity: 0.6931818162125517
Specificity: 0.9087452454206364
MCC: 0.6261347993085237
F1-score: 0.7577639739786274
ROC-AUC: 0.9075732371240925
AUC: 0.9075732371240925
Q9: 0.7736542173882605
PPV: 0.8356164354944643
NPV: 0.8156996573110927
TP: 244
FP: 48
TN: 478
FN: 108




100%|██████████| 3000/3000 [07:27<00:00,  6.70it/s]


finished with loss 0.2821158175678046


100%|██████████| 1000/1000 [01:46<00:00,  9.35it/s]


Epoch: 9
Accuracy: 0.8604910714285714
Precision: 0.8266666639111111
Recall: 0.7725856673751226
Sensitivity: 0.7725856673751226
Specificity: 0.9095652158094518
MCC: 0.6930796215942083
F1-score: 0.7987117539473241
ROC-AUC: 0.9190898008939457
AUC: 0.9190898008939457
Q9: 0.8269454890879542
PPV: 0.8266666639111111
NPV: 0.8775167770511464
TP: 248
FP: 52
TN: 523
FN: 73




100%|██████████| 3000/3000 [06:56<00:00,  7.20it/s]


finished with loss 0.26547160598860947


100%|██████████| 1000/1000 [01:32<00:00, 10.80it/s]


Epoch: 10
Accuracy: 0.8597081930415263
Precision: 0.8553459092599185
Recall: 0.7749287727210007
Sensitivity: 0.7749287727210007
Specificity: 0.9148148131207133
MCC: 0.7034726869109025
F1-score: 0.8131539599205472
ROC-AUC: 0.9270391474095179
AUC: 0.9270391474095179
Q9: 0.8298330641251122
PPV: 0.8553459092599185
NPV: 0.8621291433470696
TP: 272
FP: 46
TN: 494
FN: 79




100%|██████████| 3000/3000 [07:32<00:00,  6.63it/s]


finished with loss 0.2583987712395542


100%|██████████| 1000/1000 [02:08<00:00,  7.80it/s]


Epoch: 11
Accuracy: 0.8660022148394242
Precision: 0.8529411739619377
Recall: 0.8033240974977172
Sensitivity: 0.8033240974977172
Specificity: 0.9077490758159611
MCC: 0.7189089410630818
F1-score: 0.827389442471627
ROC-AUC: 0.9358613322975334
AUC: 0.9358613322975334
Q9: 0.8463906863930856
PPV: 0.8529411739619377
NPV: 0.8738898741138723
TP: 290
FP: 50
TN: 492
FN: 71




100%|██████████| 3000/3000 [07:05<00:00,  7.06it/s]


finished with loss 0.24787745296370406


100%|██████████| 1000/1000 [01:42<00:00,  9.73it/s]


Epoch: 12
Accuracy: 0.8231981981981982
Precision: 0.8544303770429419
Recall: 0.7086614154628309
Sensitivity: 0.7086614154628309
Specificity: 0.9092702151690922
MCC: 0.6388006373562447
F1-score: 0.7747489228482799
ROC-AUC: 0.9055687565681508
AUC: 0.9055687565681508
Q9: 0.7842338496353438
PPV: 0.8544303770429419
NPV: 0.8059440545350628
TP: 270
FP: 46
TN: 461
FN: 111




100%|██████████| 3000/3000 [06:58<00:00,  7.17it/s]


finished with loss 0.24497829783441125


100%|██████████| 1000/1000 [01:32<00:00, 10.83it/s]


Epoch: 13
Accuracy: 0.8491620111731844
Precision: 0.8338278907008074
Recall: 0.7805555533873457
Sensitivity: 0.7805555533873457
Specificity: 0.8953271011302297
MCC: 0.6840177868651884
F1-score: 0.8063127678532098
ROC-AUC: 0.9207502596053997
AUC: 0.9207502596053997
Q9: 0.8280810080226881
PPV: 0.8338278907008074
NPV: 0.858422937529708
TP: 281
FP: 56
TN: 479
FN: 79




100%|██████████| 3000/3000 [15:50<00:00,  3.16it/s]  


finished with loss 0.2156337700015431


100%|██████████| 1000/1000 [01:27<00:00, 11.48it/s]


Epoch: 14
Accuracy: 0.8571428571428571
Precision: 0.8571428546938775
Recall: 0.7936507915511883
Sensitivity: 0.7936507915511883
Specificity: 0.9034749017307434
MCC: 0.7056535456712133
F1-score: 0.8241758230437145
ROC-AUC: 0.9358542215685073
AUC: 0.9358542215685073
Q9: 0.8389144804457481
PPV: 0.8571428546938775
NPV: 0.8571428555729984
TP: 300
FP: 50
TN: 468
FN: 78




100%|██████████| 3000/3000 [06:27<00:00,  7.75it/s]


finished with loss 0.2211919988439531


100%|██████████| 1000/1000 [01:26<00:00, 11.50it/s]


Epoch: 15
Accuracy: 0.8509454949944383
Precision: 0.8631921795987225
Recall: 0.7422969166882439
Sensitivity: 0.7422969166882439
Specificity: 0.9225092233902044
MCC: 0.6859599098312033
F1-score: 0.7981927698822399
ROC-AUC: 0.9298143611688219
AUC: 0.9298143611688219
Q9: 0.8097163981414526
PPV: 0.8631921795987225
NPV: 0.8445945931679145
TP: 265
FP: 42
TN: 500
FN: 92




 50%|████▉     | 1485/3000 [03:23<03:15,  7.76it/s]