In [None]:
'''Loading packages'''
import pickle
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from transformers import (AutoTokenizer,
                          AutoModel,
                          AutoConfig,
                          BertTokenizer,
                          BertModel
                         )

def loadBert(device, language_model, num_labels):
    print(language_model, 'is used as the language model.')
    assert language_model != None, "language_model is None."
    if language_model == 'BioBert':
        tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        BioBert=AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    elif language_model =="bioRoberta":
        config = AutoConfig.from_pretrained("allenai/biomed_roberta_base", num_labels = num_labels)
        tokenizer = AutoTokenizer.from_pretrained("allenai/biomed_roberta_base")
        BioBert = AutoModel.from_pretrained("allenai/biomed_roberta_base")
    elif  language_model == "Bert":
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        BioBert = BertModel.from_pretrained("bert-base-uncased")
    elif language_model == "bioLongformer":
        tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
        BioBert= AutoModel.from_pretrained("yikuan8/Clinical-Longformer")
    elif language_model == "ClinicalBERT":
        tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
        BioBert = AutoModel.from_pretrained("medicalai/ClinicalBERT")
    else:
        raise ValueError("language_model should be BioBert, bioRoberta, bioLongformer, ClinicalBERT or Bert")
    
    for param in BioBert.parameters():
        param.requires_grad = False
    
    BioBert = BioBert.to(device)
    return BioBert, tokenizer

'''Frozen Text Representation'''
class BertForRepresentation(nn.Module):
    def __init__(self, BioBert, language_modelname):
        super().__init__()
        self.bert = BioBert
        self.language_model = language_modelname
        if self.language_model in ['ClinicalBERT']:
            self.dropout = torch.nn.Dropout(BioBert.config.dropout)
        else:
            self.dropout = torch.nn.Dropout(BioBert.config.hidden_dropout_prob)

    def forward(self, input_ids_sequence, attention_mask_sequence):
        txt_arr = []
        for input_ids, attention_mask  in zip(input_ids_sequence, attention_mask_sequence):
            text_embeddings=self.bert(input_ids, attention_mask=attention_mask)
            text_embeddings= text_embeddings[0][:,0,:]
            text_embeddings = self.dropout(text_embeddings)
            txt_arr.append(text_embeddings)
        return torch.stack(txt_arr)

In [25]:
'''Checking if GPU is available'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
language_model = "ClinicalBERT"
BioBert, tokenizer = loadBert(device, language_model, num_labels = 2)
text_extractor = BertForRepresentation(BioBert, language_model).to(device)

def template_representations(templates, text_extractor, device):
    text_token=[]
    attention_mask=[]
    for template in templates:
        template_codes = tokenizer(template, 
                                   padding=True,
                                   max_length=512,
                                   add_special_tokens=True,
                                   return_attention_mask = True,
                                   truncation=True)
        text_token.append(torch.tensor(template_codes['input_ids'], dtype=torch.long))
        attention_mask.append(torch.tensor(template_codes['attention_mask'], dtype=torch.long))
    
    print(text_token[0], attention_mask[0].shape)
    text_token, attention_mask = padding(text_token, attention_mask, max_length = 512)
    
    text_token = text_token.unsqueeze(dim=0)
    attention_mask = attention_mask.unsqueeze(dim=0)

    text_token = text_token.clone().detach().to(device, dtype=torch.long)
    attention_mask = attention_mask.clone().detach().to(device, dtype=torch.long)
    return text_token, attention_mask, text_extractor(text_token, attention_mask)

def padding(text_token, atten_mask, max_length):
    text_token = pad_sequence(text_token, batch_first=True, padding_value=0) # dim_token * num_note 
    atten_mask = pad_sequence(atten_mask, batch_first=True, padding_value=0)

    print(text_token.shape)
    text_token = torch.nn.functional.pad(text_token, (0, max_length - text_token.size(1)), value=0)
    atten_mask = torch.nn.functional.pad(atten_mask, (0, max_length - atten_mask.size(1)), value=0)
    return text_token, atten_mask

ClinicalBERT is used as the language model.


In [26]:
'''Based on previous work, we extract the task template using hard-encoding templates'''
templates1 = ['The patient is admitted to the hospital after an emergency visit, belonging to class 1.', \
             'The patient did not require hospitalization after the emergency visit, belonging to class 0.', \
             'The patient [mask] after the emergency visit, belonging to class unknown.']

text_token, attention_mask, template_representations = template_representations(templates1, text_extractor, device)
with open('./config/task1_template.pkl', 'wb') as f:
    pickle.dump(template_representations, f)

# import pickle
# text = pickle.load(open('./config/task1_template.pkl', "rb"))

tensor([  101, 10105, 38607, 10124, 40345, 10114, 10105, 18141, 10662, 10151,
        44461, 27541,   117, 54188, 10114, 13596,   122,   119,   102]) torch.Size([19])
torch.Size([3, 19])


In [3]:
'''Based on previous work, we extract the task template using hard-encoding templates'''
templates2 = ['The patient died during their hospitalization or was urgently transferred to the Intensive Care Unit (ICU) within 12 hours, belonging to class 1.', \
             'The patient remained alive throughout their hospitalization and did not undergo an emergency transfer to the Intensive Care Unit (ICU) within 12 hours, belonging to class 0.', \
             'The patient [mask] throughout their hospitalization and [mask] to the Intensive Care Unit (ICU) within 12 hours, belonging to class unknown.']

text_token2, attention_mask2, template_representations2 = template_representations(templates2, text_extractor, device)
with open('./config/task2_template.pkl', 'wb') as f:
    pickle.dump(template_representations2, f)

import pickle
text = pickle.load(open('./config/task2_template.pkl', "rb"))