In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [4]:
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={'': 0}
)
model = PeftModel.from_pretrained(model, "mistralai-medical")
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
model.hf_device_map

{'': 0}

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [8]:
device = "cuda:0"

In [22]:
def stream(prompt, step):
    messages = [
        {"role": "user", "content": f"Given the input: {prompt}\n\n{step}\n\n"}
    ]
    
    
    encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
    
    model_inputs = encodeds.to(device)
    
    
    generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
    decoded = tokenizer.batch_decode(generated_ids)
    return decoded[0]
    # print(decoded[0])

In [29]:
def entire_process(prompt):
    step1 = "Extract medical terms from the input above and number them line by line:\n"
    step1_output = stream(prompt, step1)
    print(step1_output)
    print('\n'*3)
    # print(step1_output[len(prompt) + 9:].replace('[/INST]', '').replace('<s>', '').replace('</s>', '').strip())
    # print(step1_output.replace('[INST]', '').replace('[/INST]', '').replace('<s>', '').replace('</s>', '').strip())
    
    input_sentence = input("Enter the context you want to find:\n")
    step2 = f"Extract set of sentences from above context which have the same meaning as: {input_sentence}"
    step2_output = stream(prompt, step2)
    # print(step2_output)
    print(step2_output.replace("<s> [INST] Given the input:", '').replace(prompt, '').replace('</s>', '').replace('[/INST]', ''))
    # print(step2_output.replace("<s> [INST] Given the input:", '').replace(prompt, ''))
    print('\n'*3)
    # print(step2_output[len(prompt) + 9:].replace('[INST]', '').replace('[/INST]', '').replace('<s>', '').replace('</s>', '').strip())
    
    step3 = "Summarize the content above:\n"
    step3_output = stream(prompt, step3)
    print(step3_output.replace("<s> [INST] Given the input:", '').replace(prompt, '').replace('</s>', '').replace('[/INST]', ''))
    # print(step3_output.replace("<s> [INST] Given the input:", '').replace(prompt, ''))
    print('\n'*3)
    # print(step3_output[len(prompt) + 9:].replace('[INST]', '').replace('[/INST]', '').replace('<s>', '').replace('</s>', '').strip())    

In [30]:
with open('sample_input.txt', 'r') as f:
    prompt = f.read()

In [31]:
entire_process(prompt)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> [INST] Given the input: Most Responsible Diagnosis: COPD Exacerbation  Active Issues Managed in Hospital:  Pulmonary edema Microcytic anemia Gout Purpuric rash NYD  Course in Hospital:  Mr. Johnson arrived in the ER from nursing home with a three-day history of worsening shortness of breath, yellow-green sputum, and increased sputum production. He was subsequently diagnosed with a COPD exacerbation and was satting at 84% on 4L O2 by nasal prongs. He was stepped up to BiPAP for 24 hours and prednisone, ciprofloxacin, and around the clock puffers were initiated. By day 2 of admission he was stepped down to oxygen by nasal prongs and QID puffers.  In terms of respiratory complications, Mr. Johnson had a sudden hypoxic resp failure on day 3 of admission. CCOT was involved, but ICU was avoided. He was found to be in pulmonary edema that responded to diuresis. Last documented echo was completed 15 years ago and a repeat echo was ordered to be completed as an outpatient.    Unfortunately 

Enter the context you want to find:
 microcytic anemia


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


 

Extract set of sentences from above context which have the same meaning as: microcytic anemia

  "Mr. Johnson was found to have a microcytic anemia." 
"Lastly, upon admission Mr. Johnson was found to have a microcytic anemia." 
"Further testing revealed iron deficiency anemia and therapy with ferrous fumarate was initiated."




 

Summarize the content above:


  The patient is a male with a COPD exacerbation who was admitted to the hospital on April 18th, 2019 with a three-day history of worsening shortness of breath, yellow-green sputum, and increased sputum production. He arrived in the ER from a nursing home. On admission, he was diagnosed with a COPD exacerbation and was treated with BiPAP, prednisone, ciprofloxacin, and around the clock puffers. By day 2 of admission, he was stepped down to oxygen by nasal prongs and a puffers regimen. However, he experienced a sudden hypoxic respiratory failure on day 3 of admission, which required the involvement of CCOT but avoided ICU adm