In [1]:
import torch
import numpy as np
import torch
from PIL import Image
import random

# Load Checkpoint for Inference

In [13]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def inference(model, processor, image_paths, seed=42):

    image_tokens = "<image>"
    prompt = f"<|im_start|>user {image_tokens}\nProvide a description of the findings and impressions in the radiology images given the following images of the study.|im_end|><|im_start|>assistant"

    # chat template in interleaved format work same as in sampling videos. Just pass in as many images you want for a prompt
    conversation = [
        {

        "role": "user",
        "content": [
            {"type": "text", "text": "Provide a description of the findings and impressions in the radiology images given the following images of the study."},
            {"type": "image"},
            ],
        },
    ]

    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)


    images = []   
    for image_path in image_paths:
        images.append(
            Image.open(image_path).convert("RGB")
        )

    inputs_images = processor(text=prompt, images=images, padding=True, return_tensors="pt").to(model.device)

    # Ensure correct dtypes for different inputs
    for key, value in inputs_images.items():
        if torch.is_tensor(value):
            if key in ['input_ids', 'attention_mask']:
                inputs_images[key] = value.long()
            elif key == 'pixel_values':
                inputs_images[key] = value.half()

    # Use deterministic generation settings
    output = model.generate(
        **inputs_images,
        max_new_tokens=512,
        do_sample=True,  # Use greedy decoding
        temperature=0.7,  # Fixed temperature
        # num_beams=1,  # No beam search
        top_k=1,  # Only consider the most likely token
        top_p=0.9,  # No nucleus sampling
    )
    
    print(processor.decode(output[0], skip_special_tokens=True))


In [14]:
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

# Set a global seed for reproducibility
GLOBAL_SEED = 42
set_seed(GLOBAL_SEED)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4", 
)

original_model_id = "llava-hf/llava-interleave-qwen-7b-hf" 
# processor is not changed so we still load from the original model repo

image_path = [
      "/working/datasets/mimic-cxr/mimic-cxr-images-512/img/p10/p10046166/s50051329/427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg",
    ]

processor = AutoProcessor.from_pretrained(original_model_id)

model_id = "/working/rajan/multiview-llm/Models/Finetune/lmms-finetune/checkpoints/llava-interleave-qwen-7b_lora-True_qlora-True_single"
new_model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True, 
)
inference(new_model, processor, image_path, seed=GLOBAL_SEED)
del new_model


Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.53s/it]
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


user

Provide a description of the findings and impressions in the radiology images given the following images of the study.
assistant
There is a new right-sided chest tube with a small right apical pneumothorax. There is a small right pleural effusion. There is a right-sided chest wall subcutaneous emphysema. There is a right-sided rib fracture. There is a right-sided chest wall deformity. There is a left-sided rib fracture. There is a left-sided chest wall deformity. There is a left-sided chest tube. There is a small left apical pneumothorax. There is a small left pleural effusion. There is a left-sided rib fracture. There is a right-sided rib fracture. There is a right-sided chest wall deformity. There is a right-sided rib fracture. There is a right-sided chest wall deformity. There is a right-sided rib fracture. There is a right-sided rib fracture. There is a right-sided rib fracture. There is a right-sided rib fracture. There is a right-sided rib fracture. There is a right-sided r

In [15]:
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

# Set a global seed for reproducibility
GLOBAL_SEED = 42
set_seed(GLOBAL_SEED)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4", 
)

original_model_id = "llava-hf/llava-interleave-qwen-7b-hf" 
# processor is not changed so we still load from the original model repo
image_paths = [
      "/working/datasets/mimic-cxr/mimic-cxr-images-512/img/p10/p10046166/s50051329/427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg",
]

processor = AutoProcessor.from_pretrained(original_model_id)


old_model = LlavaForConditionalGeneration.from_pretrained(
    original_model_id, 
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True, 
)
print("Original Model Output:")
inference(old_model, processor, image_path, seed=GLOBAL_SEED)
del old_model


model_id = "/working/rajan/multiview-llm/Models/Finetune/lmms-finetune/checkpoints/llava-interleave-qwen-7b_lora-True_qlora-True_single"
new_model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True, 
)
print("Fine-tuned Model Output:")
inference(new_model, processor, image_paths, seed=GLOBAL_SEED)
del new_model



Loading checkpoint shards: 100%|██████████| 4/4 [00:15<00:00,  3.84s/it]
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Original Model Output:
user

Provide a description of the findings and impressions in the radiology images given the following images of the study.
assistant
The images provided are X-ray films, which are used to visualize the internal structure of the body. In this case, the X-ray is focused on the chest area, allowing us to see the bones, organs, and other structures within the chest cavity.

The bones are prominent, with the spine, rib cage, and pelvis being clearly visible. The ribs appear as a series of curved bones that encase the heart and lungs. The spine is a complex structure with multiple vertebrae, which are the bony blocks that make up the back.

The lungs are also visible, with their characteristic wedge-shaped structure and the presence of air pockets. The heart is recognizable by its four chambers and the presence of valves that regulate blood flow.

There are no significant abnormalities or injuries visible in the X-ray images provided. The bones appear healthy and wel

Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.49s/it]
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Fine-tuned Model Output:
user

Provide a description of the findings and impressions in the radiology images given the following images of the study.
assistant
There is a new right-sided chest tube with tip projecting over the right lung apex. There is a small right apical pneumothorax. There is a small right pleural effusion. There is a small left pleural effusion. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left apical pneumothorax. There is a small right apical pneumothorax. There is a small left 

# Generate

In [6]:
import json
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import os

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_model_and_processor(model_path, original_model_id):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
    )
    
    processor = AutoProcessor.from_pretrained(original_model_id)
    model = LlavaForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
        low_cpu_mem_usage=True,
    )
    
    return model, processor

def generate_description(model, processor, image_paths):
    images = [Image.open(img_path).convert("RGB") for img_path in image_paths]

    # chat template in interleaved format work same as in sampling videos. Just pass in as many images you want for a prompt
    conversation = [
        {
        "role": "user",
        "content": [
            {"type": "text", "text": "Provide a description of the findings and impressions in the radiology images given the following images of the study."},
            {"type": "image"},
            ],
        },
    ]

    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(text=prompt, images=images, padding=True, return_tensors="pt").to(model.device)
    
    for key, value in inputs.items():
        if torch.is_tensor(value):
            if key in ['input_ids', 'attention_mask']:
                inputs[key] = value.long()
            elif key == 'pixel_values':
                inputs[key] = value.half()
    
    output = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_k=1,
        top_p=0.9,
    )

    # temperature=1.0,  # Fixed temperature
    # num_beams=1,  # No beam search
    # top_k=1,  # Only consider the most likely token
    # top_p=1.0,  # No nucleus sampling
    
    return processor.decode(output[0], skip_special_tokens=True)

def extract_study_id(image_path):
    # Split the path and find the component starting with 's'
    path_components = image_path.split('/')
    for component in path_components:
        if component.startswith('s') and component[1:].isdigit():
            return component
    return None  # Return None if no valid study_id is found

def evaluate_test_set(test_set_path, model_path, original_model_id, output_path):
    set_seed(42)  # Set a fixed seed for reproducibility
    model, processor = load_model_and_processor(model_path, original_model_id)
    
    with open(test_set_path, 'r') as f:
        test_set = json.load(f)
    
    results = []
    
    for item in test_set:
        study_id = extract_study_id(item['image'][0])
        if study_id is None:
            print(f"Warning: Could not extract study_id from {item['image'][0]}")
            continue
        
        ground_truth = item['conversations'][1]['value']
        # system_prompt = item['system_prompt']
        
        generated_output = generate_description(model, processor, item['image'])
        
        results.append({
            'study_id': study_id,
            'ground_truth': ground_truth,
            'generated_output': generated_output
        })
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Evaluation complete. Results saved to {output_path}")


In [13]:
# Evaluate
test_set_path = '/working/rajan/multiview-llm/Models/Finetune/dataset/data/mimic_cxr_single_test_findings.json'
model_path = '/working/rajan/multiview-llm/Models/Finetune/lmms-finetune/checkpoints/llava-interleave-qwen-7b_lora-True_qlora-True_single'
original_model_id = 'llava-hf/llava-interleave-qwen-7b-hf'
output_path = 'outputs/results_single.json'

evaluate_test_set(test_set_path, model_path, original_model_id, output_path)

Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00,  2.72s/it]
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for

Evaluation complete. Results saved to outputs/results_single.json


In [14]:
def extract_assistant_response(generated_output):
    parts = generated_output.split('assistant')
    if len(parts) > 1:
        return parts[-1].strip()
    else:
        return generated_output.strip()

def post_process_results(input_path, output_path):
    # Read the existing results
    with open(input_path, 'r') as f:
        results = json.load(f)
    
    # Process each result
    for result in results:
        result['generated_output'] = extract_assistant_response(result['generated_output'])
    
    # Save the processed results
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Post-processing complete. Updated results saved to {output_path}")

In [16]:
# Usage
input_path = 'outputs/results_single.json'
output_path = 'outputs/processed_results_singleview.json'

post_process_results(input_path, output_path)

Post-processing complete. Updated results saved to outputs/processed_results_singleview.json


In [17]:
import json
# Load results from a JSON file
with open('/working/rajan/multiview-llm/Models/Finetune/usage/outputs/processed_results_singleview.json', 'r') as f:
    results = json.load(f)

In [18]:
results[0]

{'study_id': 's50239281',
 'ground_truth': '<findings>Left PICC tip is seen terminating in the region of the distal left brachiocephalic vein. Tracheostomy tube is in unchanged standard position. The heart is moderately enlarged. Marked calcification of the aortic knob is again present. Mild pulmonary vascular congestion is similar. Bibasilar streaky airspace opacities are minimally improved. Previously noted left pleural effusion appears to have resolved. No pneumothorax is identified. Percutaneous gastrostomy tube is seen in the left upper quadrant.</findings><impression>1. Left PICC tip appears to terminate in the distal left brachiocephalic vein. 2. Mild pulmonary vascular congestion. 3. Interval improvement in aeration of the lung bases with residual streaky opacity likely reflective of atelectasis. Interval resolution of the left pleural effusion.</impression>',
 'generated_output': '<findings>There is a new right-sided chest tube with tip projecting over the right lung apex. The

# Evaluate

## Lexical Metrics

In [19]:
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.meteor.meteor import Meteor
from rouge import Rouge
import numpy as np
from tqdm import tqdm

nltk.download('punkt', quiet=True)

def calculate_metrics(results):
    references = []
    hypotheses = []
    cider_gts = {}
    cider_res = {}
    meteor_gts = {}
    meteor_res = {}
    
    print("Preprocessing data...")
    for i, result in enumerate(tqdm(results)):
        reference = nltk.word_tokenize(result['ground_truth'])
        hypothesis = nltk.word_tokenize(result['generated_output'])
        references.append([reference])
        hypotheses.append(hypothesis)
        
        cider_gts[i] = [result['ground_truth']]
        cider_res[i] = [result['generated_output']]
        meteor_gts[i] = [result['ground_truth']]
        meteor_res[i] = [result['generated_output']]
    
    print("Calculating BLEU scores...")
    smoothie = SmoothingFunction().method1
    bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0), smoothing_function=smoothie)
    bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
    bleu_3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothie)
    bleu_4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)
    
    print("Calculating CIDEr score...")
    cider_scorer = Cider()
    cider_score, _ = cider_scorer.compute_score(cider_gts, cider_res)
    
    print("Calculating ROUGE scores...")
    rouge_scorer = Rouge()
    rouge_scores = rouge_scorer.get_scores([r['generated_output'] for r in results], 
                                           [r['ground_truth'] for r in results], avg=True)
    
    return {
        'BLEU-1': bleu_1,
        'BLEU-2': bleu_2,
        'BLEU-3': bleu_3,
        'BLEU-4': bleu_4,
        'CIDEr': cider_score,
        'ROUGE-1': rouge_scores['rouge-1']['f'],
        'ROUGE-2': rouge_scores['rouge-2']['f'],
        'ROUGE-L': rouge_scores['rouge-l']['f']
    }

In [20]:
# Calculate metrics - generated from `write_radiologic_report()` method
metrics = calculate_metrics(results)

# Print results
for metric, score in metrics.items():
    print(f"{metric}: {score:.4f}")

Preprocessing data...


100%|██████████| 627/627 [00:01<00:00, 458.35it/s]


Calculating BLEU scores...
Calculating CIDEr score...
Calculating ROUGE scores...
BLEU-1: 0.2325
BLEU-2: 0.1583
BLEU-3: 0.1178
BLEU-4: 0.0832
CIDEr: 0.0292
ROUGE-1: 0.2782
ROUGE-2: 0.0760
ROUGE-L: 0.2687


## CheXbert

In [21]:
results[12]

{'study_id': 's51129150',
 'ground_truth': '<findings>A left Port-A-Cath terminates in the right atrium, unchanged from prior. Lung volumes are extremely low resulting in bronchovascular crowding and limited evaluation of the lung bases. Diffuse interstitial opacities have increased, and despite the low lung volumes, findings are consistent with superimposed pulmonary edema on a background of pulmonary fibrosis. No large pleural effusion is evident. There is no pneumothorax. Cardiomediastinal and hilar contours are within normal limits. High density material within multiple mid thoracic vertebral bodies is likely related to prior kyphoplasty, unchanged from prior.</findings><impression>Superimposed pulmonary edema on a background of pulmonary fibrosis. Low lung volumes limit assessment for basilar consolidation.</impression>',
 'generated_output': '<findings>There is a new right lower lobe infiltrate. There is also a new left lower lobe infiltrate. There is pulmonary vascular redistrib

In [22]:
from chexbert import CheXbertMetrics
import torch
torch.cuda.set_device(0)

# Setup CheXbertMetrics
test_chexbert_metrics = CheXbertMetrics(
    bert_path='bert-base-uncased',
    checkpoint_path='/working/rajan/multiview-llm/Models/Multi/cxrmate/checkpoints/stanford/chexbert/chexbert.pth',
    ckpt_dir='ckpt',
    mbatch_size=1,
    exp_dir='metrics',
)

# Prepare the data
res = [result['generated_output'] for result in results]
gt = [result['ground_truth'] for result in results]
ids = list(range(len(res)))

# Update CheXbert metrics
test_chexbert_metrics.update(res, gt, ids)

# compute CheXbert metrics
chexbert_scores = test_chexbert_metrics.compute()

# Print CheXbert metrics
print("CheXbert Metrics:")
for key, value in chexbert_scores.items():
    if isinstance(value, torch.Tensor):
        print(f'{key}: {value.item():.3f}')
    else:
        print(f'{key}: {value:.3f}')



CheXbert Metrics:
ce_precision_macro: 0.301
ce_recall_macro: 0.232
ce_f1_macro: 0.207
ce_precision_micro: 0.414
ce_recall_micro: 0.373
ce_f1_micro: 0.392
ce_precision_example: 0.392
ce_recall_example: 0.354
ce_f1_example: 0.350
ce_num_examples: 627.000


In [23]:

# Combine with other metrics
all_metrics = {
    **metrics,  # Your previously calculated metrics
    **chexbert_scores  # CheXbert metrics
}

# Print all metrics
print("\nAll Metrics:")
for key, value in all_metrics.items():
    print(f'{key}: {value:.3f}')


All Metrics:
BLEU-1: 0.232
BLEU-2: 0.158
BLEU-3: 0.118
BLEU-4: 0.083
CIDEr: 0.029
ROUGE-1: 0.278
ROUGE-2: 0.076
ROUGE-L: 0.269
ce_precision_macro: 0.301
ce_recall_macro: 0.232
ce_f1_macro: 0.207
ce_precision_micro: 0.414
ce_recall_micro: 0.373
ce_f1_micro: 0.392
ce_precision_example: 0.392
ce_recall_example: 0.354
ce_f1_example: 0.350
ce_num_examples: 627.000


## CXRBERT

In [24]:
from cxr_bert import CXRBERT

# CXR-BERT:
cxr_bert_metric = CXRBERT(
    ckpt_dir='ckpt', 
    mbatch_size=10, 
    exp_dir='metrics',
    split='test',
    accumulate_over_dicoms = False)

# calculate CXR-BERT metric
study_ids = [result['study_id'] for result in results]
# dicom_ids = [result['dicom_id'] for result in results]
cxr_bert_metric.update(res, [[label] for label in gt], study_ids)
cxr_bert_scores = cxr_bert_metric.compute(epoch=1)

# Add CXR-BERT metric to all_metrics
for key, value in cxr_bert_scores.items():
    if isinstance(value, torch.Tensor):
        all_metrics[f"CXR-BERT_{key}"] = value.item()
    else:
        all_metrics[f"CXR-BERT_{key}"] = value

# Print CXR-BERT metrics
print("CXR-BERT Metrics:")
for key, value in cxr_bert_scores.items():
    if isinstance(value, torch.Tensor):
        print(f'{key}: {value.item():.3f}')
    else:
        print(f'{key}: {value:.3f}')



CXR-BERT Metrics:
cxr_bert_metric: 0.329


In [26]:
# Print summary of all metrics
print("\n===== Test Epoch End Metrics Summary =====")
for key, value in all_metrics.items():
    print(f"{key}: {value:.3f}")
print("==========================================\n")


===== Test Epoch End Metrics Summary =====
BLEU-1: 0.232
BLEU-2: 0.158
BLEU-3: 0.118
BLEU-4: 0.083
CIDEr: 0.029
ROUGE-1: 0.278
ROUGE-2: 0.076
ROUGE-L: 0.269
ce_precision_macro: 0.301
ce_recall_macro: 0.232
ce_f1_macro: 0.207
ce_precision_micro: 0.414
ce_recall_micro: 0.373
ce_f1_micro: 0.392
ce_precision_example: 0.392
ce_recall_example: 0.354
ce_f1_example: 0.350
ce_num_examples: 627.000
CXR-BERT_cxr_bert_metric: 0.329



In [None]:
output_path = '/outputs/metrics_singleview'
with open(output_path, 'w') as f:
    json.dump(results, f, indent=2)