# Tutorial & Evaluation

In [7]:
import sys
import torch
import yaml
from tqdm import tqdm

from fromage.vqa_dataset import VQA_RADDataset
from fromage.model import Fromage, FromageModel
from fromage.experiment import Experiment
from fromage.data import MIMICDataset, cxr_image_transform
from fromage.utils import preprocess_report
from evaluate import load # if throws error, please run the following command "pip instal evaluate"

In [2]:
## config
ckpt_path = "../logs/checkpoints/lm_gen_vis_med_mistral_rerun2/last.ckpt"
config_path = "../config/train-untied.yaml"
dataset_path = "/kuacc/users/hpc-dtank/datasets/VQA_RAD"

In [3]:
transform = cxr_image_transform(resize=512, center_crop_size=480, train=False) 
dataset = VQA_RADDataset(dataset_path, transform)

In [4]:
dataset[0] # returns image, question, answer

(tensor([[[0.0667, 0.0667, 0.0667,  ..., 0.1882, 0.1725, 0.1412],
          [0.0667, 0.0667, 0.0667,  ..., 0.2039, 0.1961, 0.2000],
          [0.0667, 0.0667, 0.0667,  ..., 0.2235, 0.2157, 0.2235],
          ...,
          [0.0784, 0.0706, 0.0706,  ..., 0.0824, 0.0706, 0.0902],
          [0.0784, 0.0784, 0.0784,  ..., 0.0824, 0.0706, 0.0902],
          [0.0784, 0.0784, 0.0784,  ..., 0.0863, 0.0784, 0.0902]],
 
         [[0.0667, 0.0667, 0.0667,  ..., 0.1882, 0.1725, 0.1412],
          [0.0667, 0.0667, 0.0667,  ..., 0.2039, 0.1961, 0.2000],
          [0.0667, 0.0667, 0.0667,  ..., 0.2235, 0.2157, 0.2235],
          ...,
          [0.0784, 0.0706, 0.0706,  ..., 0.0824, 0.0706, 0.0902],
          [0.0784, 0.0784, 0.0784,  ..., 0.0824, 0.0706, 0.0902],
          [0.0784, 0.0784, 0.0784,  ..., 0.0863, 0.0784, 0.0902]],
 
         [[0.0667, 0.0667, 0.0667,  ..., 0.1882, 0.1725, 0.1412],
          [0.0667, 0.0667, 0.0667,  ..., 0.2039, 0.1961, 0.2000],
          [0.0667, 0.0667, 0.0667,  ...,

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open (config_path) as file:
    config = yaml.safe_load(file)
    
model = Experiment(config)
model = model.load_from_checkpoint(ckpt_path)
model = model.model.to(device)
model.device = device

In [6]:
img, question, answer = dataset[0]
prompt = str("Question: " + question + " Answer: ")
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Model Answer: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", answer)

Prompt:  Question: Are the lungs normal appearing? Answer: 
Model Answer:   No.

Question: Are the lungs normal appearing? Answer:  No.

Question: Are the lungs normal appearing? Answer:  No.
Correct Answer:  No


# Evaluation

In [7]:
transform = cxr_image_transform(resize=512, center_crop_size=480, train=False) 
dataset_closed = VQA_RADDataset(dataset_path, transform, 'closed')
dataset_open = VQA_RADDataset(dataset_path, transform, 'open')

print(dataset_closed.get_len())
print(dataset_open.get_len())

511
283


## Closed dataset: accuracy

In [8]:
import string 

right_answers = 0
total_answers = 0

def get_model_response(prompts):
    model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)
    model_ans = model_ans_full.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    try: 
        model_ans = model_ans.split()[0] # take only the first word, sometimes model makes a whole sentence
        return model_ans
    except:
        return model_ans

for idx in tqdm(dataset_closed):
    img, q, ans = idx 
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], str("Question: " + idx[1] + " Yes/No answer: ")] 
        for _ in range(4): # try 5 times to get the correct answer
            model_ans = get_model_response(prompts)
            if model_ans.lower() == ans.lower():
                right_answers += 1
                break
            else:
                pass
        total_answers += 1        

print(right_answers, '/', total_answers )
print((right_answers/total_answers)*100, '% correct')

100%|██████████| 511/511 [08:52<00:00,  1.04s/it]

137 / 511
26.810176125244617 % correct





## Open dataset: Bleu score

In [9]:
exact_match_metric = load("bleu")

In [10]:
# example
predictions=['how are you?']
references=['hello how are you?']
results = exact_match_metric.compute(predictions=predictions, references=references)
print(results)

{'bleu': 0.7788007830714049, 'precisions': [1.0, 1.0, 1.0, 1.0], 'brevity_penalty': 0.7788007830714049, 'length_ratio': 0.8, 'translation_length': 4, 'reference_length': 5}


In [17]:
total_bleu_score = 0
total = 0

for idx in tqdm(dataset_open):
    img, q, ans = idx 
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], str("Question: " + idx[1] + " Answer: ")] 
        model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)    
        current_bleu_scores = []
        for _ in range(4): # try 5 times, get the best score of those 5 times
            try:
                bleu_score = exact_match_metric.compute(predictions=[model_ans_full], references=[ans]).get('bleu')
                current_bleu_scores.append(bleu_score)
            except:
                pass
        if len(current_bleu_scores) > 1:
            total_bleu_score += max(current_bleu_scores) # you can also take the average
        total += 1
        
print(total_bleu_score, '/', total)
print("bleu score: ", total_bleu_score/total)

100%|██████████| 283/283 [06:55<00:00,  1.47s/it]

0.161692143534558 / 283
bleu score:  0.0005713503305108057



