In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import numpy as np

In [2]:
model_config = {
    'vocab_size': 100, # number of disease + symbols for word embedding
    'edge_relationship_size': 9, # number of vocab for edge_attr
    'hidden_size': 50*5, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 103, # number of vocab for age embedding
    'delta_size': 144, # number of vocab for age embedding
    'gender_vocab_size': 2,
    'ethnicity_vocab_size': 2,
    'race_vocab_size': 6,
    'num_labels':1,
    'max_position_embedding': 50, # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'graph_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 2, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : 1,
    'n_layers' : 3 - 1,
    'alpha' : 0.1
}

In [3]:
def genere_graph(num_nodes, config):
    x = torch.randint(1, config["vocab_size"]+1, (num_nodes,))
    x = torch.cat([torch.tensor([0]), x]) # on rajoute vst
    
    all_edges = []
    for i in range(len(x)):
        for j in range(i+1,len(x)):
            all_edges.append((i, j))
    source, target = zip(*all_edges)
    edge_index = torch.tensor([source, target])

    edge_attr = torch.randint(1, config["edge_relationship_size"] + 1, (edge_index.size(1),))

    return x, edge_index, edge_attr

def visit(nodes, edge_index, edge_attr, age, time, type, label, mask_v, mask, delta=None, los=None, edge_index_readout=None,):
    return Data(x=nodes, edge_index=edge_index, edge_attr=edge_attr, age=age, time=time, type=type, label=label, mask_v=mask_v, mask=mask, delta=delta, los=los, edge_index_readout=edge_index_readout)

def genere_visit(num_nodes, mask_v, config):
    x, edge_index, edge_attr = genere_graph(num_nodes, model_config)
    age = torch.randint(1, config["age_vocab_size"], (1,))
    time = torch.randint(1, 367, (1,))
    type = torch.randint(1, 11, (1,))
    label = torch.randint(0, 2, (1,))
    mask_v = torch.tensor(mask_v)
    mask = torch.tensor(1)
    return visit(x, edge_index, edge_attr, age, time, type, label, mask_v, mask)

def genere_patient(config):
    visits = []
    num_visits = np.random.randint(2, config["max_position_embedding"]+1)
    for i in range(config["max_position_embedding"]-num_visits):
        num_nodes = np.random.randint(10, 30)
        visits.append(genere_visit(num_nodes, 0, config)) 
    for i in range(config["max_position_embedding"]-num_visits, config["max_position_embedding"]):
        num_nodes = 1
        visits.append(genere_visit(num_nodes, 1, config))
    return visits

def genere_batch(num_patients, num_visits, config):
    patients = []
    for i in range(num_patients):
        patients.append(genere_patient(config))
    return patients

In [4]:
dataset = genere_batch(100, 10, model_config)

for patient in dataset:
    for visit in patient:
        print(visit)
    print("-----")



Data(x=[27], edge_index=[2, 351], edge_attr=[351], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[26], edge_index=[2, 325], edge_attr=[325], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[22], edge_index=[2, 231], edge_attr=[231], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[28], edge_index=[2, 378], edge_attr=[378], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[21], edge_index=[2, 210], edge_attr=[210], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[20], edge_index=[2, 190], edge_attr=[190], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[29], edge_index=[2, 406], edge_attr=[406], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[17], edge_index=[2, 136], edge_attr=[136], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[15], edge_index=[2, 105], edge_attr=[105], age=[1], time=[1], type=[1], label=[1], mask_v=0, mask=1)
Data(x=[29], edge_i

In [76]:
import pickle
with open("data", 'wb') as file:
    pickle.dump(dataset, file)

In [56]:
with open("data", 'rb') as file:
    loaded_data = pickle.load(file)

print("Loaded data:", loaded_data)

Loaded data: [[Data(x=[13], edge_index=[2, 78], edge_attr=[78], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[13]), Data(x=[13], edge_index=[2, 78], edge_attr=[78], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[13]), Data(x=[15], edge_index=[2, 105], edge_attr=[105], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[15]), Data(x=[17], edge_index=[2, 136], edge_attr=[136], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[17]), Data(x=[22], edge_index=[2, 231], edge_attr=[231], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[22]), Data(x=[17], edge_index=[2, 136], edge_attr=[136], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[17]), Data(x=[19], edge_index=[2, 171], edge_attr=[171], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[19]), Data(x=[20], edge_index=[2, 190], edge_attr=[190], age=[1], time=[1], type=[1], label=[1], mask_v=1, mask=[20]), Data(x=[27], edge_index=[2, 351], edge_attr=[351], age=[1], time=[1], type=[1], label