# Objective

To create a simple model that can predict which medication from a given list
would be appropriate for a patient based on their symptoms/treatment goals.

In [None]:
!pip install pandas
!pip install datasets

In [1]:
import os
import re
import glob
import shutil
import string
import pathlib


data_dir = os.path.abspath(os.path.join(os.getcwd(),'..','data'))


os.environ['MPLCONFIGDIR'] = os.path.join(data_dir,'plt_configs')
import matplotlib.pyplot as plt


os.environ['HF_HOME'] = os.path.join(data_dir,'hf_cache')
import datasets

import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras.layers import TextVectorization

# Data collection

Drug review dataset obtained from Hugging Face

In [2]:
dataset = datasets.load_dataset("flxclxc/encoded_drug_reviews")

Using custom data configuration flxclxc--encoded_drug_reviews-ee0cdba36988e67d
Found cached dataset json (/tf/data/hf_cache/datasets/flxclxc___json/flxclxc--encoded_drug_reviews-ee0cdba36988e67d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
df = pd.DataFrame(dataset)

In [4]:
print(df)

                                                   train
0      {'patient_id': 184648, 'drugName': 'Efudex', '...
1      {'patient_id': 25268, 'drugName': 'Flector Pat...
2      {'patient_id': 172019, 'drugName': 'Amitiza', ...
3      {'patient_id': 196063, 'drugName': 'Stendra', ...
4      {'patient_id': 225264, 'drugName': 'Bupropion'...
...                                                  ...
53466  {'patient_id': 199190, 'drugName': 'Depo-Prove...
53467  {'patient_id': 188476, 'drugName': 'ParaGard',...
53468  {'patient_id': 105752, 'drugName': 'Methylpred...
53469  {'patient_id': 56713, 'drugName': 'Meclizine',...
53470  {'patient_id': 215006, 'drugName': 'Fluoxetine...

[53471 rows x 1 columns]


In [5]:
complete_dataset = dataset['train'].to_pandas()
complete_dataset.head()

Unnamed: 0,patient_id,drugName,condition,review,rating,date,usefulCount,review_length,encoded
0,184648,Efudex,basal cell carcinoma,"""I have BCC on my upper arm and SCC on upper l...",1.0,"August 30, 2016",16,36,"[-0.0633561835, 0.0115883639, -0.0027463636, 0..."
1,25268,Flector Patch,pain,"""I tore my shoulder labrum and the pain can be...",8.0,"May 29, 2014",40,45,"[-0.083280459, 0.0182377025, 0.0619471855, 0.0..."
2,172019,Amitiza,irritable bowel syndrome,"""Amitiza is the best if you have ibs!""",10.0,"July 13, 2016",9,8,"[-0.0300639421, -0.0081300493, 0.0343461707, 0..."
3,196063,Stendra,erectile dysfunction,"""Viagra works in a strong, crude way with side...",10.0,"November 10, 2014",82,141,"[-0.0037669495, -0.0845683292, 0.0196341239, 0..."
4,225264,Bupropion,depression,"""I really wanted Wellbutrin to work. I was giv...",3.0,"October 4, 2015",15,62,"[-0.0633124188, 0.0167291258, 0.0707527027, 0...."


In [6]:
print(df.info()) # 53,471 total drug reviews in dataset

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53471 entries, 0 to 53470
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   train   53471 non-null  object
dtypes: object(1)
memory usage: 417.9+ KB
None


In [7]:
complete_dataset.info() 

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53471 entries, 0 to 53470
Data columns (total 9 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   patient_id     53471 non-null  int64  
 1   drugName       53471 non-null  object 
 2   condition      53471 non-null  object 
 3   review         53471 non-null  object 
 4   rating         53471 non-null  float64
 5   date           53471 non-null  object 
 6   usefulCount    53471 non-null  int64  
 7   review_length  53471 non-null  int64  
 8   encoded        53471 non-null  object 
dtypes: float64(1), int64(3), object(5)
memory usage: 3.7+ MB


# Data Cleaning

-Isolated drugs from the dataset

In [8]:
# unique values in column "drugName"
drugs = complete_dataset['drugName'].unique()
for drug in drugs:
    print(drug)

Efudex
Flector Patch
Amitiza
Stendra
Bupropion
Amiodarone
Sprintec
Acetaminophen / tramadol
Sertraline
Lisinopril
Levonorgestrel
Liraglutide
Fluoxetine
Methylprednisolone
Doxylamine / pyridoxine
Zoladex
Clarithromycin
Tranylcypromine
Phentermine
Prednisolone
Tarceva
Ezetimibe / simvastatin
Belsomra
Enbrel
Flexeril
Ethinyl estradiol / levonorgestrel
Prozac
Clomipramine
Klonopin
Ethinyl estradiol / norgestimate
Zovirax
Adipex-P
Vilazodone
Etonogestrel
Buspirone
Phentermine / topiramate
Micardis HCT
Dapsone
TriNessa
Estarylla
Venlafaxine
Aviane
Geodon
Tiotropium
Remicade
Synthroid
Minocycline
Ativan
Eletriptan
Levoxyl
Duloxetine
Dexbrompheniramine / pseudoephedrine
Loestrin 24 Fe
Haloperidol
Nexplanon
Robaxin
Cefdinir
Ortho Tri-Cyclen Lo
Desogen
Hiprex
Lorazepam
Loratadine / pseudoephedrine
Lyrica
Norco
Ethinyl estradiol / norelgestromin
Ambrisentan
Trulicity
Duac
Zyrtec
Propranolol
Alli
Exemestane
Mirena
NuvaRing
Ethinyl estradiol / norethindrone
Desogestrel / ethinyl estradiol
Magnesium

-Selected drugs with a review count > 200 reviews

In [9]:
pd.set_option("display.max_rows", None)
frequence = complete_dataset['drugName'].value_counts()
print(frequence)

Levonorgestrel                                                                                      1265
Etonogestrel                                                                                        1081
Ethinyl estradiol / norethindrone                                                                    869
Nexplanon                                                                                            736
Ethinyl estradiol / norgestimate                                                                     649
Ethinyl estradiol / levonorgestrel                                                                   591
Phentermine                                                                                          539
Sertraline                                                                                           506
Escitalopram                                                                                         452
Mirena                                                 

-Removed brand name drugs from dataset so that model did not classify brand/generic as two separate medications 

ex. Lexapro/Escitalopram or Chantix/Varenicline

In [10]:
base_dir = os.path.join(data_dir,'drugs')
if not os.path.exists(base_dir):
    os.makedirs(base_dir)

drug_names = ["Levonorgestrel",
              "Etonogestrel",
              "Ethinyl estradiol / norethindrone",
              "Ethinyl estradiol / norgestimate",
              "Ethinyl estradiol / levonorgestrel",
              "Phentermine",
              "Sertraline",
              "Escitalopram",
              "Mirena",
              "Gabapentin",
              "Miconazole",
              "Bupropion",
              "Venlafaxine",
              "Duloxetine",
              "Tramadol",
              "Clonazepam",
              "Citalopram",
              "Medroxyprogesterone",
              "Bupropion / naltrexone",
              "Varenicline",
              "Metronidazole",
              "Drospirenone / ethinyl estradiol",
              "Tioconazole",
              "Depo-Provera",
              "Liraglutide",
              "Fluoxetine",
              "Quetiapine",
              "Lo Loestrin Fe",
              "Alprazolam",
              "Amitriptyline",
              "Doxycycline",
              "Desvenlafaxine",
              "Trazodone",
              "Suprep Bowel Prep Kit",
              "Paroxetine",
              "Bisacodyl",
              "Lorcaserin"]
drug_directories = []

for drug_name in drug_names:
    current_drug_dir = os.path.join(data_dir,'drugs',drug_name)
    print(current_drug_dir)
    drug_directories.append(current_drug_dir)
    if not os.path.exists(current_drug_dir):
        os.makedirs(current_drug_dir)


/tf/data/drugs/Levonorgestrel
/tf/data/drugs/Etonogestrel
/tf/data/drugs/Ethinyl estradiol / norethindrone
/tf/data/drugs/Ethinyl estradiol / norgestimate
/tf/data/drugs/Ethinyl estradiol / levonorgestrel
/tf/data/drugs/Phentermine
/tf/data/drugs/Sertraline
/tf/data/drugs/Escitalopram
/tf/data/drugs/Mirena
/tf/data/drugs/Gabapentin
/tf/data/drugs/Miconazole
/tf/data/drugs/Bupropion
/tf/data/drugs/Venlafaxine
/tf/data/drugs/Duloxetine
/tf/data/drugs/Tramadol
/tf/data/drugs/Clonazepam
/tf/data/drugs/Citalopram
/tf/data/drugs/Medroxyprogesterone
/tf/data/drugs/Bupropion / naltrexone
/tf/data/drugs/Varenicline
/tf/data/drugs/Metronidazole
/tf/data/drugs/Drospirenone / ethinyl estradiol
/tf/data/drugs/Tioconazole
/tf/data/drugs/Depo-Provera
/tf/data/drugs/Liraglutide
/tf/data/drugs/Fluoxetine
/tf/data/drugs/Quetiapine
/tf/data/drugs/Lo Loestrin Fe
/tf/data/drugs/Alprazolam
/tf/data/drugs/Amitriptyline
/tf/data/drugs/Doxycycline
/tf/data/drugs/Desvenlafaxine
/tf/data/drugs/Trazodone
/tf/data

In [11]:
data = np.array(["Levonorgestrel",
                 "Etonogestrel",
                 "Ethinyl estradiol / norethindrone",
                 "Ethinyl estradiol / norgestimate",
                 "Ethinyl estradiol / levonorgestrel",
                 "Phentermine",
                 "Sertraline",
                 "Escitalopram",
                 "Mirena",
                 "Gabapentin",
                 "Miconazole",
                 "Bupropion",
                 "Venlafaxine",
                 "Duloxetine",
                 "Tramadol",
                 "Clonazepam",
                 "Citalopram",
                 "Medroxyprogesterone",
                 "Bupropion / naltrexone",
                 "Varenicline",
                 "Metronidazole",
                 "Drospirenone / ethinyl estradiol",
                 "Tioconazole",
                 "Depo-Provera",
                 "Liraglutide",
                 "Fluoxetine",
                 "Quetiapine",
                 "Lo Loestrin Fe",
                 "Alprazolam",
                 "Amitriptyline",
                 "Doxycycline",
                 "Desvenlafaxine",
                 "Trazodone",
                 "Suprep Bowel Prep Kit",
                 "Paroxetine",
                 "Bisacodyl",
                 "Lorcaserin"])
s = pd.Series(data)

In [12]:
print(s[:])

0                         Levonorgestrel
1                           Etonogestrel
2      Ethinyl estradiol / norethindrone
3       Ethinyl estradiol / norgestimate
4     Ethinyl estradiol / levonorgestrel
5                            Phentermine
6                             Sertraline
7                           Escitalopram
8                                 Mirena
9                             Gabapentin
10                            Miconazole
11                             Bupropion
12                           Venlafaxine
13                            Duloxetine
14                              Tramadol
15                            Clonazepam
16                            Citalopram
17                   Medroxyprogesterone
18                Bupropion / naltrexone
19                           Varenicline
20                         Metronidazole
21      Drospirenone / ethinyl estradiol
22                           Tioconazole
23                          Depo-Provera
24              

In [13]:
drug_datasets = []
for drug_name in drug_names:

    filtered_df = complete_dataset[complete_dataset['drugName'] == drug_name]
    drug_datasets.append(filtered_df)

In [14]:
drug_names
drug_directories
drug_datasets

for i in range(len(drug_names)):
    review_counter = 0
    for text in drug_datasets[i]['review']:
        with open(os.path.join(drug_directories[i],str(review_counter)+'.txt'), 'w') as f:
            f.write(text)
        review_counter+=1

# Training model using the dataset

In [15]:
batch_size = 32
seed = 42

raw_train_ds = tf.keras.utils.text_dataset_from_directory(
    pathlib.Path(base_dir),
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)

Found 13923 files belonging to 35 classes.
Using 11139 files for training.


In [16]:
for text_batch, label_batch in raw_train_ds.take(1):
    for i in range(10):
        print("Patient Review: ", text_batch.numpy()[i])
        print("Label:", label_batch.numpy()[i])

Patient Review:  b'"This is so awesome reading everyone\'s reviews. These are real reviews to me. You read other diet adds, they are not real. I started in late August, 28 2013. I have lost a total of 15lbs so far! Looking to lose more. I was near the 200 lb mark. It\'s been like that since I had my last child in 2005. All this time I had been working out, I saw no changes. I needed a change because I have insulin resistance. That is scary to me. Diabetes runs in my family both parents. I love this pill!. It\'s amazing how much energy I feel. I did feel jittery when taking at first and not feel hungry but I ate anyways. I watch what I eat. Insomnia and dry mouth. Drink lots of water. I just want to feel better about myself. "'
Label: 26
Patient Review:  b'"Gabapentin is a miracle worker. While I was suffering with depression and severe anxiety, I was against taking any type of medications to help me however things were starting to get out control. I am very happy with how things are no

In [17]:
for i, label in enumerate(raw_train_ds.class_names):
    print("Label", i, "corresponds to", label)

Label 0 corresponds to Alprazolam
Label 1 corresponds to Amitriptyline
Label 2 corresponds to Bisacodyl
Label 3 corresponds to Bupropion
Label 4 corresponds to Bupropion 
Label 5 corresponds to Citalopram
Label 6 corresponds to Clonazepam
Label 7 corresponds to Depo-Provera
Label 8 corresponds to Desvenlafaxine
Label 9 corresponds to Doxycycline
Label 10 corresponds to Drospirenone 
Label 11 corresponds to Duloxetine
Label 12 corresponds to Escitalopram
Label 13 corresponds to Ethinyl estradiol 
Label 14 corresponds to Etonogestrel
Label 15 corresponds to Fluoxetine
Label 16 corresponds to Gabapentin
Label 17 corresponds to Levonorgestrel
Label 18 corresponds to Liraglutide
Label 19 corresponds to Lo Loestrin Fe
Label 20 corresponds to Lorcaserin
Label 21 corresponds to Medroxyprogesterone
Label 22 corresponds to Metronidazole
Label 23 corresponds to Miconazole
Label 24 corresponds to Mirena
Label 25 corresponds to Paroxetine
Label 26 corresponds to Phentermine
Label 27 corresponds to 

# Create validation set

In [18]:
raw_val_ds = tf.keras.utils.text_dataset_from_directory(
    pathlib.Path(base_dir),
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)


Found 13923 files belonging to 35 classes.
Using 2784 files for validation.


# Prepare dataset for training

-Utilized TensorFlow Text Vectorization

In [19]:
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')

In [20]:
max_features = 10000
sequence_length = 250

vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length)

In [21]:
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

In [22]:
def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label

In [23]:
text_batch, label_batch = next(iter(raw_train_ds))
first_review, first_label = text_batch[0], label_batch[0]
print("Review", first_review)
print("Label", raw_train_ds.class_names[first_label])
print("Vectorized review", vectorize_text(first_review, first_label))

Review tf.Tensor(b'"I\'m on Mylans generic version of Generess Fe and so far its been okay. I\'ve been on 4 different birth controls in the past year and a half and I was put on this one to try to eliminate a cyst I\'ve had since November. While I can say the cyst pain had gone away I have now noticed that I get extremely nauseous on this pill. I\'m sure it\'s just my body adjusting but I feel terrible. I take the pill at 9 every night so I think I have to make sure that I don\'t have an empty stomach when I take it. This pill has also made me break out a bit but hopefully that will go away in the next few months."', shape=(), dtype=string)
Label Ethinyl estradiol 
Vectorized review (<tf.Tensor: shape=(1, 250), dtype=int64, numpy=
array([[  25,   13, 8084,  575, 1488,   12, 2842,  571,    3,   22,  152,
          43,   24,  659,   33,   24,   13,  139,  248,   90, 1117,   17,
           4,  304,  122,    3,    8,  272,    3,    2,   10,  156,   13,
          14,   79,    5,  190,    5,

In [24]:
print("1287 ---> ",vectorize_layer.get_vocabulary()[1287])
print(" 313 ---> ",vectorize_layer.get_vocabulary()[313])
print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))

1287 --->  contraceptive
 313 --->  8
Vocabulary size: 10000


In [25]:
train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)

In [26]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Create the model

Loss function: Sparse Categorical Crossentropy (multi classification)

Optimizer: nadam (Nesterov-accelerated Adaptive Moment Estimation)

Activation: Softmax

In [27]:
embedding_dim = 16

In [28]:
model = tf.keras.Sequential([
  layers.Embedding(max_features + 1, embedding_dim),
  layers.Dropout(0.2),
  layers.GlobalAveragePooling1D(),
  layers.Dropout(0.2),
  layers.Dense(len(drug_names), activation=tf.nn.softmax)])


In [29]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              optimizer='nadam',
              metrics=['accuracy'])

In [30]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, None, 16)          160016    
                                                                 
 dropout (Dropout)           (None, None, 16)          0         
                                                                 
 global_average_pooling1d (G  (None, 16)               0         
 lobalAveragePooling1D)                                          
                                                                 
 dropout_1 (Dropout)         (None, 16)                0         
                                                                 
 dense (Dense)               (None, 37)                629       
                                                                 
Total params: 160,645
Trainable params: 160,645
Non-trainable params: 0
__________________________________________________

# Train the model

In [31]:
epochs = 300
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300


Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78/300
Epoch 79/300
Epoch 80/300
Epoch 81/300
Epoch 82/300
Epoch 83/300
Epoch 84/300
Epoch 85/300
Epoch 86/300
Epoch 87/300
Epoch 88/300
Epoch 89/300
Epoch 90/300
Epoch 91/300
Epoch 92/300
Epoch 93/300
Epoch 94/300
Epoch 95/300
Epoch 96/300
Epoch 97/300
Epoch 98/300
Epoch 99/300
Epoch 100/300
Epoch 101/300
Epoch 102/300
Epoch 103/300
Epoch 104/300
Epoch 105/300
Epoch 106/300
Epoch 107/300
Epoch 108/300
Epoch 109/300
Epoch 110/300
Epoch 111/300
Epoch 112/300
Epoch 113/300
Epoch 114/300


Epoch 115/300
Epoch 116/300
Epoch 117/300
Epoch 118/300
Epoch 119/300
Epoch 120/300
Epoch 121/300
Epoch 122/300
Epoch 123/300
Epoch 124/300
Epoch 125/300
Epoch 126/300
Epoch 127/300
Epoch 128/300
Epoch 129/300
Epoch 130/300
Epoch 131/300
Epoch 132/300
Epoch 133/300
Epoch 134/300
Epoch 135/300
Epoch 136/300
Epoch 137/300
Epoch 138/300
Epoch 139/300
Epoch 140/300
Epoch 141/300
Epoch 142/300
Epoch 143/300
Epoch 144/300
Epoch 145/300
Epoch 146/300
Epoch 147/300
Epoch 148/300
Epoch 149/300
Epoch 150/300
Epoch 151/300
Epoch 152/300
Epoch 153/300
Epoch 154/300
Epoch 155/300
Epoch 156/300
Epoch 157/300
Epoch 158/300
Epoch 159/300
Epoch 160/300
Epoch 161/300
Epoch 162/300
Epoch 163/300
Epoch 164/300
Epoch 165/300
Epoch 166/300
Epoch 167/300
Epoch 168/300
Epoch 169/300
Epoch 170/300
Epoch 171/300
Epoch 172/300
Epoch 173/300
Epoch 174/300
Epoch 175/300
Epoch 176/300
Epoch 177/300
Epoch 178/300
Epoch 179/300
Epoch 180/300
Epoch 181/300
Epoch 182/300
Epoch 183/300
Epoch 184/300
Epoch 185/300
Epoch 

Epoch 228/300
Epoch 229/300
Epoch 230/300
Epoch 231/300
Epoch 232/300
Epoch 233/300
Epoch 234/300
Epoch 235/300
Epoch 236/300
Epoch 237/300
Epoch 238/300
Epoch 239/300
Epoch 240/300
Epoch 241/300
Epoch 242/300
Epoch 243/300
Epoch 244/300
Epoch 245/300
Epoch 246/300
Epoch 247/300
Epoch 248/300
Epoch 249/300
Epoch 250/300
Epoch 251/300
Epoch 252/300
Epoch 253/300
Epoch 254/300
Epoch 255/300
Epoch 256/300
Epoch 257/300
Epoch 258/300
Epoch 259/300
Epoch 260/300
Epoch 261/300
Epoch 262/300
Epoch 263/300
Epoch 264/300
Epoch 265/300
Epoch 266/300
Epoch 267/300
Epoch 268/300
Epoch 269/300
Epoch 270/300
Epoch 271/300
Epoch 272/300
Epoch 273/300
Epoch 274/300
Epoch 275/300
Epoch 276/300
Epoch 277/300
Epoch 278/300
Epoch 279/300
Epoch 280/300
Epoch 281/300
Epoch 282/300
Epoch 283/300
Epoch 284/300
Epoch 285/300
Epoch 286/300
Epoch 287/300
Epoch 288/300
Epoch 289/300
Epoch 290/300
Epoch 291/300
Epoch 292/300
Epoch 293/300
Epoch 294/300
Epoch 295/300
Epoch 296/300
Epoch 297/300
Epoch 298/300
Epoch 

In [32]:
loss, accuracy = model.evaluate(train_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

Loss:  0.38051557540893555
Accuracy:  0.8928090333938599


In [33]:
loss, accuracy = model.evaluate(val_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

Loss:  0.8986846208572388
Accuracy:  0.7108476758003235


# Export the model

In [34]:
export_model = tf.keras.Sequential([
  vectorize_layer,
  model,
  layers.Activation('softmax')
])

export_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='nadam',
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)


# Inference on new data

In [40]:
examples = [
  "I need constipation relief. I haven't had a bowel movement in days.",
  "I need to quit smoking. I've been trying to quit for 5 years but haven't been able to stop.",
  "I am clinically obese and my doctor recommends that I lose some weight."
]

predictions=export_model.predict(examples)




In [41]:
for j in range(len(examples)):
    print(examples[j])
    prediction=predictions[j]
    for i, label in enumerate(raw_train_ds.class_names):
        print(label+':'+str(prediction[i]))

I need constipation relief. I haven't had a bowel movement in days.
Alprazolam:0.027992737
Amitriptyline:0.028081547
Bisacodyl:0.039304838
Bupropion:0.026375666
Bupropion :0.026448157
Citalopram:0.026295884
Clonazepam:0.026724856
Depo-Provera:0.026290558
Desvenlafaxine:0.026602026
Doxycycline:0.02651452
Drospirenone :0.026258783
Duloxetine:0.026896669
Escitalopram:0.026440028
Ethinyl estradiol :0.026290426
Etonogestrel:0.026256127
Fluoxetine:0.026316678
Gabapentin:0.026647054
Levonorgestrel:0.026720267
Liraglutide:0.0263603
Lo Loestrin Fe:0.026253564
Lorcaserin:0.026412234
Medroxyprogesterone:0.026308527
Metronidazole:0.027736874
Miconazole:0.026398908
Mirena:0.02646992
Paroxetine:0.0263931
Phentermine:0.026704764
Quetiapine:0.026496897
Sertraline:0.026322342
Suprep Bowel Prep Kit:0.026763124
Tioconazole:0.027407339
Tramadol:0.028944781
Trazodone:0.027157381
Varenicline:0.026279496
Venlafaxine:0.026651198
I need to quit smoking. I've been trying to quit for 5 years but haven't been abl

# Constipation relief 

Model predicts Bisacodyl:0.039304838

Bisacodyl is a laxative. Good prediction!

# Smoking cessation 

Model predicts Bupropion:0.033370513 and Varenicline:0.04371334

Bupropion and Varenicline (Brand name Chantix) are used for smoking cessation.

# Weight loss

Model predicts Phentermine:0.02896901 and Lorcaserin:0.030372752

Phentermine and Locaserin are prescribed for weight loss, so these are good choices. However, I would have also expected the model to predict Bupropion / Naltrexone (0.027201299) and Liraglutide (0.026953785 with a higher degree of confidence as these are also used to treat obesity. 