In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ensemble_utils import ensemble_generate

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
T5_CLIN_LARGE_PATH = '../../../BioNLP2023/al826/supervised/physionet.org/files/clinical-t5/1.0.0/Clinical-T5-Large'
TEST_PATH = '../../../BioNLP2023/released_14.04.2023/bionlp-workshop-2023-shared-task-1a-problem-list-summarization-1.1.0/BioNLP2023-1A-Test.csv'

In [5]:
def load_model(path, device):
    model = AutoModelForSeq2SeqLM.from_pretrained(T5_CLIN_LARGE_PATH, return_dict=True)
    state = torch.load(path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    model.to(device)
    return model

In [6]:
# Create an ensemble of three models
base_dir = "../../../BioNLP2023/al826/supervised/experiments/submissions-yf"
ensemble = [
    load_model(f"{base_dir}/models/clint5-large-A-1e5-10e11/seed-3/models/model.pt", device),
    load_model(f"{base_dir}/models/clint5-large-AS-1e5-10e11/seed-3/models/model.pt", device),
    load_model(f"{base_dir}/models/clint5-large-AS-1e5-10e11/seed-5/models/model.pt", device),
]

In [7]:
# Load data and create a tokenizer
data = pd.read_csv(TEST_PATH)
tokenizer = AutoTokenizer.from_pretrained(T5_CLIN_LARGE_PATH, return_dict=True)

In [8]:
# Run inference on test sample (index=0)

idx = 0 

# Input = {A} format 
input_text_A = data.iloc[idx]['Assessment']
inputs_A = tokenizer(input_text_A, return_tensors="pt").to(device)

# Input = {A}+{S} format 
input_text1 = data.iloc[idx]['Assessment']
input_text2 = data.iloc[idx]['Subjective Sections']
input_text_AS = f"{input_text1} \\nSSSS\\n{input_text2}"
inputs_AS = tokenizer(input_text_AS, return_tensors="pt").to(device)

In [9]:
# Token-level Ensemble Generation
summary_ids = ensemble_generate(
        ensemble,
        [inputs_A.input_ids, inputs_AS.input_ids, inputs_AS.input_ids],
        num_beams=4,
        length_penalty=0.6,
        max_length=256,
        min_length=5,
        no_repeat_ngram_size=4,
)
summary_txt = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)

In [10]:
print(summary_txt)

1. Coronary Artery Disease; 2. Epistaxis; 3. Acute Renal Failure; 4. Chronic Obstructive Pulmonary Disease
