# Imports

In [2]:
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.

import torch
from torch_geometric.data import Data

import numpy as np
import sparse

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as tgmnn
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataListLoader as GraphLoader
from torch_geometric.data import Batch

from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
import time
from sklearn import preprocessing
import math
from torch.utils.data import Dataset
import copy
import sklearn.metrics as skm
import pandas as pd
import random
from torch.utils.data.dataset import Dataset
import pytorch_pretrained_bert as Bert
import itertools
from einops import rearrange, repeat
import ast
from typing import Optional, Tuple, Union
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import softmax
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import LayerNorm
import torch.nn.functional as F
from torch import Tensor
import pickle
from sklearn.model_selection import ShuffleSplit
import transformers
import os
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve, auc
from sklearn.metrics import precision_score, recall_score, f1_score
from torchmetrics.classification import BinaryAccuracy
from tqdm import tqdm
from matplotlib import pyplot as plt

# MODELE

In [59]:


class TransformerConv(MessagePassing):
    _alpha: OptTensor
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim
        self._alpha = None

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        self.layernorm1 = LayerNorm(out_channels)
        self.layernorm2 = LayerNorm(out_channels)
        self.gelu = nn.GELU()
        self.proj = Linear(heads * out_channels, out_channels)
        self.ffn = Linear(out_channels, out_channels)
        self.ffn2 = Linear(out_channels, out_channels)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)


        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()


    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, batch=None, return_attention_weights=None):
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""Runs the forward pass of the module.

        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels
        residual = x
        x = self.layernorm1(x, batch)
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        query = self.lin_query(x[1]).view(-1, H, C)
        key = self.lin_key(x[0]).view(-1, H, C)
        value = self.lin_value(x[0]).view(-1, H, C)
        # propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr, size=None)
        alpha = self._alpha
        self._alpha = None
        if self.concat:
            out = self.proj(out.view(-1, self.heads * self.out_channels))
        else:
            out = out.mean(dim=1)
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = out+residual
        residual = out

        out = self.layernorm2(out)
        out = self.gelu(self.ffn(out))
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.ffn2(out)
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = out + residual
        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:


        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr

        out = out * alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')


class GraphTransformer(torch.nn.Module):
    def __init__(self, config):
        super().__init__()

        self.transformerconv1 = TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=True)
        self.transformerconv2 = TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=True)
        self.transformerconv3 = TransformerConv(config.hidden_size // 5, config.hidden_size // 5, heads=2, edge_dim=config.hidden_size // 5, dropout=config.hidden_dropout_prob, concat=False)

        self.embed = nn.Embedding(config.vocab_size, config.hidden_size // 5)
        self.embed_ee = nn.Embedding(config.node_attr_size, config.hidden_size // 5)

    def forward(self, x, edge_index, edge_attr, batch):
        indices = (x==0).nonzero().squeeze()
        h_nodes = self.transformerconv1(x=self.embed(x), edge_index=edge_index, edge_attr=self.embed_ee(edge_attr), batch=batch)
        h_nodes = nn.GELU()(h_nodes)
        h_nodes = self.transformerconv2(x=h_nodes, edge_index=edge_index, edge_attr=self.embed_ee(edge_attr), batch=batch)
        h_nodes = nn.GELU()(h_nodes)
        h_nodes = self.transformerconv3(x=h_nodes, edge_index=edge_index, edge_attr=self.embed_ee(edge_attr), batch=batch)
        x = h_nodes[indices]
        return x



class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, segment, age
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        #self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
        self.word_embeddings = GraphTransformer(config)
        self.type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size//5)
        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.age_vocab_size, config.hidden_size//5))
        self.time_embeddings = nn.Embedding(config.time_vocab_size , config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.time_vocab_size, config.hidden_size//5))
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size//5). \
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size//5))


        self.seq_layers = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU()
        )

        self.LayerNorm = nn.LayerNorm(config.hidden_size)
        self.acti = nn.GELU()
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))

    def forward(self, nodes, edge_index,  edge_attr, batch, age_ids, time_ids,  type_ids, posi_ids):
        word_embed = self.word_embeddings(nodes, edge_index, edge_attr, batch)
        type_embeddings = self.type_embeddings(type_ids)
        age_embed = self.age_embeddings(age_ids)
        time_embeddings = self.time_embeddings(time_ids)
        posi_embeddings = self.posi_embeddings(posi_ids)

        word_embed = torch.reshape(word_embed, type_embeddings.shape)
        embeddings = torch.cat((word_embed, type_embeddings, posi_embeddings, age_embed, time_embeddings), dim=2)
        
        b, n, _ = embeddings.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = self.seq_layers(embeddings)
        embeddings = self.LayerNorm(embeddings)
        
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos / (10000 ** (2 * idx / hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos / (10000 ** (2 * idx / hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)



class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config=config)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids, attention_mask=None, output_all_encoded_layers=True):

        if attention_mask is None:
            attention_mask = torch.ones_like(age_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers)
        
        sequence_output = encoded_layers[-1]

        pooled_output = self.pooler(sequence_output)
        
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        return encoded_layers, pooled_output




class BertForMTR(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMTR, self).__init__(config)
        self.num_labels = 1
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
        self.apply(self.init_bert_weights)

    def forward(self, nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids, attention_mask=None, labels=None):
        encoded_layer, pooled_output = self.bert(nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids, attention_mask, output_all_encoded_layers=False)

        
        logits = self.classifier(pooled_output).squeeze(dim=1)
        
        weights = torch.where(labels == 1, torch.tensor(1.5), torch.tensor(1.0))  #### à voir
        bce_logits_loss = nn.BCEWithLogitsLoss(reduction='mean', weight=weights)
        discr_supervised_loss = bce_logits_loss(logits, labels)
        
        return encoded_layer, pooled_output
    

    

class Pre_training2(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super(Pre_training2, self).__init__(config)
        self.bert = BertForMTR(config)
        self.linear1 = nn.Linear(config.hidden_size, self.config.vocab_size)
        self.linear2 = nn.Linear(config.hidden_size, self.config.type_vocab_size -1)
        self.gru = nn.GRU(config.hidden_size, config.hidden_size  , 1, batch_first = True, bidirectional=False)
        self.apply(self.init_bert_weights)
    
    def forward(self, nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids, attention_mask=None, labels=None):
        encoded_layer, pooled_output = self.bert(nodes, edge_index, edge_attr, batch, age_ids, time_ids, type_ids, posi_ids, attention_mask, labels)
        pooled_output = pooled_output.unsqueeze(dim=1)

        hidden = None
        output_sequence = []
        nb_seq = 50
        for i in range(nb_seq):
            output,hidden = self.gru(pooled_output, hidden)
            output_sequence.append(output)
        output_sequence = torch.stack(output_sequence, dim=1).squeeze(dim=2)
        
        b,s,h = output_sequence.shape
        output_sequence = output_sequence.view(b*s,h)
        output_sequence1 = self.linear1(output_sequence)
        output_sequence2 = self.linear2(output_sequence)
        
        return output_sequence1,output_sequence2



class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.age_vocab_size = config.get('age_vocab_size')
        self.type_vocab_size = config.get('type_vocab_size')
        self.time_vocab_size = config.get('time_vocab_size')
        self.graph_dropout_prob = config.get('graph_dropout_prob')
        self.node_attr_size = config.get('node_attr_size')



class TrainConfig(object):
    def __init__(self, config):
        self.batch_size = config.get('batch_size')
        self.use_cuda = config.get('use_cuda')
        self.max_len_seq = config.get('max_len_seq')
        self.train_loader_workers = config.get('train_loader_workers')
        self.test_loader_workers = config.get('test_loader_workers')
        self.device = config.get('device')
        self.output_dir = config.get('output_dir')
        self.output_name = config.get('output_name')
        self.best_name = config.get('best_name')



class GDSet(Dataset):
    def __init__(self, g):
        self.g = g

    def __getitem__(self, index):

        g = self.g[index]
        for i in range(len(g)):
          g[i]['posi_ids'] = i
        return g

    def __len__(self):
        return len(self.g)

# DATA

In [20]:
path = '../../data/'

In [21]:

with open(path + 'data_pad100.pkl', 'rb') as handle:
    dataset_loaded = pickle.load(handle)

In [22]:
#dataset = dataset_loaded[:100]

In [23]:
#with open('data_pad100', "wb") as f:
    #pickle.dump(dataset, f)

In [24]:
dataset=dataset_loaded

In [25]:
print(len(dataset))
print(len(dataset[0]))
print(dataset[0][-1])

100
50
Data(subject_id=[1], hadm_id=[1], label=[1], age=[1], rang=[1], type=[1], x=[56], edge_index=[2, 1540], edge_attr=[1540], mask_v=[1], time=[1])


In [26]:
noeud_unique = set()
edge_attr_unique = set()
age_unique = set()
time_unique = set()
type_unique = set()
label_unique = set()
hadm_unique = set()
subject_unique = set()
mask_v_unique = set()
rang_unique = set()
for patient in dataset:
    for visite in patient:
        noeuds = visite.x.tolist()
        edge = visite.edge_attr.tolist()
        label = visite.label.tolist()
        age = visite.age.tolist()
        time = visite.time.tolist()
        typ = visite.type.tolist()
        mask_v = visite.mask_v.tolist()
        rang = visite.rang.tolist()
        hadm = visite.hadm_id.tolist()
        subject = visite.subject_id.tolist()
        for noeud in noeuds:
            noeud_unique.add(noeud)
        for attribut in edge:
            edge_attr_unique.add(attribut)
        for lab in label:
            label_unique.add(lab)
        for a in age:
            age_unique.add(a)
        for t in time:
            time_unique.add(t)
        for ty in typ:
            type_unique.add(ty)
        for mask in mask_v:
            mask_v_unique.add(mask)
        for r in rang:
            rang_unique.add(r)
        for h in hadm:
            hadm_unique.add(h)
        for s in subject:
            subject_unique.add(s)
        

vocab_size = len(noeud_unique)
edge_attr_size = len(edge_attr_unique)
age_size = len(age_unique)
time_size = len(time_unique)
type_size = len(type_unique)
label_size = len(label_unique)
hadm_size = len(hadm_unique)
subject_size = len(subject_unique)
mask_v_size = len(mask_v_unique)
rang_size = len(rang_unique)

print('vocab_size',vocab_size)
print('max noeud',max(noeud_unique))
print('node_attr_size',edge_attr_size)
print('max edge_attr',max(edge_attr_unique))
print('age_size',age_size)
print('max age',max(age_unique))
print('time_size',time_size)
print('max time',max(time_unique))
print('type_size',type_size)
print('max type',max(type_unique))
print('label_size',label_size)
print('max label',max(label_unique))
print('hadm_size',hadm_size)
print('subject_size',subject_size)
print('maskv_size',mask_v_size)
print('max maskv',max(mask_v_unique))
print('rang_size',rang_size)
print('max rang',max(rang_unique))

vocab_size 1055
max noeud 9377
node_attr_size 8
max edge_attr 7
age_size 48
max age 130
time_size 248
max time 367
type_size 9
max type 10
label_size 2
max label 1
hadm_size 394
subject_size 100
maskv_size 2
max maskv 1
rang_size 24
max rang 51


In [27]:
train_l = int(len(dataset)*0.80)
val_l = int(len(dataset)*0.10)
test_l = len(dataset) - val_l - train_l

In [28]:
def split_dataset(dataset, random_seed=1, few_shots=1):
  rs = ShuffleSplit(n_splits=1, test_size=.20, random_state=random_seed)

  k = 5

  for i, (train_index_tmp, test_index) in enumerate(rs.split(dataset)):
    rs2 = ShuffleSplit(n_splits=1, test_size=0.1, random_state=random_seed)
    for j, (train_index, val_index) in enumerate(rs2.split(train_index_tmp)):
      train_index = train_index_tmp[train_index]
      if few_shots < 1:
        train_index = random.sample(list(train_index), int(len(train_index) * few_shots))
      val_index = train_index_tmp[val_index]

      trainDSet = [dataset[x] for x in train_index]
      valDSet = [dataset[x] for x in val_index]
      testDSet = [dataset[x] for x in test_index]
      return trainDSet, valDSet, testDSet

# Config file 

In [113]:
train_params = {
    'batch_size': 3,
    'use_cuda': True,
    'max_len_seq': 50,
    'device': "cuda"if torch.cuda.is_available() else "cpu",
    'data_len' : len(dataset),
    'train_data_len' : train_l,
    'val_data_len' : val_l,
    'test_data_len' : test_l,
    'epochs' : 100,
    'lr': 0.001,
    'weight_decay': 0.0001,
}

model_config = {
    'vocab_size': 9405, # number of disease + symbols for word embedding
    'edge_relationship_size': 8, # number of vocab for edge_attr
    'hidden_size': 50*5, # word embedding and seg embedding hidden size
    'age_vocab_size': 151, # number of vocab for age embedding
    'time_vocab_size': 380, # number of vocab for time embedding
    'type_vocab_size': 11+1 , # number of vocab for type embedding
    'node_attr_size': 8, # number of vocab for node_attr embedding
    'num_labels': 1,
    'max_position_embedding': 50, # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'graph_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 2, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'n_layers' : 3 - 1,
    'alpha' : 0.1
}

# fonction entrainement

In [118]:
import torch
import random
import copy

def creation_edge_index(x):
    # Edges (graphe complet)
    edge_index = []
    all_edges = []

    for i in range(len(x)):
        for j in range(i+1,len(x)):
            all_edges.append((i, j))
    source, target = zip(*all_edges)

    edge_index = torch.tensor([source, target], dtype=torch.int64)

    return edge_index
'''
def remove_nodes(graph_batch, max_nodes_per_visit=3):
    num_nodes = graph_batch.x.size(0)
    num_visits = graph_batch.batch.max() + 1
    graph_batch2 = copy.deepcopy(graph_batch)
    
    nodes_to_keep = []  # Liste pour stocker les nœuds à conserver
    nodes_to_remove = []  # Liste pour stocker les nœuds à supprimer
    to_delete = []
    
    for id, visit_id in enumerate(range(num_visits)):
        visit_nodes = (graph_batch.batch == visit_id).nonzero(as_tuple=False).squeeze()

        num_nodes_in_visit = len(visit_nodes)
        
        if num_nodes_in_visit > max_nodes_per_visit:
            # Choisir un nœud à supprimer au hasard
            node_to_remove = random.choice(visit_nodes.tolist())
            nodes_to_remove.append(node_to_remove)

            # Ajouter tous les nœuds sauf celui à supprimer
            nodes_to_keep.extend(visit_nodes[visit_nodes != node_to_remove].tolist())
            to_delete.append(id)

        else:
            nodes_to_keep.extend(visit_nodes.tolist())
    
    # Supprimer les nœuds
    print(graph_batch2.x.shape)
    graph_batch2.x = graph_batch.x[nodes_to_keep]
    true = graph_batch.x[nodes_to_remove]

     # Supprimer les arêtes associées aux nœuds supprimés
    edge_index = graph_batch.edge_index
    print(edge_index)
    print(edge_index.size(1))
    mask = torch.ones(edge_index.size(1), dtype=torch.bool)
    for node in nodes_to_remove:
        mask = mask & ~((edge_index[0] == node) | (edge_index[1] == node))
    edge_index = edge_index[:, mask]
    
    graph_batch2.edge_index = edge_index

    # Supprimer les attributs d'arête associés aux arêtes supprimées
    graph_batch2.edge_attr = graph_batch2.edge_attr[mask]

    graph_batch2.batch = graph_batch.batch[nodes_to_keep]
   
    return graph_batch2, true, to_delete
'''

def rem_node(data, min_node_per_visit=3):
    dataset2 = []
    nodes_to_remove = []
    indices_to_remove = []
    for i, patient in enumerate(data):
        patient2=[]
        for j, visite in enumerate(patient):
            node_initial = visite.x
            edge_index_initial = visite.edge_index
            edge_attr_initial = visite.edge_attr
            batch_initial = visite.batch
            num_nodes_in_visit = len(visite.x)
            visit2 = copy.deepcopy(visite)
            if num_nodes_in_visit > min_node_per_visit:
                # Choisir un nœud à supprimer au hasard
                node_to_remove = random.choice(node_initial.tolist()[1:])
                nodes_to_remove.append(node_to_remove)
                indices_to_remove.append(True)
                visit2.x = node_initial[node_initial != node_to_remove]
                mask = torch.ones(edge_index_initial.size(1), dtype=torch.bool)
                mask = mask & ~ ((edge_index_initial[0] == node_to_remove) | (edge_index_initial[1] == node_to_remove))
                visit2.edge_index = creation_edge_index(visit2.x)
                visit2.edge_attr = edge_attr_initial[mask]

                patient2.append(visit2)
            else:
                patient2.append(visit2)
                indices_to_remove.append(False)
            
        dataset2.append(patient2)

    return dataset2, torch.tensor(nodes_to_remove, dtype=torch.int64), indices_to_remove




def train(model, optim_model, trainload, device,scheduler=None):
    tr_loss = 0
    start = time.time()
    model.train()
    loss_type= torch.nn.CrossEntropyLoss()
    loss_node = torch.nn.CrossEntropyLoss()

    for step, data in tqdm(enumerate(trainload)):
        optim_model.zero_grad()

        graph_batch, true_node , nodes_to_remove_idx = rem_node(data,3)

        batched_data = Batch()
        graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data)))
        graph_batch = graph_batch.to(device)
        nodes = graph_batch.x
        edge_index = graph_batch.edge_index
        edge_attr = graph_batch.edge_attr
        batch = graph_batch.batch
        

        # pour le type 
        p = graph_batch.type.shape
        mask_type = (torch.ones(p) * 11).to(torch.int64).to(device)
        true_type = graph_batch.type
        graph_batch.type = mask_type



        age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
        time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
        type_ids = torch.reshape(graph_batch.type, [graph_batch.type.shape[0] // 50, 50])
        posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
        attMask = torch.reshape(graph_batch.mask_v, [graph_batch.mask_v.shape[0] // 50, 50])
        attMask = torch.cat((torch.ones((attMask.shape[0], 1)).to(device), attMask), dim=1)
        labels = torch.reshape(graph_batch.label, [graph_batch.label.shape[0] // 50, 50])[:, -1].float()


        pred_node, pred_type = model(nodes, edge_index, edge_attr, batch, age_ids, time_ids,type_ids,posi_ids,attMask, labels)
        pred_node = pred_node[nodes_to_remove_idx]

        loss1 = loss_type(pred_type, true_type)
        loss2 = loss_node(pred_node, true_node.to(device))
        total_loss = loss1 + loss2

        
        total_loss.backward()
        tr_loss += total_loss.item()
        optim_model.step()
        if scheduler is not None:
            scheduler.step()
        del loss1
        del loss2
        del total_loss
        
    
    print("TOTAL TRAIN LOSS",(tr_loss * train_params['batch_size']) / len(trainload))
    cost = time.time() - start
    print("TRAINING TIME", cost)

    return tr_loss, cost


def eval(model, optim_model, _valload, saving, device):
    tr_loss = 0
    start = time.time()
    model.eval()
    loss_type = nn.CrossEntropyLoss()
    loss_node = nn.CrossEntropyLoss()

    with torch.no_grad():
        for step, data in enumerate(_valload):
            optim_model.zero_grad()

            graph_batch, true_node , nodes_to_remove_idx = rem_node(data,3)

            batched_data = Batch()
            graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data)))
            graph_batch = graph_batch.to(device)
            nodes = graph_batch.x
            edge_index = graph_batch.edge_index
            edge_attr = graph_batch.edge_attr
            batch = graph_batch.batch
            

            # pour le type 
            p = graph_batch.type.shape
            mask_type = (torch.ones(p) * 11).to(torch.int64).to(device)
            true_type = graph_batch.type
            graph_batch.type = mask_type



            age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
            time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
            type_ids = torch.reshape(graph_batch.type, [graph_batch.type.shape[0] // 50, 50])
            posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
            attMask = torch.reshape(graph_batch.mask_v, [graph_batch.mask_v.shape[0] // 50, 50])
            attMask = torch.cat((torch.ones((attMask.shape[0], 1)).to(device), attMask), dim=1)
            labels = torch.reshape(graph_batch.label, [graph_batch.label.shape[0] // 50, 50])[:, -1].float()


            pred_node, pred_type = model(nodes, edge_index, edge_attr, batch, age_ids, time_ids,type_ids,posi_ids,attMask, labels)
            pred_node = pred_node[nodes_to_remove_idx]

            loss1 = loss_type(pred_type, true_type)
            loss2 = loss_node(pred_node, true_node.to(device))
            total_loss = loss1+loss2
            tr_loss += total_loss.item()
            del loss1
            del loss2
            del total_loss

    print("TOTAL VAL LOSS",(tr_loss * train_params['batch_size']) / len(_valload))


    cost = time.time() - start
    print("EVAL TIME", cost)

    return tr_loss, cost


def test(testload, model, device):
    model.eval()
    tr_loss = 0
    start = time.time()
    loss_type = nn.CrossEntropyLoss()
    loss_node = nn.CrossEntropyLoss()
    with torch.no_grad():
        for step, data in enumerate(testload):
            # Process the batch data and move it to the device

            graph_batch, true_node , nodes_to_remove_idx = rem_node(data,3)

            batched_data = Batch()
            graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data)))
            graph_batch = graph_batch.to(device)
            nodes = graph_batch.x
            edge_index = graph_batch.edge_index
            edge_attr = graph_batch.edge_attr
            batch = graph_batch.batch
            

            # pour le type 
            p = graph_batch.type.shape
            mask_type = (torch.ones(p) * 11).to(torch.int64).to(device)
            true_type = graph_batch.type
            graph_batch.type = mask_type


            age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
            time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
            type_ids = torch.reshape(graph_batch.type, [graph_batch.type.shape[0] // 50, 50])
            posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
            attMask = torch.reshape(graph_batch.mask_v, [graph_batch.mask_v.shape[0] // 50, 50])
            attMask = torch.cat((torch.ones((attMask.shape[0], 1)).to(device), attMask), dim=1)
            labels = torch.reshape(graph_batch.label, [graph_batch.label.shape[0] // 50, 50])[:, -1].float()


            pred_node, pred_type = model(nodes, edge_index, edge_attr, batch, age_ids, time_ids,type_ids,posi_ids,attMask, labels)
            pred_node = pred_node[nodes_to_remove_idx]

            loss1 = loss_type(pred_type, true_type)
            loss2 = loss_node(pred_node, true_node.to(device))
            total_loss = loss1 + loss2
            tr_loss += total_loss.item()
            del loss1
            del loss2
            del total_loss

    print("TOTAL TEST LOSS ", (tr_loss * train_params['batch_size']) / len(testload))
    cost = time.time() - start
    print("TEST TEST TIME", cost)
    

    return tr_loss, cost



def run_epoch(model, optim_model, trainload, valload, device,scheduler=None):
    best_val = math.inf
    
    with open(path + "v_behrt_log_train.txt", 'a') as f:
        f.write("TRAINING\n")

    for e in range(train_params["epochs"]):
        print("Epoch n" + str(e))

        train_loss, train_time_cost = train(model, optim_model, trainload, device,scheduler)
        val_loss, val_time_cost = eval(model, optim_model, valload, False, device)

        train_loss = (train_loss * train_params['batch_size']) / len(trainload)
        val_loss = (val_loss * train_params['batch_size']) / len(valload)
        with open(path + "GT_behrt_log_pretrain.txt", 'a') as f:
            f.write("Epoch n" + str(e) + '\n TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
            f.write('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost) + '\n\n\n')

        if val_loss < best_val:
            print("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = model.module if hasattr(model, 'module') else model
            save_model(model_to_save.state_dict(), path + 'pretrain2')
            best_val = val_loss
        print('\n')
    return train_loss, val_loss, train_time_cost, val_time_cost

from torch.optim.lr_scheduler import StepLR  

def experiment(num_experiments=5):
    conf = BertConfig(model_config)
    model = Pre_training2(conf).to(train_params['device'])
    transformer_vars = [i for i in model.parameters()]
    optim_model = torch.optim.AdamW(transformer_vars, lr=train_params['lr'], weight_decay=train_params['weight_decay'])
    scheduler = None
    df = pd.DataFrame(columns=['Experiment', 'Model', 'Metric', 'Score'])

    for exp in tqdm(range(num_experiments)):
        print(f"\n Experiment {exp + 1}")
        trainDSet, valDSet, testDSet = split_dataset(dataset, random_seed=exp)
        trainload =  GraphLoader(GDSet(trainDSet), batch_size=train_params['batch_size'], shuffle=False)
        valload =  GraphLoader(GDSet(valDSet), batch_size=train_params['batch_size'], shuffle=False)
        testload =  GraphLoader(GDSet(testDSet), batch_size=train_params['batch_size'], shuffle=False)
        #pretrain à ajouter ici
        train_loss, val_loss, train_time_cost, val_time_cost = run_epoch(model, optim_model, trainload, valload, train_params['device'],scheduler)
        test_loss, test_cost = test(testload, model, train_params['device'])

        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Train Loss', train_loss]
        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Val Loss', val_loss]
        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Train Time', train_time_cost]
        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Val Time', val_time_cost]
        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Test Time', test_cost]
        df.loc[len(df)] = [exp + 1, 'GT_BERT', 'Test Loss', test_loss]
        
    df.to_csv(path + 'GT_behrt_results.csv')

    return df


def save_model(_model_dict, file_name):
    torch.save(_model_dict, file_name)


# Main

In [119]:
import time as time

df = experiment(num_experiments=1)

  0%|          | 0/1 [00:00<?, ?it/s]


 Experiment 1
Epoch n0


24it [00:05,  4.61it/s]


TOTAL TRAIN LOSS 25.844415843486786
TRAINING TIME 5.214238405227661
TOTAL VAL LOSS 24.60110330581665
EVAL TIME 0.37101221084594727
** ** * Saving fine - tuned model ** ** * 


Epoch n1


24it [00:05,  4.18it/s]


TOTAL TRAIN LOSS 23.008140325546265
TRAINING TIME 5.746732234954834
TOTAL VAL LOSS 22.918253421783447
EVAL TIME 0.4776427745819092
** ** * Saving fine - tuned model ** ** * 


Epoch n2


24it [00:05,  4.29it/s]


TOTAL TRAIN LOSS 21.680137813091278
TRAINING TIME 5.606027603149414
TOTAL VAL LOSS 26.169431686401367
EVAL TIME 0.4256870746612549


Epoch n3


24it [00:05,  4.26it/s]


TOTAL TRAIN LOSS 20.85476052761078
TRAINING TIME 5.64177131652832
TOTAL VAL LOSS 21.374399662017822
EVAL TIME 0.426922082901001
** ** * Saving fine - tuned model ** ** * 


Epoch n4


24it [00:05,  4.50it/s]


TOTAL TRAIN LOSS 21.34863668680191
TRAINING TIME 5.338130235671997
TOTAL VAL LOSS 20.394027709960938
EVAL TIME 0.47754740715026855
** ** * Saving fine - tuned model ** ** * 


Epoch n5


24it [00:06,  3.97it/s]


TOTAL TRAIN LOSS 20.718711137771606
TRAINING TIME 6.0542213916778564
TOTAL VAL LOSS 22.0828275680542
EVAL TIME 0.5438523292541504


Epoch n6


24it [00:05,  4.01it/s]


TOTAL TRAIN LOSS 20.251133918762207
TRAINING TIME 5.984379529953003
TOTAL VAL LOSS 23.94084596633911
EVAL TIME 0.3939831256866455


Epoch n7


24it [00:05,  4.33it/s]


TOTAL TRAIN LOSS 22.23294848203659
TRAINING TIME 5.5455756187438965
TOTAL VAL LOSS 23.468125820159912
EVAL TIME 0.38953161239624023


Epoch n8


24it [00:05,  4.69it/s]


TOTAL TRAIN LOSS 21.154667794704437
TRAINING TIME 5.119402647018433
TOTAL VAL LOSS 20.983065128326416
EVAL TIME 0.3706340789794922


Epoch n9


24it [00:05,  4.60it/s]


TOTAL TRAIN LOSS 19.86701887845993
TRAINING TIME 5.225611209869385
TOTAL VAL LOSS 21.469193935394287
EVAL TIME 0.4241516590118408


Epoch n10


24it [00:05,  4.17it/s]


TOTAL TRAIN LOSS 20.066703855991364
TRAINING TIME 5.75654149055481
TOTAL VAL LOSS 20.330344200134277
EVAL TIME 0.4792957305908203
** ** * Saving fine - tuned model ** ** * 


Epoch n11


24it [00:05,  4.12it/s]


TOTAL TRAIN LOSS 19.40305107831955
TRAINING TIME 5.836641311645508
TOTAL VAL LOSS 19.13809061050415
EVAL TIME 0.5106475353240967
** ** * Saving fine - tuned model ** ** * 


Epoch n12


24it [00:05,  4.29it/s]


TOTAL TRAIN LOSS 19.441198706626892
TRAINING TIME 5.601373672485352
TOTAL VAL LOSS 20.46739959716797
EVAL TIME 0.730478048324585


Epoch n13


24it [00:04,  4.82it/s]


TOTAL TRAIN LOSS 19.947130799293518
TRAINING TIME 4.983723402023315
TOTAL VAL LOSS 19.661897659301758
EVAL TIME 0.3277289867401123


Epoch n14


24it [00:04,  5.52it/s]


TOTAL TRAIN LOSS 19.681620061397552
TRAINING TIME 4.354038953781128
TOTAL VAL LOSS 20.189478874206543
EVAL TIME 0.34484219551086426


Epoch n15


24it [00:04,  5.49it/s]


TOTAL TRAIN LOSS 20.56254768371582
TRAINING TIME 4.376643657684326
TOTAL VAL LOSS 18.97108745574951
EVAL TIME 0.3336198329925537
** ** * Saving fine - tuned model ** ** * 


Epoch n16


24it [00:04,  5.28it/s]


TOTAL TRAIN LOSS 20.0654998421669
TRAINING TIME 4.5495383739471436
TOTAL VAL LOSS 24.619837284088135
EVAL TIME 0.35869598388671875


Epoch n17


24it [00:05,  4.76it/s]


TOTAL TRAIN LOSS 19.401538729667664
TRAINING TIME 5.0504679679870605
TOTAL VAL LOSS 19.55725336074829
EVAL TIME 0.38744449615478516


Epoch n18


24it [00:05,  4.79it/s]


TOTAL TRAIN LOSS 18.9180548787117
TRAINING TIME 5.011150360107422
TOTAL VAL LOSS 22.403276920318604
EVAL TIME 0.342393159866333


Epoch n19


24it [00:04,  4.93it/s]


TOTAL TRAIN LOSS 19.48395347595215
TRAINING TIME 4.875501394271851
TOTAL VAL LOSS 21.544434070587158
EVAL TIME 0.31901049613952637


Epoch n20


24it [00:04,  5.08it/s]


TOTAL TRAIN LOSS 19.63357013463974
TRAINING TIME 4.726126670837402
TOTAL VAL LOSS 19.34280490875244
EVAL TIME 0.36465024948120117


Epoch n21


24it [00:05,  4.12it/s]


TOTAL TRAIN LOSS 19.47041690349579
TRAINING TIME 5.833360910415649
TOTAL VAL LOSS 22.27788019180298
EVAL TIME 0.32999634742736816


Epoch n22


24it [00:04,  4.90it/s]


TOTAL TRAIN LOSS 18.814604818820953
TRAINING TIME 4.904999732971191
TOTAL VAL LOSS 20.865407466888428
EVAL TIME 0.3858494758605957


Epoch n23


17it [00:04,  4.08it/s]
  0%|          | 0/1 [02:15<?, ?it/s]


KeyboardInterrupt: 

In [62]:
df

Unnamed: 0,Experiment,Model,Metric,Score
0,1,GT_BERT,Train Loss,44.956565
1,1,GT_BERT,Val Loss,42.408109
2,1,GT_BERT,Train Time,5.391731
3,1,GT_BERT,Val Time,0.187244
4,1,GT_BERT,Test Time,0.485468


In [16]:
# Group by Model and Metric and calculate average and standard deviation
result_df = df.groupby(['Model', 'Metric']).agg({'Score': ['mean', 'std']}).reset_index()

# Rename columns for clarity
result_df.columns = ['Model', 'Metric', 'Average Score', 'Standard Deviation']

result_df['Average Score'] = result_df['Average Score'].round(2)
result_df['Standard Deviation'] = result_df['Standard Deviation'].round(2)

# Print the result
print(result_df)

     Model         Metric  Average Score  Standard Deviation
0  GT_BERT     Test AUPRC           0.06                 NaN
1  GT_BERT     Test AUROC           0.22                 NaN
2  GT_BERT  Test Accuracy           0.90                 NaN
3  GT_BERT        Test F1           0.85                 NaN
4  GT_BERT      Test Time           0.15                 NaN
5  GT_BERT     Train Loss           2.32                 NaN
6  GT_BERT     Train Time           2.80                 NaN
7  GT_BERT   Val Accuracy           0.60                 NaN
8  GT_BERT       Val Loss           4.07                 NaN
9  GT_BERT       Val Time           0.31                 NaN


In [17]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())
count_parameters(model)

NameError: name 'model' is not defined