## Loading Data

In [1]:
import numpy as np
import pandas as pd
import pickle
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import pydot
import graphviz

import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

In [2]:
ekg_denoised = pd.read_pickle('/home/sanjaycollege15/PredictingDiagnoses/Data/ekg_denoised_v2.pkl')
ekg_denoised = ekg_denoised.sample(frac=1)

In [3]:
ekg_denoised.head()

Unnamed: 0,ICD9_CODE,TEXT
18174,0,sinus rhythm atrial premature complexes possib...
20919,0,"sinus bradycardia, rate 53. non-diagnostic q w..."
7614,3,sinus rhythm. compared to the previous tracing...
57370,2,sinus rhythm. since the previous tracing of se...
25264,0,sinus rhythm. delayed precordial r wave progre...


### Converting to Tensorflow dataset

In [4]:
training_dataset = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.cast(ekg_denoised['TEXT'].values, tf.string),
            tf.cast(ekg_denoised['ICD9_CODE'].values, tf.int32)
        )
    )
)

In [5]:
training_dataset

<TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.int32)>

In [6]:
train_text = ekg_denoised['TEXT'].to_numpy()

In [7]:
tf.convert_to_tensor(train_text)

<tf.Tensor: shape=(68180,), dtype=string, numpy=
array([b'sinus rhythm atrial premature complexes possible baseline artifact versus atrial pacer spikes right bundle branch block since previous tracing of , right bundle branch block present',
       b'sinus bradycardia, rate 53. non-diagnostic q waves in leads ii, iii and avf with j point elevation and concave upward st segment elevations in leads i, ii iii, avf and v1-v6. there is subtle pr segment depression in most leads with pr segment elevation in lead avr. the tracing suggest pericarditis possibly related to acute inferior myocardial infarction. compared to the previous tracing of atrial fibrillation with a rapid ventricular response has give way to sinus bradycardia. generalized st segment elevation, most marked anteriorly and inferolaterally, persist. the tracings are compatible with acute inferior myocardial infarction in evolution with supra added pericardial effusion and or pericardial process.',
       b'sinus rhythm. compar

### Setting up train, dev, and test datasets 

In [8]:
AUTOTUNE = tf.data.AUTOTUNE
DATASET_SIZE = len(training_dataset)

train_size = int(0.8 * DATASET_SIZE)
val_size = int(0.1 * DATASET_SIZE)
test_size = int(0.1 * DATASET_SIZE)

full_dataset = training_dataset
train_dataset = full_dataset.take(train_size).cache().prefetch(buffer_size=AUTOTUNE)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size).cache().prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.take(test_size).cache().prefetch(buffer_size=AUTOTUNE)


In [9]:
training_dataset

<TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.int32)>

In [10]:
for text_batch, label_batch in train_dataset.take(3):
    print(f'EKG reading: {text_batch.numpy()}')
    print(f'label: {label_batch}\n')

EKG reading: b'sinus rhythm atrial premature complexes possible baseline artifact versus atrial pacer spikes right bundle branch block since previous tracing of , right bundle branch block present'
label: 0

EKG reading: b'sinus bradycardia, rate 53. non-diagnostic q waves in leads ii, iii and avf with j point elevation and concave upward st segment elevations in leads i, ii iii, avf and v1-v6. there is subtle pr segment depression in most leads with pr segment elevation in lead avr. the tracing suggest pericarditis possibly related to acute inferior myocardial infarction. compared to the previous tracing of atrial fibrillation with a rapid ventricular response has give way to sinus bradycardia. generalized st segment elevation, most marked anteriorly and inferolaterally, persist. the tracings are compatible with acute inferior myocardial infarction in evolution with supra added pericardial effusion and or pericardial process.'
label: 0

EKG reading: b'sinus rhythm. compared to the pre

In [11]:
for text_batch, label_batch in val_dataset.take(3):
    print(f'EKG reading: {text_batch.numpy()}')
    print(f'label: {label_batch}\n')

EKG reading: b"sinus rhythm. leftward axis. rsr' pattern in leads v1-v2. predominately lateral t wave abnormalities. since the previous tracing of no significant change."
label: 0

EKG reading: b'a-v paced rhythm with an atrial premature beat. since the previous tracing atrial premature beat is again seen. tracing 2'
label: 0

EKG reading: b'ectopic atrial bradycardia. compared to tracing 1 there is no significant diagnostic change. tracing 2'
label: 3



In [12]:
for text_batch, label_batch in test_dataset.take(3):
    print(f'EKG reading: {text_batch.numpy()}')
    print(f'label: {label_batch}\n')

EKG reading: b'sinus rhythm. left bundle-branch block. compared to the previous tracing of no change.'
label: 3

EKG reading: b'sinus tachycardia, rate 109. delayed precordial r wave progression, possibly a normal variant, possibly old anterior myocardial infarction. possible inferior myocardial infarction of indeterminate age. generalized non-specific repolarization changes. q-t interval prolongation. compared to the previous tracing of anterior repolarization changes are less striking.'
label: 0

EKG reading: b'sinus rhythm with occasional atrial premature beat. left atrial abnormality. left axis deviation. intraventricular conduction defect. left ventricular hypertrophy with secondary repolarization abnormality. compared to the previous tracing of frequency of atrial premature beats has decreased. otherwise, anteroseptal st-t wave abnormality appears more prominent. clinical correlation is suggested.'
label: 1



## Modeling 

### Setting up BERT Pre-Process Model and SmallBERT

In [25]:
bert_preprocess_model = hub.KerasLayer(
    'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')

In [26]:
bert_model = hub.KerasLayer(
    'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1')

### Running sample text

In [15]:
test_text = ['sinus rhythm. prolonged a-v conduction. left bundle-branch block. compared to the previous tracing of transmural inferior wall myocardial infarction was previously present if left bundle-branch block patterning is absent, although the patterning barely resembles left bundle-branch block. the current tracing has a wider qrs interval.']
text_preprocessed = bert_preprocess_model(test_text)

print(f'Keys       : {list(text_preprocessed.keys())}')
print(f'Shape      : {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids   : {text_preprocessed["input_word_ids"][0, :128]}')
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids   : {text_preprocessed["input_type_ids"][0, :12]}')


Keys       : ['input_mask', 'input_word_ids', 'input_type_ids']
Shape      : (1, 128)
Word Ids   : [  101  8254  2271  6348  1012 15330  1037  1011  1058  6204  3258  1012
  2187 14012  1011  3589  3796  1012  4102  2000  1996  3025 16907  1997
  9099 16069  2140 14092  2813  2026 24755 25070  1999 14971  7542  2001
  3130  2556  2065  2187 14012  1011  3589  3796  5418  2075  2003  9962
  1010  2348  1996  5418  2075  4510 12950  2187 14012  1011  3589  3796
  1012  1996  2783 16907  2038  1037  7289  1053  2869 13483  1012   102
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0]
Input Mask : [1 1 1 1 1 1 1 1 1 1 1 1]
Type Ids   : [0 0 0 0 0 0 0 0 0 0 0 0]


#### Notes:
- Increasing the number of sentences changes the shape of the input word IDs (2 sentences becomes (2,128)).
- The input is limited to 128 words
- Start and stop of sentences is encoded with 101 and 102

In [16]:
bert_results = bert_model(text_preprocessed)

print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')

Pooled Outputs Shape:(1, 512)
Pooled Outputs Values:[ 0.9887249  -0.7988825  -0.32717526  0.10474511 -0.29970285  0.99474674
  0.98696184 -0.98571074 -0.5750031  -0.32717916 -0.57857615 -0.97410107]
Sequence Outputs Shape:(1, 128, 512)
Sequence Outputs Values:[[ 0.4081876   0.28475899 -0.7222308  ... -0.54963666 -0.17958543
   0.05423299]
 [ 0.2108407   0.48801696  0.15987705 ...  0.63800824  0.01293866
  -1.1388142 ]
 [ 0.5426497   0.87804437 -1.1415874  ...  0.867369   -0.52439797
   0.10186939]
 ...
 [-0.20698738 -0.00587239 -1.0680933  ...  0.8720527   0.3271039
  -0.5672984 ]
 [ 0.02397693 -0.26796004 -0.3877494  ... -1.0148929   0.0331724
  -0.7092347 ]
 [ 0.07354864 -0.2509057  -0.44517717 ... -1.1513165  -0.33980703
   0.78608596]]


#### Notes:
- Bert model outputs a 512 valued vector

### Defining Model Pipeline

In [18]:
def build_classifier_model():
  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
  preprocessing_layer = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3', name='preprocessing')
  encoder_inputs = preprocessing_layer(text_input)
  encoder = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1', trainable=True, name='BERT_encoder')
  outputs = encoder(encoder_inputs)
  net = outputs['pooled_output']
  net = tf.keras.layers.Dropout(0.1)(net)
  net = tf.keras.layers.Dense(4, activation=None, name='classifier')(net)
  return tf.keras.Model(text_input, net)

classifier_model = build_classifier_model()


In [19]:
bert_raw_result = classifier_model(tf.constant(test_text))
print(tf.sigmoid(bert_raw_result))

tf.Tensor([[0.30737662 0.25915897 0.44741687 0.6134175 ]], shape=(1, 4), dtype=float32)


### Loss Metrics

In [20]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.BinaryAccuracy()

In [21]:
epochs = 5
steps_per_epoch = tf.data.experimental.cardinality(train_dataset).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)
init_lr = 3e-5

In [22]:
classifier_model.compile(loss=loss, metrics=metrics)

In [23]:
tf.constant(test_text)

<tf.Tensor: shape=(1,), dtype=string, numpy=
array([b'sinus rhythm. prolonged a-v conduction. left bundle-branch block. compared to the previous tracing of transmural inferior wall myocardial infarction was previously present if left bundle-branch block patterning is absent, although the patterning barely resembles left bundle-branch block. the current tracing has a wider qrs interval.'],
      dtype=object)>

In [24]:
history = classifier_model.fit(x=test_dataset,
                               validation_data=val_dataset,
                               epochs=epochs,
                               )

Epoch 1/5


ValueError: in user code:

    File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/keras/engine/training.py", line 878, in train_function  *
        return step_function(self, iterator)
    File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/keras/engine/training.py", line 867, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/keras/engine/training.py", line 860, in run_step  **
        outputs = model.train_step(data)
    File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/keras/engine/training.py", line 808, in train_step
        y_pred = self(x, training=True)
    File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer "preprocessing" (type KerasLayer).
    
    in user code:
    
        File "/home/sanjaycollege15/anaconda3/lib/python3.8/site-packages/tensorflow_hub/keras_layer.py", line 237, in call  *
            result = smart_cond.smart_cond(training,
    
        ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
          Positional arguments (3 total):
            * Tensor("inputs:0", shape=(), dtype=string)
            * False
            * None
          Keyword arguments: {}
        
         Expected these arguments to match one of the following 4 option(s):
        
        Option 1:
          Positional arguments (3 total):
            * TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
            * False
            * None
          Keyword arguments: {}
        
        Option 2:
          Positional arguments (3 total):
            * TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
            * True
            * None
          Keyword arguments: {}
        
        Option 3:
          Positional arguments (3 total):
            * TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
            * False
            * None
          Keyword arguments: {}
        
        Option 4:
          Positional arguments (3 total):
            * TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
            * True
            * None
          Keyword arguments: {}
    
    
    Call arguments received:
      • inputs=tf.Tensor(shape=(), dtype=string)
      • training=True
