In [3]:
import pandas as pd
import random
import re
import os

# Gensim (for Word2Vec) needs to be installed for the following two imports
import gensim
from gensim.models import Word2Vec
from gensim import utils
from time import time

# Training Word2Vec Model

This notebook takes the raw discharge summaries and trains Word2Vec model. Model is saved as a file and use during Symptom embedding generation using BiLSTM training in 05_bilstm_training notebook. 

---

- Input : discharge_summary.csv - Filtered notes file with raw discharge summary notes (generated by 01_symptom_extraction)
- Output : word2vec trained model - Used by 05_bilstm_training.

In [5]:
cwd = os.getcwd()
print(f"Current working directory : {cwd}")
# Let's define some constants that will be used below in our processing
MAX_NUMBER_OF_DISEASE = 50
RUN_TAG = "_v2.0"
data_dir = cwd + "/../../data/"
MODEL_DIR = data_dir + "/word2vec/"
DISCHARGE_SUMMARY = data_dir + f"discharge_summary_{RUN_TAG}.csv"
MODEL_FILE_PATH = MODEL_DIR + f"word2vec_model_sg_128_{RUN_TAG}"
MODEL_TEXT_FILE_PATH = MODEL_DIR + f"word2vec_model_sg_128_{RUN_TAG}.txt"

Current working directory : /Users/vijaymi/Studies/CS-598-DL4Health/Project/135-Disease-Inference-Method/disease_pred_using_bilstm/source


### Load Discharge Summaries

In [5]:
# Read discharge summaries csv file
discharge_summaries_df = pd.read_csv(DISCHARGE_SUMMARY)
discharge_summaries_df.head(5)

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
0,174,22532,167853.0,2151-08-04,,,Discharge summary,Report,,,Admission Date: [**2151-7-16**] Dischar...
1,175,13702,107527.0,2118-06-14,,,Discharge summary,Report,,,Admission Date: [**2118-6-2**] Discharg...
2,176,13702,167118.0,2119-05-25,,,Discharge summary,Report,,,Admission Date: [**2119-5-4**] D...
3,177,13702,196489.0,2124-08-18,,,Discharge summary,Report,,,Admission Date: [**2124-7-21**] ...
4,178,26880,135453.0,2162-03-25,,,Discharge summary,Report,,,Admission Date: [**2162-3-3**] D...


In [6]:
# Find the total number of rows in the discharge summaries data frame
len(discharge_summaries_df.index)

59652

### Split Data to Train and Test

In [7]:
# Set seed
seed = 1234
# Helper function to split data to train, and test
def build_data_buckets(num_records):    
    index = list(range(num_records))
    random.seed(seed)
    random.shuffle(index)
    index_train = index[0 : int(num_records * 0.80)]
    index_test = index[int(num_records * 0.80) : num_records]

    return index_train, index_test


index_train, index_test = build_data_buckets(discharge_summaries_df.shape[0])
len(index_train)

47721

In [8]:
# Copy training data into a separate dataframe
training_data_df = discharge_summaries_df.iloc[index_train].copy()
training_data_df.head(5)

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
45907,47509,20849,195254.0,2137-10-11,,,Discharge summary,Report,,,Admission Date: [**2137-9-15**] ...
56307,56268,25466,144778.0,2194-12-30,,,Discharge summary,Addendum,,,"Name: [**Known lastname 6610**], [**Known fir..."
1132,1605,65003,116768.0,2167-06-02,,,Discharge summary,Report,,,Admission Date: [**2167-5-26**] ...
4766,4772,76853,119383.0,2113-05-25,,,Discharge summary,Report,,,Admission Date: [**2113-5-16**] ...
54224,48638,9602,163874.0,2177-08-21,,,Discharge summary,Report,,,Admission Date: [**2177-8-18**] ...


### Clean and Prep Discharge Summaries Notes Data

In [9]:
# Remove some special characters etc., and split each discharge summary note into a list of tokens
def clean_note_text(note_text):
    note_text = re.sub('<[^>]*>', '', note_text)
    note_text = re.sub('[\W]+', ' ', note_text.lower())
    note_tokens = note_text.split()
    return note_tokens

notes_tokens_list = list(training_data_df['TEXT'].apply(clean_note_text))
len(notes_tokens_list)
# notes_tokens_list[:5]

47721

### Train the Word2Vec Model

Note: We used gensim.models.Word2Vec version 4.2.0. "vector_size" parameter was called "size" in older versions (in case issue is encountered).

In [12]:
# Word2Vec model parameters from the paper:

# Window size (window): 5, model will use total (left + right) 5 words for context
# Min count (min_count): 5 (default) i.e. words occurring in less than 5 notes will be removed
# Size of output vector: 128, each word will be mapped to 128 dimension vector
# Skip gram: sg = 1 implies skip gram is used (paper uses skip gram instead of CBOW)
# Negative (negative): 5, negative sampling speeds up the training process
# Down sampling (sample): 1e-3, parameter for down sampling high frequency words

# Create model
word2vec_model = gensim.models.Word2Vec(window = 5, vector_size = 128, sample = 1e-3, negative = 5, sg = 1)

# Build vocabulary using tokens created from the discharge summary notes
word2vec_model.build_vocab(notes_tokens_list)

index = list(range(len(notes_tokens_list)))
# print(len(index))
start_time = time()

# Do multiple runs, shuffling the data for each run for improved accuracy
for epoch in range(5):
    random.shuffle(index)
    note_tokens = [notes_tokens_list[i] for i in index]
    # print(len(note_tokens))
    word2vec_model.train(note_tokens, total_examples = word2vec_model.corpus_count, epochs = word2vec_model.epochs)
    print(epoch)
    
training_time = time() - start_time
print("Time taken to train the Word2Vec model: ", training_time, "seconds")

0
1
2
3
4
Time taken to train the Word2Vec model:  3353.0938024520874 seconds


### Save Trained Model

In [20]:
# Save the trained Word2Vec model to be used later
word2vec_model.save(MODEL_FILE_PATH)

# Open a saved Word2Vec model 
# word2vec_model = gensim.models.Word2Vec.load(model_dir + 'word2vec_model')

# Store the input-hidden weight matrix
word2vec_model.wv.save_word2vec_format(MODEL_TEXT_FILE_PATH, binary = False)

### Test Model

#### Load the saved model and run sanity checks on model

In [6]:
word2vec_model = gensim.models.Word2Vec.load(MODEL_FILE_PATH)

In [11]:
word2vec_model.wv.most_similar('heart')

[('congestive', 0.6031102538108826),
 ('rate', 0.5973581671714783),
 ('irregular', 0.5703706741333008),
 ('rhythm', 0.5551130175590515),
 ('diastolic', 0.5521538257598877),
 ('attack', 0.5483540892601013),
 ('lungs', 0.5457122921943665),
 ('heartrate', 0.5342849493026733),
 ('beating', 0.5156849026679993),
 ('systolic', 0.5133267641067505)]

In [8]:
'heart' in word2vec_model.wv

True

In [10]:
# Embedding for heart
word2vec_model.wv['heart']

array([ 7.06949905e-02, -1.90536574e-01, -3.61810595e-01,  2.97478259e-01,
        2.05318749e-01, -2.75612116e-01,  4.81133163e-02,  2.22510353e-01,
       -9.02033001e-02,  5.45390368e-01,  8.47483128e-02, -4.81615394e-01,
        6.02432154e-02, -1.21645816e-01, -1.77577406e-01,  3.19426596e-01,
       -5.93667209e-01,  2.34584898e-01, -3.02697867e-01, -7.76162744e-02,
       -6.01359189e-01, -1.23697087e-01, -4.12711710e-01,  1.96824282e-01,
       -3.11441328e-02, -4.04376417e-01,  8.70955884e-02, -4.95336682e-01,
       -3.43563586e-01,  1.33600667e-01, -3.18972945e-01, -6.05305694e-02,
        3.98243606e-01,  4.11699861e-02, -2.16017634e-01, -8.56307521e-02,
        6.52283952e-02,  3.46383512e-01,  3.64983141e-01, -2.23941520e-01,
        2.74442524e-01,  6.42432630e-01,  3.74905497e-01, -2.65075177e-01,
        2.82089055e-01,  4.76425231e-01, -1.72828078e-01,  8.18187222e-02,
       -1.20404899e-01,  2.27561608e-01,  4.16058898e-01,  1.02656819e-01,
        9.49631035e-02,  