In [1]:
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

In [2]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


<a name="Data"></a>
### Load Data



In [3]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [4]:
!ls 'drive/Othercomputers/My Laptop/ML-Quiz-XRay-ReportGeneration/data'

annotation_quiz_all.json  images


In [5]:
import os

In [6]:
project_dir = 'drive/Othercomputers/My Laptop/ML-Quiz-XRay-ReportGeneration/'
data_dir = os.path.join(project_dir, 'data')
image_dir = os.path.join(data_dir, 'images')

In [7]:
import json

In [8]:
## Load json
anno_json_fname = os.path.join(data_dir, 'annotation_quiz_all.json')
with open(anno_json_fname, "r") as file:
    data = json.load(file)

### Run Conversion Using LLama

In [9]:
import re

In [10]:
def strip_response(input_text):
    pattern = '### Output:\n(.*)' # Use re.search to find the text after the specified pattern match = re.search(pattern, text, re.DOTALL)
    match = re.search(pattern, input_text, re.DOTALL) # re.DOTALL allows the dot to match newline characters
    if match:
        response_text = match.group(1)
        # print("Captured text:")
        response_text = response_text.replace('<|end_of_text|>', '')
        response_text = response_text.replace('```json', '')
        response_text = response_text.replace('```', '')
        # print(response_text)
        return response_text
    else:
        print("Pattern not found.")
        return 'invalid response'


def extract_json_from_response(input_text):
    """
    Function to extract json output from LLM prompt response
    """
    # see if it's even json format
    try:
        json_data = json.loads(input_text)
    except json.JSONDecodeError as e:
        print(f"Invalid JSON string: {e}")
        print(input_text)
        return 'invalid json'

    try:
        # Ensure that all required keys are present and are of right datatype
        all_relevant_keys = ['lung', 'heart', 'bone', 'mediastinal', 'others']
        for key_name in json_data.keys():
            assert key_name in all_relevant_keys
            assert type(json_data[key_name]) is str
        for key_name in all_relevant_keys:
            assert key_name in json_data.keys()
    except:
        return 'invalid keys or values'
    return json_data

In [11]:
strip_response('### Output:\n asdf')

' asdf'

In [12]:
json_format_dict={
    'lung': 'summary of lung category related findings or leave as empty string if no lung related findings.',
    'heart': 'summary of heart category or leave as empty string if no heart related findings.',
    'bone': 'summary of bone category related findings or leave as empty string if no bone related findings.',
    'mediastinal': 'summary of mediastinal category related findings or leave as empty string if no mediastinal related findings.',
    'others': 'summary of any other findings that are NOT lung related, NOT heart related, NOT bone related, NOT mediastinal related. Leave as empty string if no findings.'
}
json_format_str = json.dumps(json_format_dict)
# base_prompt="""
# ### Instruction:
# Convert the input text provided into a json formatted summary. Do NOT return code, return the json response.

# Common words and prefixes associated with each category

# lung: Lungs, lung, pleural, pneumothorax, pneumo, pleuro
# heart: cardiac, cardio, heart, cardiomediastinal, cardiomediastinum, aorta, vein, vasculature
# bone: bone, bony, spine, humerus, tibula, fibula, rib, osseous, osteo
# mediastinal: mediastinal, cardiomediastinal, mediastinum


# If there are no relevant findings for a given tissue, do not report anything.
# Do not return chatbot style output.
# Do not report false information.
# Only report on information present in the input.
# Keep text entries in json as close to the original text as possible.
# Keep response strictly in noted Response Format, with lung, heart, bone, and mediastinal entries.

# ### Input:
# {}

# ### Response Format:
# {}
# """
base_prompt="""
### Instruction:
Convert the input text provided into a json formatted summary as per the Response Format.

Common words and prefixes associated with each category

lung: Lungs, lung, pleural, pneumothorax, pneumo, pleuro
heart: cardiac, cardio, heart, cardiomediastinal, cardiomediastinum, aorta, vein, vasculature
bone: bone, bony, spine, humerus, tibula, fibula, rib, osseous, osteo
mediastinal: mediastinal, cardiomediastinal, mediastinum


If there are no relevant findings for a given tissue, do not report anything.
Do not return chatbot style output.
Do not report false information.
Only report on information present in the input.
Keep text entries in json as close to the original text as possible.
Do not copy examples or write additional examples.

Do not return python code or other code. Only return json formatted response

#### Example 1:
input:
'The heart is folded. Lungs appear distorted. Pneumothorax.'
output:
{{
    \"lung\": \"Lungs appear distorted. Pneumothorax\",
    \"heart\": \"The heart is folded\",
    \"bone\": \"\",
    \"mediastinal\": \"\",
    \"others\": \"\"
}}
#### Example 2:
input:
'The pulmonary artery appears to have a small tear, which is alarming. Lungs appear normal in size. Spinal fracture apparent, likely from blunt trauma. Mediastinum has normal curvature'
output:
{{
    \"lung\": \"Lungs appear normal in size\",
    \"heart\": \"The pulmonary artery appears to have a small tear\"
    \"bone\": \"Spinal fracture likely from blunt trauma\",
    \"mediastinal\": \"Mediastinum has normal curvature\",
    'others': ''
}}
#### Example 3:
input:
'The aorta is enlarged. Pleural leakage. No abnormalities in right lung. There is a broken rib that may have punctured the left lung. Mediastinum is disfigured. The XXXX appears normal'
output:
{{
    \"lung\": \"Pleural leakage. No abnormalities in right lung. The left lung appears punctured, possibly by a rib\",
    \"heart\": \"The aorta is enlarged\",
    \"bone\": \"Broken rib that may have punctured left lung\",
    \"mediastinal\": \"Mediastinum is disfigured\",
    \"others\": \"The XXXX appears normal\"
}}
#### Example 4:
input:
'Mediastinal curvature normal. Cardiac muscle appears atrophied. Patient appears to have osteoperosis. Connective tissue is damaged'
output:
{{
    \"lung\": \"\",
    \"heart\": \"Cardiac muscle appears atrophied\",
    \"bone\": \"Patient appears to have osteoperosis.\",
    \"mediastinal\": \"Mediastinal curvature normal\",
    \"others\": \"Connective tissue is damaged\"
}}

#### Example 5:
input:
'The cardiomediastinal silhouette and pulmonary vasculature are within normal limits in size. The lungs are mildly hypoinflated but grossly clear of focal airspace disease, pneumothorax, or pleural effusion. There are mild degenerative endplate changes in the thoracic spine. There are no acute bony findings.'
output:
{{
\"lung\": \"Lungs are mildly hypoinflated but grossly clear of focal airspace disease, pneumothorax, or pleural effusion. Pulmonary vasculature are within normal limits in size.\",
\"heart\": "Cardiac silhouette within normal limits in size.\",
\"mediastinal\": "Mediastinal contours within normal limits in size.\",
\"bone\": "Mild degenerative endplate changes in the thoracic spine. No acute bony findings.\",
\"others\": \"\"
}}
#### Example 6:
input:
'Mild degenerative changes of the spine. Heart size within normal limits. No pleural effusion or pneumothorax. No focal air space opacity to suggest pneumonia. Mediastinum curvature within normal limits'
output:
{{
\"lung\": \"No pleural effusion or pneumothorax. No focal air space opacity to suggest pneumonia.\",
\"heart\": "Heart size within normal limits\",
\"mediastinal\": "Mediastinum curvature within normal limits\",
\"bone\": "Mild degenerative changes of the spine.\",
\"others\": \"\"
}}

### Input:
{}

### Response Format:
{}


"""


first_prompt = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response in the requested format.""" +  base_prompt + """
### Output:
{}

"""

# prior_incorrect_prompt = """ Below is an instruction that describes a task, paired with an input that provides further context. The prior response given is incorrect (see Incorrect Response). Write a response in the requested format.""" +  base_prompt + """
# ### Incorrect Response:
# {}

# ### Output:
# {}
# """

In [13]:
# set to inference mode
FastLanguageModel.for_inference(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaExtendedRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): L

In [14]:
## loop through the validation set, modifying entries
n_items_val = len(data['val'])
temperature=0.5
for i in range(n_items_val):
    print('####################################################')
    print("Running llama for validation set item {} of {}".format(str(i+1), str(n_items_val)))
    val_i = data['val'][i]
    # we pop out original response
    original_report_i = val_i.pop('original_report')
    inputs = tokenizer(
    [
        first_prompt.format(
            original_report_i, # input
            json_format_str, # output format
            "", # output - leave this blank for generation!
        )
    ], return_tensors = "pt").to("cuda")
    # extract json
    outputs = model.generate(**inputs, max_new_tokens = 250, temperature=temperature, use_cache = True)
    content_result = strip_response(tokenizer.batch_decode(outputs)[0])
    max_tries = 100
    for j in range(max_tries + 1):
        print('try {}'.format(j+1))
        print("original report: {}".format(original_report_i))
        try:
            try:
                json_out_i = extract_json_from_response(content_result)
                print("json extraction result: {}".format(json_out_i))
                assert type(json_out_i) is dict
            except:
                if type(json_out_i) is str:
                    if json_out_i == 'invalid json':
                        try:
                            print('attempting appending of curly brace to end.')
                            # try correcting for the most common error, which is forgetting an end bracket.
                            # this is hacky but should deal with majority of errors
                            content_result_mod = content_result + '}'
                            json_out_i = extract_json_from_response(content_result_mod)
                            assert type(json_out_i) is dict
                        except:
                            # raise exception
                            raise ValueError('content result not fixed by adding end bracket, submit modified query')
                    else:
                        raise ValueError('content result could not have valid json extracted, reason: {}'.format(json_out_i))
                else:
                    # just set json_out_i to invalid json.
                    json_out_i = 'invalid json'
                    raise ValueError('unknown output from extract_json_from_response. suggest debugging that function.')

            print('succeeded after {} tries'.format(str(j + 1)))
            val_i['report'] = json_out_i
            break
        except:
            if j == max_tries:
                print(content_result)
                raise ValueError('Reached Max Tries')
            del inputs
            del outputs
            inputs = tokenizer(
            [
                first_prompt.format(
                    original_report_i, # input
                    json_format_str, # output format
                    "", # output - leave this blank for generation!
                )
            ], return_tensors = "pt").to("cuda")
            outputs = model.generate(**inputs, max_new_tokens = 250, temperature=temperature, use_cache = True)
            content_result = strip_response(tokenizer.batch_decode(outputs)[0])


[1;30;43mStreaming output truncated to the last 5000 lines.[0m

{"lung": "The lungs are clear.", "heart": "The heart is normal in size.", "bone": "", "mediastinal": "The mediastinum is unremarkable", "others": ""}

### Input
json extraction result: invalid json
attempting appending of curly brace to end.
Invalid JSON string: Extra data: line 5 column 1 (char 153)


{"lung": "The lungs are clear.", "heart": "The heart is normal in size.", "bone": "", "mediastinal": "The mediastinum is unremarkable", "others": ""}

### Input:
The heart is normal in size. The mediastinum is unremarkable. The lungs are clear.

### Response Format:
{"lung": "summary of lung category related findings or leave as empty string if no lung related findings.", "heart": "summary of heart category or leave as empty string if no heart related findings.", "bone": "summary of bone category related findings or leave as empty string if no bone related findings.", "mediastinal": "summary of mediastinal category related

### Save Output

In [15]:
top_results_dir = os.path.join(project_dir, 'results')
if not os.path.exists(top_results_dir):
    os.mkdir(top_results_dir)
output_dir = os.path.join(top_results_dir, 'task1_convert_validation_annotations')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

output_fname = os.path.join(output_dir, 'annotation_quiz_all_modified.json')
with open(output_fname, 'w') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)