In [1]:

import torch
import sys
sys.path.append('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes')  
from Model_rnn import *
from Transformer_Model import *
from tokenizers import Tokenizer
import pandas as pd
from utils import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def load_seq_model(pth_file):
    model = torch.load(pth_file, map_location=device) 
    model.eval()
    return model

In [4]:
def decode_tokens(tokens, tokenizer,skip_special_tokens=True):
    if skip_special_tokens:
        return tokenizer.decode(tokens, skip_special_tokens=True)
    else:
        return tokenizer.decode(tokens, skip_special_tokens=False)

In [5]:
def predict(model, input_tensor,target_tensor, pred="full"):
    model.eval()
    hidden = model.rnn.init_zero_hidden(1)
    input_tensor = input_tensor.to(device)

    if isinstance(hidden, (list, tuple)):
        hidden = tuple(h.to(device) for h in hidden)
    else:
        hidden = hidden.to(device)

    with torch.no_grad():
        outputs = model(input_tensor, hidden)

    min_len = target_tensor.size(0)
    if pred.lower() == 'first':
        pred_logits = outputs[:, :min_len, :]
    elif pred.lower() == 'middle':
        mid = min_len // 2
        start = mid - 1
        end = mid + 2
        pred_logits = outputs[:, start:end, :]
    elif pred.lower() == 'last':
        pred_logits = outputs[:, -min_len:, :]
    elif pred.lower() == 'full':
        pred_logits = outputs
    else:
        raise ValueError("Invalid pred type. Choose from: first, middle, last, full")

    preds = pred_logits.argmax(dim=2)
    return preds.squeeze(0).cpu().tolist()


In [6]:
def clean_pred_tokens(tokens):
    return [t for t in tokens if t != 0]

In [7]:
df = pd.read_csv('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Dataset/icd10-codes-and-descriptions/Tokens.csv')
tokenizer = Tokenizer.from_file('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Notebook/bpe_tokenizer.json')

In [8]:
df.head(5)

Unnamed: 0.1,Unnamed: 0,code_padded,desc_padded
0,0,"[tensor(2), tensor(13998), tensor(4), tensor(5...","[tensor(2), tensor(15427), tensor(302), tensor..."
1,1,"[tensor(2), tensor(13998), tensor(4), tensor(6...","[tensor(2), tensor(15427), tensor(302), tensor..."
2,2,"[tensor(2), tensor(13998), tensor(4), tensor(1...","[tensor(2), tensor(15427), tensor(97), tensor(..."
3,3,"[tensor(2), tensor(8278), tensor(4), tensor(60...","[tensor(2), tensor(7741), tensor(3267), tensor..."
4,4,"[tensor(2), tensor(8278), tensor(4), tensor(35...","[tensor(2), tensor(7741), tensor(4503), tensor..."


In [9]:
df["desc_padded"] = df["desc_padded"].apply(lambda x: torch.tensor(extract_tensor_ids(x)))
df["code_padded"] = df["code_padded"].apply(lambda x: torch.tensor(extract_tensor_ids(x)))

In [10]:
example_idx = 1250
source_tensor = df.loc[example_idx, 'desc_padded'].unsqueeze(0)
target_tensor = df.loc[example_idx, 'code_padded']

In [11]:
print("\nExample Sample:")
print(f"Source tensor: {source_tensor}")
print(f"Target tensor: {target_tensor}")


Example Sample:
Source tensor: tensor([[   2,  749, 3202,   67,  127,  903, 1087, 3920,  749, 3202,   67,  903,
         1087, 3920,    3,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]])
Target tensor: tensor([   2, 6670,    4,  276,    3])


In [12]:
print(f"Source   : {decode_tokens(df.loc[example_idx, 'desc_padded'].tolist(), tokenizer)}")
print(f"Target   : {decode_tokens(df.loc[example_idx, 'code_padded'].tolist(), tokenizer)}")

Source   : malignant melanoma of left eyelid including canthus malignant melanoma of eyelid including canthus
Target   : C43 . 12


In [13]:
models_info = {
    'RNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/RNN_first/checkpoints/best_model.pt',
    'LSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/LSTM_last/checkpoints/best_model.pt',
    'GRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/GRU_last/checkpoints/best_model.pt',
    'DeepRNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepRNN/checkpoints/best_model.pt',
    'DeepLSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepLSTM/checkpoints/best_model.pt',
    'DeepGRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/DeepGRU/checkpoints/best_model.pt',
    'BiRNN': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiRNN/checkpoints/best_model.pt',
    'BiLSTM': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiLSTM/checkpoints/best_model.pt',
    'BiGRU': '/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/BiGRU/checkpoints/best_model.pt'
}

In [14]:
print("\n Example:")
print(f"Source   : {decode_tokens(df.loc[example_idx, 'desc_padded'].tolist(), tokenizer)}")
print(f"Target   : {decode_tokens(df.loc[example_idx, 'code_padded'].tolist(), tokenizer)}")

print("\n Model Predictions:")
for model_name, model_path in models_info.items():
    model = load_seq_model(model_path)
    if model_name in ['DeepRNN', 'LSTM', 'GRU', 'DeepLSTM', 'DeepGRU']:
        pred_tokens = predict(model, source_tensor,target_tensor,pred="last")
    elif model_name in ['BiRNN', 'BiLSTM', 'BiGRU']:
        pred_tokens = predict(model, source_tensor,target_tensor, pred="middle")
    else:
        pred_tokens = predict(model, source_tensor,target_tensor, pred="first")     
    pred_tokens = clean_pred_tokens(pred_tokens)
    pred_text = decode_tokens(pred_tokens, tokenizer)
    print(f"{model_name} Predicted: {pred_text}")


 Example:
Source   : malignant melanoma of left eyelid including canthus malignant melanoma of eyelid including canthus
Target   : C43 . 12

 Model Predictions:


  model = torch.load(pth_file, map_location=device)


RNN Predicted: C50 . 72
LSTM Predicted: C43 . 12
GRU Predicted: C43 . 12
DeepRNN Predicted: I66 . 21
DeepLSTM Predicted: C43 . 12
DeepGRU Predicted: C43 . 12
BiRNN Predicted: C43 . 12
BiLSTM Predicted: C43 . 12
BiGRU Predicted: C43 . 12


# Transformer

In [16]:
df = pd.read_csv('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Dataset/icd10-codes-and-descriptions/Codes&Desc_cleaned.csv')

In [39]:
single_df = df.iloc[[1]]  

single_dataset = TransformerDataset(single_df, tokenizer, 'Description', 'ICD_Code', seq_len=128)

sample = single_dataset[0]

encoder_input = sample['encoder_input']
decoder_input = sample['decoder_input']
encoder_mask = sample['encoder_mask']
decoder_mask = sample['decoder_mask']
label = sample['label']
src_text = sample['src_text']
tgt_text = sample['tgt_text']

print("Raw Tensors:")
print(f"Encoder Input IDs: {encoder_input}")
print(f"Decoder Input IDs: {decoder_input}")
print(f"Encoder Mask Shape: {encoder_mask.shape}")
print(f"Decoder Mask Shape: {decoder_mask.shape}")
print(f"Target Labels IDs: {label}")

print("\nText Decoding:")
encoder_text = tokenizer.decode(encoder_input.tolist(), skip_special_tokens=False)
decoder_text = tokenizer.decode(decoder_input.tolist(), skip_special_tokens=False)
target_text = tokenizer.decode(label.tolist(), skip_special_tokens=False)

print(f"Source Text: {src_text}")
print(f"Target Text: {tgt_text}")

print(f"\nDecoded Encoder Text (input to model): {encoder_text}")
print(f"Decoded Decoder Text (input to model decoder): {decoder_text}")
print(f"Decoded Target Text (true labels): {target_text}")



Preprocessing and tokenizing dataset...


Tokenizing: 100%|██████████| 1/1 [00:00<00:00, 2880.70it/s]

Loaded 1 valid examples.
Raw Tensors:
Encoder Input IDs: tensor([    2,     2, 15427,   302,   136, 12988, 15428,   353, 15921, 19766,
            3,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
       




In [40]:
def predict_transformer(model, sample, tokenizer):
    model.eval()
    encoder_input = sample['encoder_input'].unsqueeze(0).to(device)
    encoder_mask = sample['encoder_mask'].to(device)
    pad_token_id = tokenizer.token_to_id("[PAD]")
    sos_token_id = tokenizer.token_to_id("[SOS]")
    eos_token_id = tokenizer.token_to_id("[EOS]")
    decoder_input = torch.tensor([[sos_token_id]], device=device)
    generated_tokens = []
    target_tokens = sample['label']
    max_len = target_tokens.size(0) if target_tokens is not None else encoder_input.size(1)
    for _ in range(max_len):
        tgt_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).to(device)
        with torch.no_grad():
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, tgt_mask)
            proj_output = model.project(decoder_output)
        next_token_logits = proj_output[:, -1, :]  
        next_token = next_token_logits.argmax(dim=-1) 
        if next_token.item() == eos_token_id:
            break
        generated_tokens.append(next_token.item())
        decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
    return generated_tokens

In [41]:
model = torch.load('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/Transformer/checkpoints/best_model.pt', map_location=device) 
model.to(device)
model.eval()

  model = torch.load('/home/careinfolab/Dr_Luo/Rohan/ICD_Codes/Results/Desc_to_ICD/Transformer/checkpoints/best_model.pt', map_location=device)


Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-3): 4 x EncoderBlock(
        (self_attn): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=128, out_features=128, bias=False)
          (w_k): Linear(in_features=128, out_features=128, bias=False)
          (w_v): Linear(in_features=128, out_features=128, bias=False)
          (w_o): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.23262214988255547, inplace=False)
        )
        (ff): FeedForwardBlock(
          (linear_1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.23262214988255547, inplace=False)
          (linear_2): Linear(in_features=2048, out_features=128, bias=True)
        )
        (res_conns): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.23262214988255547, inplace=False)
            (norm): LayerNormalization()
          )
        )
      )
    )
    (norm): LayerNor

In [42]:
encoder_input = sample['encoder_input']
target_tokens = sample['label']

print("\nExample:")
source_text = tokenizer.decode(encoder_input.tolist(), skip_special_tokens=True)
print("Source:", source_text)
target_text = tokenizer.decode(target_tokens.tolist(), skip_special_tokens=True)
print("Target:", target_text)
pred_tokens = predict_transformer(model, sample, tokenizer)
pred_text = tokenizer.decode(pred_tokens, skip_special_tokens=True)
print(f"Prediction : {pred_text}")


Example:
Source: cholera due to vibrio cholerae 01 biovar eltor
Target: A00 . 1
Prediction : A00 . 1
