In [9]:
from graphcare import *

kg = "GPT-KG"
dataset = "mimic3"
task = "mortality"

# load dataset
sample_dataset, G, ent2id, rel2id, ent_emb, rel_emb, \
            map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_rel_inv, \
                ccscm_id2clus, ccsproc_id2clus, atc3_id2clus = load_everything(dataset, task, kg)

# label direct ehr node
print("Labeling direct ehr nodes...")
sample_dataset = label_ehr_nodes(task, sample_dataset, len(map_cluster), ccscm_id2clus, ccsproc_id2clus, atc3_id2clus)
print("Splitting dataset...")
train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)
G_tg = from_networkx(G)

# get embedding
print("Getting embedding...")
rel_emb = get_rel_emb(map_cluster_rel)
node_emb = G_tg.x 

Labeling direct ehr nodes...


100%|██████████| 9717/9717 [00:00<00:00, 14781.50it/s]


Splitting dataset...
Getting embedding...


In [10]:
G_tg.edge_index

tensor([[   0,    0,    0,  ..., 4598, 4598, 4598],
        [ 275, 1997,    0,  ..., 3388, 3924, 4283]])

In [11]:
train_set = Dataset(G=G_tg, dataset=train_dataset, task=task)
val_set = Dataset(G=G_tg, dataset=val_dataset, task=task)
test_set = Dataset(G=G_tg, dataset=test_dataset, task=task)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, drop_last=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, drop_last=True)

In [12]:
node_emb.shape

torch.Size([4599, 1536])

In [13]:
import torch.nn as nn

node_embedding = nn.Embedding.from_pretrained(node_emb, freeze=False)
relation_embedding = nn.Embedding.from_pretrained(rel_emb, freeze=False)
alpha_attn = nn.Linear(node_emb.shape[0], node_emb.shape[0])
beta_attn = nn.Linear(node_emb.shape[0], 1)
leakyrelu = nn.LeakyReLU(0.1)

In [14]:
node_embedding, relation_embedding

(Embedding(4599, 1536), Embedding(1077, 1536))

In [15]:
from torch.nn import Softmax

for batch in train_loader:
    # print(batch.ehr_nodes)
    # print(node_embedding(torch.tensor(batch.ehr_nodes[0])))
    node_ids = batch.y
    visit_node = batch.visit_padded_node.reshape(int(train_loader.batch_size), int(len(batch.visit_padded_node)/train_loader.batch_size), batch.visit_padded_node.shape[1]).double()
    x = node_embedding(node_ids)
    alpha = torch.softmax((leakyrelu(alpha_attn(visit_node.float()))), dim=1)
    beta = torch.softmax((leakyrelu(beta_attn(visit_node.float()))), dim=0)
    j = torch.arange(visit_node.shape[1], device=x.device).float()
    lambda_j = torch.exp(0.03 * (visit_node.shape[1] - j)).unsqueeze(0).reshape(1, visit_node.shape[1], 1)
    attn = alpha*beta*lambda_j
    attn = torch.sum(attn, dim=1)
    ehr_nodes = batch.ehr_nodes.reshape(int(train_loader.batch_size), int(len(batch.ehr_nodes)/train_loader.batch_size)).float()
    xj_batch = batch.batch[batch.edge_index[0]]
    xj_node_ids = batch.y[batch.edge_index[0]]
    print(batch)
    print(attn[xj_batch, xj_node_ids].shape)
    # print(batch.batch.shape)
    # print(attn.shape)
    # print(batch.batch[batch.edge_index[0]].shape)
    # print(batch.y[batch.edge_index[0]].shape)
    # print(ehr_nodes.shape)
    # print(ehr_nodes[1].view(1, -1) @ node_embedding.weight / torch.sum(ehr_nodes[1]))
    # attn = attn[batch.edge_index[0]]
    # print(attn.shape)
    
    # print(x.shape)
    # print(visit_node.shape)
    # visit_emb = x.view(visit_node.shape).sum(dim=2) / visit_node.sum(dim=2).clamp(min=1).view(visit_node.shape[:2] + (1,))
    # print(visit_emb.shape)
    # print(node_emb[:4599].shape)
    # print(visit_node.shape)
    # print(x.shape)
    # print((visit_node @ node_emb[:4599]).shape)
    break

DataBatch(x=[90729, 1536], edge_index=[2, 54985], y=[90729], relation=[54985], label=[64], visit_padded_node=[1792, 4599], ehr_nodes=[294336], batch=[90729], ptr=[65])
torch.Size([54985])


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv
from pyhealth.models import RETAINLayer
from torch_geometric.nn.inits import reset

from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
)
from torch_geometric.utils import spmm
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import softmax
from torch.nn import LeakyReLU


class BiAttentionGNNConv(MessagePassing):
    def __init__(self, nn: torch.nn.Module, eps: float = 0.,
                 train_eps: bool = False, edge_dim: Optional[int] = None,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        self.W_R = torch.nn.Linear(edge_dim, edge_dim)

        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))

        self.reset_parameters()

    def reset_parameters(self):
        self.nn.reset_parameters()
        self.eps.data.fill_(self.initial_eps)
        if self.W_R is not None:
            self.W_R.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None, attn: Tensor = None) -> Tensor:

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size, attn=attn)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j: Tensor, edge_attr: Tensor, attn: Tensor) -> Tensor:

        h_R = self.W_R(edge_attr)
        out = (x_j * attn + h_R).relu()
        return out

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(nn={self.nn})'


def masked_softmax(src: Tensor, mask: Tensor, dim: int = -1) -> Tensor:
    out = src.masked_fill(~mask, float('-inf'))
    out = torch.softmax(out, dim=dim)
    out = out.masked_fill(~mask, 0)
    return out

class GraphCare(nn.Module):
    def __init__(self, num_nodes, num_rels, max_visit, embedding_dim, hidden_dim, out_channels, layers=3, dropout=0.5, decay_rate=0.03, node_emb=None, rel_emb=None):
        super(GraphCare, self).__init__()

        self.embedding_dim = embedding_dim
        self.decay_rate = decay_rate

        j = torch.arange(max_visit).float()
        self.lambda_j = torch.exp(self.decay_rate * (max_visit - j)).unsqueeze(0).reshape(1, max_visit, 1).float()

        if node_emb is None:
            self.node_emb = nn.Embedding(num_nodes, embedding_dim)
        else:
            self.node_emb = nn.Embedding.from_pretrained(node_emb, freeze=False)

        if rel_emb is None:
            self.rel_emb = nn.Embedding(num_rels, embedding_dim)
        else:
            self.rel_emb = nn.Embedding.from_pretrained(rel_emb, freeze=False)

        self.lin = nn.Linear(embedding_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)

        self.layers = layers
        self.dropout = dropout

        self.alpha_attn = nn.ModuleDict()
        self.beta_attn = nn.ModuleDict()
        self.conv = nn.ModuleDict()

        self.leakyrelu = nn.LeakyReLU(0.1)

        for layer in range(1, layers+1):
            self.alpha_attn[str(layer)] = nn.Linear(num_nodes, num_nodes)
            self.beta_attn[str(layer)] = nn.Linear(num_nodes, 1)
            self.conv[str(layer)] = BiAttentionGNNConv(nn.Linear(hidden_dim, hidden_dim), edge_dim=hidden_dim)

        self.MLP = nn.Linear(hidden_dim * 2, out_channels)
        

    def to(self, device):
        super().to(device)
        self.lambda_j = self.lambda_j.float().to(device)


    def forward(self, node_ids, rel_ids, edge_index, batch, visit_node, ehr_nodes):
        x = self.node_emb(node_ids).float()
        edge_attr = self.rel_emb(rel_ids).float()

        x = self.bn1(self.lin(x))
        edge_attr = self.bn1(self.lin(edge_attr))


        for layer in range(1, self.layers+1):
            alpha = masked_softmax((self.leakyrelu(self.alpha_attn[str(layer)](visit_node.float()))), mask=visit_node>1, dim=1)
            beta = masked_softmax((self.leakyrelu(self.beta_attn[str(layer)](visit_node.float()))), mask=visit_node>1, dim=0) * self.lambda_j

            attn = alpha * beta
            attn = torch.sum(attn, dim=1)
            xj_node_ids = node_ids[edge_index[0]]
            xj_batch = batch[edge_index[0]]
            attn = attn[xj_batch, xj_node_ids].reshape(-1, 1)

            x = F.relu(self.conv[str(layer)](x, edge_index, edge_attr, attn=attn))
            x = F.dropout(x, p=0.3, training=self.training)

        # patient graph embedding through global mean pooling
        x_graph = global_mean_pool(x, batch)
        x_graph = F.dropout(x_graph, p=self.dropout, training=self.training)

        # patient node embedding through local (direct EHR) mean pooling
        x_node = torch.stack([ehr_nodes[i].view(1, -1) @ self.node_emb.weight / torch.sum(ehr_nodes[i]) for i in range(batch.max().item() + 1)])
        x_node = self.lin(x_node).squeeze(1)
        x_node = F.dropout(x_node, p=self.dropout, training=self.training)

        # concatenate patient graph embedding and patient node embedding
        x_concat = torch.cat((x_graph, x_node), dim=1)
        x_concat = F.dropout(x_concat, p=self.dropout, training=self.training)

        # MLP for prediction
        logits = self.MLP(x_concat)

        return logits




In [17]:
from tqdm import tqdm
from pyhealth.metrics import multilabel_metrics_fn
import torch.nn.functional as F
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score, f1_score, precision_score, recall_score, jaccard_score
    
def train(model, device, train_loader, optimizer):
    model.train()
    training_loss = 0
    tot_loss = 0
    pbar= tqdm(enumerate(train_loader))
    for i, data in pbar:
        pbar.set_description(f'loss: {training_loss}')

        data = data.to(device)
        optimizer.zero_grad()

        node_ids = data.y
        rel_ids = data.relation

        out = model(
                node_ids = node_ids, 
                rel_ids = rel_ids,
                edge_index = data.edge_index,
                batch = data.batch,
                visit_node = data.visit_padded_node.reshape(int(train_loader.batch_size), int(len(data.visit_padded_node)/train_loader.batch_size), data.visit_padded_node.shape[1]).float(), 
                ehr_nodes = data.ehr_nodes.reshape(int(train_loader.batch_size), int(len(data.ehr_nodes)/train_loader.batch_size)).float()
                
            )
        try:
            label = data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size))
        except:
            continue
        # print(out.shape, label.shape)
        loss = F.binary_cross_entropy_with_logits(out, label.float())
        loss.backward()
        training_loss = loss
        tot_loss += loss
        optimizer.step()
    
    return tot_loss

def evaluate(model, device, loader):
    model.eval()
    y_prob_all = []
    y_true_all = []

    for data in tqdm(loader):
        data = data.to(device)
        with torch.no_grad():    
            
            node_ids = data.y
            rel_ids = data.relation

            logits = model(
                    node_ids = node_ids, 
                    rel_ids = rel_ids,
                    edge_index = data.edge_index,
                    batch = data.batch,
                    visit_node = data.visit_padded_node.reshape(int(loader.batch_size), int(len(data.visit_padded_node)/loader.batch_size), data.visit_padded_node.shape[1]).float(), 
                    ehr_nodes = data.ehr_nodes.reshape(int(loader.batch_size), int(len(data.ehr_nodes)/loader.batch_size)).float()               
                )

            y_prob = torch.sigmoid(logits)
            try:
                y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
            except:
                continue
            y_prob_all.append(y_prob.cpu())
            y_true_all.append(y_true.cpu())
            
    y_true_all = np.concatenate(y_true_all, axis=0)
    y_prob_all = np.concatenate(y_prob_all, axis=0)

    return y_true_all, y_prob_all

def train_loop(train_loader, val_loader, model, optimizer, device, epochs):
    best_acc = 0
    best_f1 = 0
    for epoch in range(1, epochs+1):
        loss = train(model, device, train_loader, optimizer)
        y_true_all, y_prob_all = evaluate(model, device, val_loader)

        y_pred_all = (y_prob_all >= 0.5).astype(int)
        
        val_pr_auc = average_precision_score(y_true_all, y_prob_all)
        val_roc_auc = roc_auc_score(y_true_all, y_prob_all)
        val_jaccard = jaccard_score(y_true_all, y_pred_all, average='macro', zero_division=1)
        val_acc = accuracy_score(y_true_all, y_pred_all)
        val_f1 = f1_score(y_true_all, y_pred_all, average='macro', zero_division=1)
        val_precision = precision_score(y_true_all, y_pred_all, average='macro', zero_division=1)
        val_recall = recall_score(y_true_all, y_pred_all, average='macro', zero_division=1)

        if val_acc >= best_acc and val_f1 >= best_f1:
            # torch.save(model.state_dict(), '../../../data/pj20/exp_data/saved_weights_gin_mimic3_readmission_dynamic.pkl')
            # print("best model saved")
            best_acc = val_acc
            best_f1 = val_f1

        print(f'Epoch: {epoch}, Training loss: {loss}, Val PRAUC: {val_pr_auc:.4f}, Val ROC_AUC: {val_roc_auc:.4f}, Val acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val precision: {val_precision:.4f}, Val recall: {val_recall:.4f}, Val jaccard: {val_jaccard:.4f}')


In [18]:
sample_dataset[0]['visit_padded_node'].shape[0]

28

In [19]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

model = GraphCare(
    num_nodes=node_emb.shape[0],
    num_rels=rel_emb.shape[0],
    max_visit=sample_dataset[0]['visit_padded_node'].shape[0],
    embedding_dim=node_emb.shape[1],
    hidden_dim=512,
    out_channels=1,
    layers=3,
    dropout=0.5,
    decay_rate=0.01,
    node_emb=node_emb,
    rel_emb=rel_emb
)

model.to(device)



In [20]:
model

GraphCare(
  (node_emb): Embedding(4599, 1536)
  (rel_emb): Embedding(1077, 1536)
  (lin): Linear(in_features=1536, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (alpha_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=4599, bias=True)
    (2): Linear(in_features=4599, out_features=4599, bias=True)
    (3): Linear(in_features=4599, out_features=4599, bias=True)
  )
  (beta_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=1, bias=True)
    (2): Linear(in_features=4599, out_features=1, bias=True)
    (3): Linear(in_features=4599, out_features=1, bias=True)
  )
  (conv): ModuleDict(
    (1): BiAttentionGNNConv(nn=Linear(in_features=512, out_features=512, bias=True))
    (2): BiAttentionGNNConv(nn=Linear(in_features=512, out_features=512, bias=True))
    (3): BiAttentionGNNConv(nn=Linear(in_features=512, out_features=512, bias=True))
  )
  (leakyrelu): LeakyReLU(negative_slope=0.1)
 

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, device=device, epochs=100)

loss: 0.2067977339029312: : 120it [00:50,  2.37it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.24it/s]


Epoch: 1, Training loss: 31.50913429260254, Val PRAUC: 0.1073, Val ROC_AUC: 0.6298, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.15553808212280273: : 120it [00:49,  2.43it/s]
100%|██████████| 15/15 [00:05<00:00,  2.77it/s]


Epoch: 2, Training loss: 29.789306640625, Val PRAUC: 0.1221, Val ROC_AUC: 0.6443, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.2628225088119507: : 120it [00:54,  2.21it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.28it/s]


Epoch: 3, Training loss: 29.22536277770996, Val PRAUC: 0.1219, Val ROC_AUC: 0.6374, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.2506409287452698: : 120it [00:49,  2.41it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.08it/s]


Epoch: 4, Training loss: 28.71989631652832, Val PRAUC: 0.1247, Val ROC_AUC: 0.6289, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.17904116213321686: : 120it [00:50,  2.40it/s]
100%|██████████| 15/15 [00:04<00:00,  3.54it/s]


Epoch: 5, Training loss: 28.603946685791016, Val PRAUC: 0.1274, Val ROC_AUC: 0.6521, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.18198521435260773: : 120it [00:46,  2.59it/s]
100%|██████████| 15/15 [00:04<00:00,  3.61it/s]


Epoch: 6, Training loss: 28.40140151977539, Val PRAUC: 0.1251, Val ROC_AUC: 0.6398, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.34278881549835205: : 120it [00:46,  2.56it/s]
100%|██████████| 15/15 [00:04<00:00,  3.42it/s]


Epoch: 7, Training loss: 28.36269187927246, Val PRAUC: 0.1255, Val ROC_AUC: 0.6459, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.39694470167160034: : 120it [00:48,  2.47it/s]
100%|██████████| 15/15 [00:04<00:00,  3.50it/s]


Epoch: 8, Training loss: 28.33040428161621, Val PRAUC: 0.1216, Val ROC_AUC: 0.6462, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.22200188040733337: : 120it [00:46,  2.58it/s]
100%|██████████| 15/15 [00:04<00:00,  3.12it/s]


Epoch: 9, Training loss: 28.119611740112305, Val PRAUC: 0.1213, Val ROC_AUC: 0.6465, Val acc: 0.9229, Val F1: 0.4800, Val precision: 0.4619, Val recall: 0.4994, Val jaccard: 0.4615


loss: 0.05911890044808388: : 120it [00:46,  2.59it/s]
100%|██████████| 15/15 [00:04<00:00,  3.52it/s]


Epoch: 10, Training loss: 28.490995407104492, Val PRAUC: 0.1307, Val ROC_AUC: 0.6407, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.19967162609100342: : 120it [00:43,  2.73it/s]
100%|██████████| 15/15 [00:04<00:00,  3.58it/s]


Epoch: 11, Training loss: 27.972232818603516, Val PRAUC: 0.1286, Val ROC_AUC: 0.6400, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.18729358911514282: : 120it [00:43,  2.74it/s]
100%|██████████| 15/15 [00:03<00:00,  3.83it/s]


Epoch: 12, Training loss: 27.9657039642334, Val PRAUC: 0.1269, Val ROC_AUC: 0.6454, Val acc: 0.9229, Val F1: 0.4800, Val precision: 0.4619, Val recall: 0.4994, Val jaccard: 0.4615


loss: 0.12111091613769531: : 120it [00:44,  2.72it/s]
100%|██████████| 15/15 [00:03<00:00,  4.08it/s]


Epoch: 13, Training loss: 27.88951301574707, Val PRAUC: 0.1273, Val ROC_AUC: 0.6405, Val acc: 0.9229, Val F1: 0.4800, Val precision: 0.4619, Val recall: 0.4994, Val jaccard: 0.4615


loss: 0.1285816729068756: : 120it [00:44,  2.70it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.65it/s]


Epoch: 14, Training loss: 27.78513526916504, Val PRAUC: 0.1312, Val ROC_AUC: 0.6513, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.2933382987976074: : 120it [00:45,  2.64it/s] 
100%|██████████| 15/15 [00:03<00:00,  3.78it/s]


Epoch: 15, Training loss: 27.727758407592773, Val PRAUC: 0.1284, Val ROC_AUC: 0.6362, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.18093577027320862: : 120it [00:45,  2.62it/s]
100%|██████████| 15/15 [00:03<00:00,  3.79it/s]


Epoch: 16, Training loss: 27.63042449951172, Val PRAUC: 0.1296, Val ROC_AUC: 0.6434, Val acc: 0.9229, Val F1: 0.4800, Val precision: 0.4619, Val recall: 0.4994, Val jaccard: 0.4615


loss: 0.3249710202217102: : 120it [00:45,  2.63it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.67it/s]


Epoch: 17, Training loss: 27.789810180664062, Val PRAUC: 0.1277, Val ROC_AUC: 0.6407, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.15263307094573975: : 120it [00:42,  2.81it/s]
100%|██████████| 15/15 [00:04<00:00,  3.28it/s]


Epoch: 18, Training loss: 27.643404006958008, Val PRAUC: 0.1266, Val ROC_AUC: 0.6387, Val acc: 0.9229, Val F1: 0.4800, Val precision: 0.4619, Val recall: 0.4994, Val jaccard: 0.4615


loss: 0.3869945704936981: : 120it [00:44,  2.68it/s] 
100%|██████████| 15/15 [00:04<00:00,  3.73it/s]


Epoch: 19, Training loss: 27.582056045532227, Val PRAUC: 0.1255, Val ROC_AUC: 0.6400, Val acc: 0.9208, Val F1: 0.4794, Val precision: 0.4619, Val recall: 0.4983, Val jaccard: 0.4604


loss: 0.3811080753803253: : 120it [00:43,  2.75it/s]  
100%|██████████| 15/15 [00:04<00:00,  3.49it/s]


Epoch: 20, Training loss: 27.32731819152832, Val PRAUC: 0.1264, Val ROC_AUC: 0.6396, Val acc: 0.9240, Val F1: 0.4802, Val precision: 0.9620, Val recall: 0.5000, Val jaccard: 0.4620


loss: 0.07026718556880951: : 14it [00:05,  2.58it/s]


KeyboardInterrupt: 