In [1]:
import torch
import sys
sys.path.append('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes')  
from Model_rnn import *
from tokenizers import Tokenizer
import pandas as pd
from utils import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def load_seq_model(pth_file):
    model = torch.load(pth_file, map_location=device) 
    model.eval()
    return model

In [4]:
def decode_tokens(tokens, tokenizer):
    return tokenizer.decode(tokens, skip_special_tokens=True)

In [5]:
def predict(model, input_tensor,target_tensor, pred="full"):
    model.eval()
    hidden = model.rnn.init_zero_hidden(1)
    input_tensor = input_tensor.to(device)

    if isinstance(hidden, (list, tuple)):
        hidden = tuple(h.to(device) for h in hidden)
    else:
        hidden = hidden.to(device)

    with torch.no_grad():
        outputs = model(input_tensor, hidden)

    min_len = target_tensor.size(0)
    if pred.lower() == 'first':
        pred_logits = outputs[:, :min_len, :]
    elif pred.lower() == 'middle':
        mid = min_len // 2
        start = mid - 1
        end = mid + 2
        pred_logits = outputs[:, start:end, :]
    elif pred.lower() == 'last':
        pred_logits = outputs[:, -min_len:, :]
    elif pred.lower() == 'full':
        pred_logits = outputs
    else:
        raise ValueError("Invalid pred type. Choose from: first, middle, last, full")

    preds = pred_logits.argmax(dim=2)
    return preds.squeeze(0).cpu().tolist()


In [6]:
def clean_pred_tokens(tokens):
    return [t for t in tokens if t != 0]

In [7]:
df = pd.read_csv('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Dataset/icd10-codes-and-descriptions/Tokens.csv')
tokenizer = Tokenizer.from_file('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Notebook/bpe_tokenizer.json')

In [8]:
df.head(5)

Unnamed: 0.1,Unnamed: 0,code_padded,desc_padded
0,0,"[tensor(2), tensor(13998), tensor(4), tensor(5...","[tensor(2), tensor(15427), tensor(302), tensor..."
1,1,"[tensor(2), tensor(13998), tensor(4), tensor(6...","[tensor(2), tensor(15427), tensor(302), tensor..."
2,2,"[tensor(2), tensor(13998), tensor(4), tensor(1...","[tensor(2), tensor(15427), tensor(97), tensor(..."
3,3,"[tensor(2), tensor(8278), tensor(4), tensor(60...","[tensor(2), tensor(7741), tensor(3267), tensor..."
4,4,"[tensor(2), tensor(8278), tensor(4), tensor(35...","[tensor(2), tensor(7741), tensor(4503), tensor..."


In [9]:
df["desc_padded"] = df["desc_padded"].apply(lambda x: torch.tensor(extract_tensor_ids(x)))
df["code_padded"] = df["code_padded"].apply(lambda x: torch.tensor(extract_tensor_ids(x)))

In [10]:
example_idx = 1250
source_tensor = df.loc[example_idx, 'desc_padded'].unsqueeze(0)
target_tensor = df.loc[example_idx, 'code_padded']

In [11]:
print("\nExample Sample:")
print(f"Source tensor: {source_tensor}")
print(f"Target tensor: {target_tensor}")


Example Sample:
Source tensor: tensor([[   2,  749, 3202,   67,  127,  903, 1087, 3920,  749, 3202,   67,  903,
         1087, 3920,    3,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]])
Target tensor: tensor([   2, 6670,    4,  276,    3])


In [12]:
print(f"Source   : {decode_tokens(df.loc[example_idx, 'desc_padded'].tolist(), tokenizer)}")
print(f"Target   : {decode_tokens(df.loc[example_idx, 'code_padded'].tolist(), tokenizer)}")

Source   : malignant melanoma of left eyelid including canthus malignant melanoma of eyelid including canthus
Target   : C43 . 12


In [13]:
models_info = {
    'RNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/RNN_first/checkpoints/best_model.pt',
    'LSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/LSTM_last/checkpoints/best_model.pt',
    'GRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/GRU_last/checkpoints/best_model.pt',
    'DeepRNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepRNN/checkpoints/best_model.pt',
    'DeepLSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepLSTM/checkpoints/best_model.pt',
    'DeepGRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepGRU/checkpoints/best_model.pt',
    'BiRNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiRNN/checkpoints/best_model.pt',
    'BiLSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiLSTM/checkpoints/best_model.pt',
    'BiGRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiGRU/checkpoints/best_model.pt'
}

In [14]:
print("\n Example:")
print(f"Source   : {decode_tokens(df.loc[example_idx, 'desc_padded'].tolist(), tokenizer)}")
print(f"Target   : {decode_tokens(df.loc[example_idx, 'code_padded'].tolist(), tokenizer)}")

print("\n Model Predictions:")
for model_name, model_path in models_info.items():
    model = load_seq_model(model_path)
    if model_name in ['DeepRNN', 'LSTM', 'GRU', 'DeepLSTM', 'DeepGRU']:
        pred_tokens = predict(model, source_tensor,target_tensor,pred="last")
    elif model_name in ['BiRNN', 'BiLSTM', 'BiGRU']:
        pred_tokens = predict(model, source_tensor,target_tensor, pred="middle")
    else:
        pred_tokens = predict(model, source_tensor,target_tensor, pred="first")     
    pred_tokens = clean_pred_tokens(pred_tokens)
    pred_text = decode_tokens(pred_tokens, tokenizer)
    print(f"{model_name} Predicted: {pred_text}")


 Example:
Source   : malignant melanoma of left eyelid including canthus malignant melanoma of eyelid including canthus
Target   : C43 . 12

 Model Predictions:


  model = torch.load(pth_file, map_location=device)


RNN Predicted: C50 . 72
LSTM Predicted: C43 . 12
GRU Predicted: C43 . 12
DeepRNN Predicted: P07 . 12
DeepLSTM Predicted: C43 . 12
DeepGRU Predicted: C43 . 12
BiRNN Predicted: C43 . 12
BiLSTM Predicted: C43 . 12
BiGRU Predicted: C43 . 12
