# Health Expenditure Predictor Demonstration

In [1]:
import torch
import torch.nn as nn
from models.SubLayers import MultiHeadAttention


import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import time
import math
import random
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class HealthExpenditureModule(nn.Module):
    """
    Health Expenditures Module
    
    Args:
        d_out (int): prediction output dimension
        d_demo (int): demographic embedding dimesion if it exists, or 0
        d_model (int): (multihead attention) model dimension
        n_head (int): the number of (multihead attention) heads
        dropout (float): dropout ratio of a multihead attention and a final linear layer
    
    """
    def __init__(self, d_out, d_demo=0, d_model=512, n_head=8, dropout=0.1):
        super().__init__()
        self.d_demo = d_demo
        
        assert d_model % n_head == 0
        self.multihead_attn = MultiHeadAttention(n_head, d_model, d_k = (d_model // n_head), d_v = (d_model // n_head), dropout=dropout)
        
        d_inter = d_model * 2 + d_demo
        
        self.linear = nn.Sequential(
            nn.Linear(d_inter, d_inter // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_inter // 2, d_out)
        )
    
    def forward(self, q, kv, emb_demo=None):
        # q: (b, query_len, d_model)
        # kv: (b, visit_len, d_model)
        # emb_demo: (b, d_demo)
        
        output, attn = self.multihead_attn(q, kv, kv)
        output = output.view(q.size(0), -1) # output: (b, d_model * 2)
        
        if emb_demo is not None:
            assert emb_demo.size(-1) == self.d_demo # model sanity
            
            output = torch.cat([output, emb_demo], dim=1)

        output = self.linear(output) # (b, d_out)
        
        return output, attn

In [3]:
class HealthExpenditurePredictor(nn.Module):
    """
    Health Expenditures Prediction Module
    
    Args:
        n_disease (int): the number of prediction diseases
        d_out (int): prediction output dimension
        d_visit (int): visit embedding dimension
        d_query (int): combined query dimension
        d_demo (int): demographic embedding dimesion if it exists, or 0
        d_model (int): (multihead attention) model dimension
        n_head (int): the number of (multihead attention) heads
        dropout (float): dropout ratio of a multihead attention and a final linear layer of HealthExpenditureModule
        
    """
    def __init__(self, n_disease, d_out, d_visit, d_query, d_demo=0, d_model=512, n_head=8, dropout=0.1):
        super().__init__()
        
        self.module = HealthExpenditureModule(d_out, d_demo, d_model, n_head, dropout)
        
        # Embedding layers for query, visit, demo
        self.w_q1 = nn.Embedding(n_disease, d_query) # for sparse input data, such as one-hot encoded vectors
        self.w_q2 = nn.Linear(1, d_query, bias=False)
        self.w_qs = nn.Linear(d_query, d_model, bias=False)
        self.w_visit = nn.Linear(d_visit, d_model)
        if d_demo != 0:
            self.w_demo = nn.Linear(d_demo, d_demo, bias=False)
    
    def forward(self, q_1, q_2, emb_visit, emb_demo=None):
        # q_1 (torch.LongTensor): (b)
        # q_2: (b, 1)
        
        q_1 = self.w_q1(q_1) # q_1: (b, d_query)
        q_2 = self.w_q2(q_2) # q_2: (b, d_query)
        
        # Concatenating queries
        q = torch.stack((q_1, q_2), dim=1) # q: (b, query_len, d_query)
        
        # Query
        q = self.w_qs(q) # q: (b, query_len, d_model)
        
        # Key, Value
        kv = self.w_visit(emb_visit) # kv: (b, visit_len, d_model)
        
        # Concatenating demographic embeddings
        if emb_demo is not None:
            emb_demo = self.w_demo(emb_demo) # emb_demo: (b, d_demo)
        
        output, attn = self.module(q, kv, emb_demo)
        
        return output, attn

# Demo 1

## 데이터

In [4]:
"""
질병, 질병심각도 (Q)
"""

# Six diseases: A, B, C, D, E, F
A = torch.LongTensor([0])
B = torch.LongTensor([1])
C = torch.LongTensor([2])
D = torch.LongTensor([3])
E = torch.LongTensor([4])
F = torch.LongTensor([5])

# Seven patients
query1 = torch.cat((E, B, F, A, C, A, D))
print('query1: ', query1)

# Severity level: (from a normal distribution)
query2 = torch.randn(7,1)
print('query2: ', query2)

query1:  tensor([4, 1, 5, 0, 2, 0, 3])
query2:  tensor([[-0.0722],
        [-0.3962],
        [-0.2648],
        [-0.6019],
        [-1.4203],
        [ 2.1846],
        [-1.0118]])


In [5]:
"""
방문표현 (K,V)
"""

# Random visit embedding
n_patients = 7
T = 14
visit_embed_size = 5
x_v = torch.randn([n_patients, T, visit_embed_size]) # batch_first

"""
demographic emb
"""
# Random demographic embedding
demo_embed_size = 9
x_d = torch.randn([n_patients, demo_embed_size])


## Model

In [14]:
# hyperparameters
num_disease = 6
dim_output = 3 # Three expenditures: total, patient, NHIS
dim_visit = 5
dim_query = 5
dim_demo = 9

In [15]:
hep = HealthExpenditurePredictor(num_disease, dim_output, dim_visit, dim_query, d_demo=dim_demo)
hep

HealthExpenditurePredictor(
  (module): HealthExpenditureModule(
    (multihead_attn): MultiHeadAttention(
      (w_qs): Linear(in_features=512, out_features=512, bias=False)
      (w_ks): Linear(in_features=512, out_features=512, bias=False)
      (w_vs): Linear(in_features=512, out_features=512, bias=False)
      (fc): Linear(in_features=512, out_features=512, bias=False)
      (attention): ScaledDotProductAttention(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
    )
    (linear): Sequential(
      (0): Linear(in_features=1033, out_features=516, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): Linear(in_features=516, out_features=3, bias=True)
    )
  )
  (w_q1): Embedding(6, 5)
  (w_q2): Linear(in_features=1, out_features=5, bias=False)
  (w_qs): Linear(in_features=5, out_features=512, bias=False)
  (w_visit): Linear(in_

In [18]:
hep_output, hep_attn = hep(query1, query2, x_v, x_d)

print(hep_output.size())
hep_output

torch.Size([7, 3])


tensor([[-0.0630,  0.2212,  0.0496],
        [-0.1097,  0.0945,  0.1078],
        [-0.1006,  0.0935,  0.0080],
        [-0.1850, -0.1710,  0.0608],
        [ 0.0460,  0.0794,  0.0278],
        [-0.0991, -0.0287,  0.1441],
        [ 0.0513,  0.1180, -0.0397]], grad_fn=<AddmmBackward0>)

# Demo 2

## 데이터

In [28]:
"""
질병, 질병심각도 (Q)
"""

# Six diseases: A, B, C, D, E, F
A = torch.LongTensor([0])
B = torch.LongTensor([1])
C = torch.LongTensor([2])
D = torch.LongTensor([3])
E = torch.LongTensor([4])
F = torch.LongTensor([5])

# 10000 patients
# create a tuple of 10000 elements, by randomly sampling from the above 6 diseases
query1 = torch.cat([torch.LongTensor([random.randint(0,5)]) for _ in range(10000)])
print('query1: ', query1)

# Severity level: (from a normal distribution)
query2 = torch.randn(10000,1)
print('query2: ', query2)

query1:  tensor([2, 3, 4,  ..., 1, 2, 0])
query2:  tensor([[ 0.5559],
        [-2.7790],
        [ 0.4456],
        ...,
        [-0.9780],
        [ 1.3913],
        [-0.1570]])


In [9]:
"""
방문표현 (K,V)
"""

# Random visit embedding
n_patients = 10000
T = 14
visit_embed_size = 5
x_v = torch.randn([n_patients, T, visit_embed_size]) # batch_first

"""
demographic emb
"""
# Random demographic embedding
demo_embed_size = 9
x_d = torch.randn([n_patients, demo_embed_size])


In [14]:
"""
치료비 (Y)
"""

# Random health expenditure
y = torch.randn([n_patients, 3]) # torch.Size([10000, 3])

In [72]:
"""
Dataset
"""
# create a dataloader for the following data: query1, query2, x_v, x_d, y
# total number of data points = 10000
# batch size = 64

class MyDataset(data.Dataset):
    def __init__(self, query1, query2, x_v, x_d, y):
        self.query1 = query1
        self.query2 = query2
        self.x_v = x_v
        self.x_d = x_d
        self.y = y
        
    def __getitem__(self, index):
        return self.query1[index], self.query2[index], self.x_v[index], self.x_d[index], self.y[index]
    
    def __len__(self):
        return len(self.query1)
    
dataset = MyDataset(query1, query2, x_v, x_d, y)

"""
Data split
"""

# split the data into train, validation, test
# train: 80%
# validation: 10%
# test: 10%

train_size = math.floor(0.8 * len(dataset))
valid_size = math.floor(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

train_dataloader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = data.DataLoader(valid_dataset, batch_size=64, shuffle=True)
test_dataloader = data.DataLoader(test_dataset, batch_size=64, shuffle=True)

## Model

In [73]:
"""
Get model
"""

# hyperparameters
num_disease = 6
dim_output = 3 # Three expenditures: total, patient, NHIS
dim_visit = 5
dim_query = 5
dim_demo = 9

model = HealthExpenditurePredictor(num_disease, dim_output, dim_visit, dim_query, d_demo=dim_demo)

## Train

In [77]:
"""
Train model on train_dataloader
"""

# loss function
criterion = nn.MSELoss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# number of epochs
n_epochs = 100

# training: save the model with the lowest validation loss
valid_loss_min = np.Inf
for epoch in range(n_epochs):
    start_time = time.time()
    train_loss = 0.0
    valid_loss = 0.0
    
    # train the model
    model.train()
    for batch_idx, (query1, query2, x_v, x_d, y) in enumerate(train_dataloader):
        optimizer.zero_grad()
        output, _ = model(query1, query2, x_v, x_d)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    # validate the model
    model.eval()
    for batch_idx, (query1, query2, x_v, x_d, y) in enumerate(valid_dataloader):
        output, _ = model(query1, query2, x_v, x_d)
        loss = criterion(output, y)
        valid_loss += loss.item()
        
    # calculate average losses
    train_loss = train_loss/len(train_dataloader)
    valid_loss = valid_loss/len(valid_dataloader)
    
    # print training/validation statistics 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTime: {:.2f} sec'.format(
        epoch+1, train_loss, valid_loss, time.time() - start_time))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), 'model.pt')
        valid_loss_min = valid_loss


Epoch: 1 	Training Loss: 1.010723 	Validation Loss: 0.974676 	Time: 6.11 sec
Validation loss decreased (inf --> 0.974676).  Saving model ...
Epoch: 2 	Training Loss: 1.010812 	Validation Loss: 0.976384 	Time: 5.98 sec
Epoch: 3 	Training Loss: 1.010791 	Validation Loss: 0.978181 	Time: 5.94 sec
Epoch: 4 	Training Loss: 1.010724 	Validation Loss: 0.979848 	Time: 5.93 sec
Epoch: 5 	Training Loss: 1.010738 	Validation Loss: 0.977986 	Time: 5.96 sec
Epoch: 6 	Training Loss: 1.010725 	Validation Loss: 0.979549 	Time: 5.94 sec
Epoch: 7 	Training Loss: 1.010707 	Validation Loss: 0.982019 	Time: 5.94 sec
Epoch: 8 	Training Loss: 1.010734 	Validation Loss: 0.982883 	Time: 5.92 sec
Epoch: 9 	Training Loss: 1.010743 	Validation Loss: 0.979857 	Time: 6.01 sec
Epoch: 10 	Training Loss: 1.010715 	Validation Loss: 0.982338 	Time: 5.95 sec
Epoch: 11 	Training Loss: 1.010709 	Validation Loss: 0.978592 	Time: 5.95 sec
Epoch: 12 	Training Loss: 1.010711 	Validation Loss: 0.985971 	Time: 5.95 sec
Epoch: 13

KeyboardInterrupt: 

## Test

In [78]:
"""
Test model on test_dataloader
"""

# load the model that got the best validation loss
model.load_state_dict(torch.load('model.pt'))

# get test loss
test_loss = 0.0
model.eval()
for batch_idx, (query1, query2, x_v, x_d, y) in enumerate(test_dataloader):
    output, _ = model(query1, query2, x_v, x_d)
    loss = criterion(output, y)
    test_loss += loss.item()

# calculate and print avg test loss
test_loss = test_loss/len(test_dataloader)

print('Test Loss: {:.6f}'.format(test_loss))


Test Loss: 1.011417


In [2]:
import tensorflow as tf

tf.__version__

'1.13.1'

In [None]:
import tensorflow as tf

class HealthExpenditureModule(object):
    """
    Health Expenditures Module
    
    Args:
        d_out (int): prediction output dimension
        d_demo (int): demographic embedding dimesion if it exists, or 0
        d_model (int): (multihead attention) model dimension
        n_head (int): the number of (multihead attention) heads
        dropout (float): dropout ratio of a multihead attention and a final linear layer
    
    """
    def __init__(self, d_out, d_demo=0, d_model=512, n_head=8, dropout=0.1):
        self.d_demo = d_demo
        
        assert d_model % n_head == 0
        self.multihead_attn = MultiHeadAttention(n_head, d_model, d_k = (d_model // n_head), d_v = (d_model // n_head), dropout=dropout)
        
        d_inter = d_model * 2 + d_demo
        
        self.linear = tf.keras.Sequential([
            tf.keras.layers.Dense(d_inter // 2),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(d_out)
        ])
    
    def forward(self, q, kv, emb_demo=None):
        # q: (b, query_len, d_model)
        # kv: (b, visit_len, d_model)
        # emb_demo: (b, d_demo)
        
        output, attn = self.multihead_attn(q, kv, kv)
        output = tf.reshape(output, [tf.shape(q)[0], -1]) # output: (b, d_model * 2)
        
        if emb_demo is not None:
            assert emb_demo.get_shape().as_list()[-1] == self.d_demo # model sanity
            
            output = tf.concat([output, emb_demo], axis=-1)

        output = self.linear(output) # (b, d_out)
        
        return output, attn
