In [9]:
PROJ_DIR = '/content/drive/My Drive/Courses/CS598-DL4H/project'

## import  packages you need
import tensorflow as tf
import numpy as np
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
import os, sys
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import time
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, f1_score

# print(os.path.exists(PROJ_DIR))
sys.path.append(os.path.join(PROJ_DIR, 'tensorflow'))
from graph_convolutional_transformer import SequenceExampleParser, FeatureEmbedder, create_matrix_vdp


## Data

In [10]:
data_dir = os.path.join(PROJ_DIR, 'eicu_samples/proc_data/fold_0')
train_data_path = os.path.join(data_dir, 'train.tfrecord')
valid_data_path = os.path.join(data_dir, 'validation.tfrecord')
# test_data_path = os.path.join(data_dir, 'test.tfrecord')

In [11]:
def count_examples_in_tfrecord(tfrecord_path):
    count = 0
    for record in tf.data.TFRecordDataset(tfrecord_path):
        count += 1
    return count

total_train_examples = count_examples_in_tfrecord(train_data_path)
print(f"Total training examples: {total_train_examples}")

total_valid_examples = count_examples_in_tfrecord(valid_data_path)
print(f"Total validation examples: {total_valid_examples}")

Total training examples: 32820
Total validation examples: 4103


In [12]:
train_seqex_reader = SequenceExampleParser(batch_size=15000)
valid_seqex_reader = SequenceExampleParser(batch_size=1000)

label_key='label.readmission'
train_generator = train_seqex_reader(train_data_path, label_key, True)
valid_generator = valid_seqex_reader(valid_data_path, label_key, True)  # False break training !

train_generator_iter = iter(train_generator)
valid_generator_iter = iter(valid_generator)

train_features_org, train_labels_org = next(train_generator_iter)
valid_features_org, valid_labels_org = next(valid_generator_iter)

In [13]:
def process_samples(features, labels, embedding_size):
  feature_keys=['dx_ints', 'proc_ints']
  vocab_sizes={'dx_ints':3249, 'proc_ints':2210}
  max_num_codes=50
  prior_scalar=0.5
  use_guide = True
  use_prior = True

  feature_embedder = FeatureEmbedder(vocab_sizes, feature_keys, embedding_size)
  embedding_dict, mask_dict = feature_embedder.lookup(features, max_num_codes)

  keys = ['visit'] + feature_keys
  embeddings = tf.concat([embedding_dict[key] for key in keys], axis=1)
  masks = tf.concat([mask_dict[key] for key in keys], axis=1)

  guide, prior = create_matrix_vdp(features, masks, use_prior, use_guide, max_num_codes, prior_scalar)

  embeddings = torch.tensor(embeddings.numpy())
  masks = torch.tensor(masks.numpy())
  guide = torch.tensor(guide.numpy())
  prior = torch.tensor(prior.numpy())
  labels = torch.tensor(labels.numpy())

  return embeddings, masks, guide, prior, labels

In [14]:
train_embeddings, train_masks, train_guide, train_prior, train_labels = process_samples(train_features_org, train_labels_org, embedding_size=128)
valid_embeddings, valid_masks, valid_guide, valid_prior, valid_labels = process_samples(valid_features_org, valid_labels_org, embedding_size=128)

print(train_embeddings.shape)
print(train_masks.shape)
print(train_guide.shape)
print(train_prior.shape)
print(train_labels.shape)

print(valid_embeddings.shape)
print(valid_masks.shape)
print(valid_guide.shape)
print(valid_prior.shape)
print(valid_labels.shape)

del train_features_org, train_labels_org, valid_features_org, valid_labels_org

torch.Size([15000, 101, 128])
torch.Size([15000, 101])
torch.Size([15000, 101, 101])
torch.Size([15000, 101, 101])
torch.Size([15000])
torch.Size([1000, 101, 128])
torch.Size([1000, 101])
torch.Size([1000, 101, 101])
torch.Size([1000, 101, 101])
torch.Size([1000])


In [15]:
train_dir = os.path.join(PROJ_DIR, 'pytorch/train')
valid_dir = os.path.join(PROJ_DIR, 'pytorch/valid')

if not os.path.exists(train_dir):
  os.makedirs(train_dir)
if not os.path.exists(valid_dir):
  os.makedirs(valid_dir)

torch.save(train_embeddings, os.path.join(train_dir, 'embeddings.pt'))
torch.save(train_masks, os.path.join(train_dir, 'masks.pt'))
torch.save(train_guide, os.path.join(train_dir, 'guide.pt'))
torch.save(train_prior, os.path.join(train_dir, 'prior.pt'))
torch.save(train_labels, os.path.join(train_dir, 'labels.pt'))

torch.save(valid_embeddings, os.path.join(valid_dir, 'embeddings.pt'))
torch.save(valid_masks, os.path.join(valid_dir, 'masks.pt'))
torch.save(valid_guide, os.path.join(valid_dir, 'guide.pt'))
torch.save(valid_prior, os.path.join(valid_dir, 'prior.pt'))
torch.save(valid_labels, os.path.join(valid_dir, 'labels.pt'))