In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

!git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
!cd mixtral-offloading && pip install -q -r requirements.txt

!huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo \
--quiet --local-dir \
Mixtral-8x7B-Instruct-v0.1-offloading-demo

clear_output()

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m73.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.3/78.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m73.3 MB/s[0m eta [36m0:00:00

In [None]:
import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

In [None]:
!huggingface-cli login

In [None]:
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(quantized_model_name)

device = torch.device("cuda:0")

##### Change this to 5 if you have only 12 GB of GPU VRAM #####
offload_per_layer = 4
# offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)

attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

In [None]:
from transformers import TextStreamer
import re
import json
import pandas as pd


tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
df_val = pd.read_csv('MEDIQA-CORR-2024-MS-ValidationSet-1-Full.csv')
print(f"df_val shape: {df_val.shape}")
df_val['mixtral_error_flag'] = 0
df_val['mixtral_reason'] = ''
for i, row in df_val.iterrows():
  if i<200:
    continue
  else:
    print(f"index={i}",flush=True)
    past_key_values = None
    sequence = None
    seq_len = 0
    text = row['Text']
    user_input=f"""You are an AI trained in medical knowledge. Below are examples of clinical texts (delimited by triple quotes) followed by analysis of whether there is a diagnostic error and, if so, the reason for error.
    ####
    Here are some examples:
    Clinical Text: ```A 17-year-old boy comes to the physician because of body aches and sore throat for 1 week. He has no history of serious illness and takes no medications. He lives with his parents; they recently adopted a cat from an animal shelter. He is sexually active with one female partner, and they use condoms consistently. His temperature is 38.7 C (101.7 F), pulse is 99/min, and blood pressure is 110/72 mm Hg. Examination shows bilateral posterior cervical lymphadenopathy. The pharynx is red and swollen. Laboratory studies show:
    Hemoglobin 15 g/dL
    Leukocyte count 11,500/mm3
    Segmented neutrophils 48%
    Band forms 2%
    Basophils 0.5%
    Eosinophils 1%
    Lymphocytes 45%
    Monocytes 3.5%
    When the patient's serum is added to a sample of horse erythrocytes, the cells aggregate together. The causal pathogen is cytomegalovirus.```
    Output= {{"Error": "yes",
    "Reason": "the sentence stating 'The causal pathogen is cytomegalovirus' is incorrect. The correct medical diagnosis should be 'the causal pathogen is Epstein-Barr virus'."
    }}
    Clinical Text: ```A previously healthy 18-year-old woman comes to the emergency department because of diarrhea and abdominal cramps since the previous evening. She has had around 3Ã¢â‚¬â€œ4 episodes of watery stools. She feels nauseous and has vomited twice. She recollects eating out 2 days ago. She has been on a vegan diet for 6 months. She takes no medications and has not traveled anywhere recently. Her temperature is 36.8 (98.2 F), pulse is 73/min, and blood pressure is 110/70 mm Hg. Examination shows dry mucous membranes. Abdominal examination is unremarkable. Norovirus was determined to be the casual organism.```
    Output={{
    "Error": "No",
    "Reason": "the context given in text aligns with the diagnosis"
    }}
    Clinical Text: ```A 5-year-old boy is brought to the emergency department by his grandmother because of difficulty breathing. Over the past two hours, the grandmother has noticed his voice getting progressively hoarser and occasionally muffled, with persistent drooling. He has not had a cough. The child recently immigrated from Africa, and the grandmother is unsure if his immunizations are up-to-date. He appears uncomfortable and is sitting up and leaning forward with his chin hyperextended. His temperature is 39.5 C (103.1 F), pulse is 110/min, and blood pressure is 90/70 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 95%. Pulmonary examination shows inspiratory stridor and scattered rhonchi throughout both lung fields, along with poor air movement. Pharyngoscopy is ordered. ```
    Output= {{
      "Error":"yes",
      "Reason": "The clinical text says 'Pharyngoscopy is ordered' but it should have been 'Nasotracheal intubation is performed' as per the context given in clinical text"
    }}
    Clinical Text:```A 28-year-old man is brought to the emergency department after being struck by a car an hour ago as he was crossing the street. He did not lose consciousness. He is complaining of pain in his right arm, forehead, and pelvis. He also has the urge to urinate, but has been unable to do so since the accident. He takes no medications. His temperature is 37.1 C (98.9 F), pulse is 72/min, respirations are 18/min, and blood pressure is 118/82 mm Hg. There are abrasions over his scalp and face and a 1x3 cm area of ecchymosis above his right eye. Abdominal examination shows suprapubic tenderness. There is a scant amount of blood at the urethral meatus. There is no cervical spinal tenderness. Musculoskeletal examination shows tenderness and ecchymosis over his right distal forearm. An x-ray of the pelvis shows a fracture of the pelvic ramus. Retrograde urethrogram was then performed. A CT scan of the head and neck show no abnormalities.```
    Output = {{
      "Error": "No",
      "Reason": "the context given in text aligns with the diagnosis."
    }}
    ###
    Now, you are given below a new clinical text delimited by triple quotes. Carefully evaluate and analyse the information presented in clinical text such as symptoms, clinical examination findings, patient history and other details. Determine if any of the given sentences contain a diagnostic error or not. Use your knowledge and the context provided to make your assessment. Provide the output only in JSON Format with the following keys. Do not provide explanations or notes.
    Error, Reason
    ```
    Clinical Text: {text}
    ```
    """
    user_entry = dict(role="user", content=user_input)
    input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)

    if past_key_values is None:
        attention_mask = torch.ones_like(input_ids)
    else:
        seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
        attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)

    # print("Mixtral: ", end="")
    result = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        past_key_values=past_key_values,
        # streamer=streamer,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        max_new_tokens=200,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_hidden_states=True,
      )
    output_text = tokenizer.decode(result['sequences'][0], skip_special_tokens=True)
    output_text = output_text.split('[/INST]')[1].strip()
    error_index = output_text.find('"Error": "') + len('"Error": "')
    error_end_index = output_text.find('"', error_index + 1)
    error = output_text[error_index:error_end_index]
    reason_index = output_text.find('"Reason": "') + len('"Reason": "')
    reason_end_index = output_text.find('"', reason_index + 1)
    reason = output_text[reason_index:reason_end_index]

    # print(error)
    # print(reason)
    df_val.at[i, 'mixtral_error_flag'] = error
    df_val.at[i,'mixtral_reason'] = reason

    sequence = result["sequences"]
    past_key_values = result["past_key_values"]

    if(i%5==0):
      print("saving the file",flush=True)
      df_val.to_csv('/content/drive/MyDrive/EA_ST/mixtral_predictions_200.csv',index=False)