In [None]:
from google.colab import drive
import torch
import pandas as pd
import json
from tqdm import tqdm
from collections import Counter
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

In [None]:
print(torch.__version__)

2.3.0+cu121


In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
with open('/content/drive/MyDrive/ML-Quiz-XRay-ReportGeneration/data/annotation_quiz_all.json', 'r') as f:
  data = json.load(f)
df = pd.DataFrame(data['val'])

In [None]:
correct_format = """{
  "lung": "Lungs are mildly hypoinflated but grossly clear of focal airspace disease, pneumothorax, or pleural effusion. Pulmonary vasculature are within normal limits in size.",
  "heart": "Cardiac silhouette within normal limits in size.",
  "mediastinal": "Mediastinal contours within normal limits in size.",
  "bone": "Mild degenerative endplate changes in the thoracic spine. No acute bony findings.",
  "others": ""
  }"""
report_format = "Lungs are mildly hypoinflated but grossly clear of focal airspace disease, pneumothorax, or pleural effusion. Pulmonary vasculature are within normal limits in size. Cardiac silhouette within normal limits in size. Mild degenerative endplate changes in the thoracic spine. No acute bony findings. Mediastinal contours within normal limits in size."

In [None]:
prompt = """
You are a medical practitioner, tasked with reorganizing a X-RAY report into pre-defined anatomical region. All output must be in valid JSON. Don't add explanation beyond the JSON.
The anatomical regions are as follows: lungs, heart, mediastinal, bone, others. If the part of the report (one or many line) is about lungs than put into \"lungs\", so as heart, bones, and mediastinal. If one line is part of two regions, add it into both the regions. However, if you can not put any part of the report into these four regions than put it into \"others\".

An example report is as follows:
Report: \"{}\"
The output of that report will be:
Output: \"{}\"
Note that, output should contain ONLY valid JSON."
"""

In [None]:
model_id = "hiieu/Meta-Llama-3-8B-Instruct-function-calling-json-mode"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
messages = [
    {"role": "system", "content": prompt.format(report_format,correct_format)}]

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [None]:
result = []
for item in tqdm(df.iterrows(), total=df.shape[0]):
  messages.append({"role": "user", "content": item[1]['original_report']})

  input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
  ).to(model.device)

  outputs = model.generate(
    input_ids,
    max_new_tokens=round(len(tokenizer(item[1]['original_report'])['input_ids'])*10),
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1, # To reduce model creativity
    top_p=0.9, # To reduce randomness
    pad_token_id=tokenizer.eos_token_id
  )
  response = outputs[0][input_ids.shape[-1]:]
  output = tokenizer.decode(response, skip_special_tokens=True)
  try:
    output = json.loads(output)
    output['status'] = 'successfull'
  except Exception as ex:
    print(ex, output)
    output = dict()
    output['status'] = 'failed'
  output['id'] = item[1]['id']
  result.append(output)
  del messages[-1]

100%|██████████| 296/296 [57:46<00:00, 11.71s/it]


In [None]:
with open('/content/drive/MyDrive/ML-Quiz-XRay-ReportGeneration/Task1_Result.json', 'w') as f:
  json.dump(result, f)

In [None]:
data_count = Counter([item['status'] for item in result])
print(data_count)

Counter({'successfull': 296})
