# Image Classification Evaluation
### Imports & Config

In [1]:
import sys
import torch
import yaml
import pandas as pd
from tqdm import tqdm
import random
import json
import os

from fromage.imgcls_dataset import RSNAPneumoniaDataset, COVIDDataset
from fromage.data import cxr_image_transform
from fromage.experiment import Experiment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
rsna_dataset_path = '"../data/datasets/RSNA'
covid_dataset_path = "../data/datasets/COVID"

ckpt_path = "../logs/checkpoints/lm_med_vis_med/last.ckpt"
config_path = "../config/train-untied_lm_med_vis_med.yaml"

### Load datasets and model

In [3]:
transform = cxr_image_transform(resize=512, center_crop_size=480, train=False) 

RSNAdataset = RSNAPneumoniaDataset(rsna_dataset_path, transform)
COVIDdataset = COVIDDataset(covid_dataset_path, transform)

FileNotFoundError: [Errno 2] No such file or directory: '"../data/datasets/RSNA/stage_2_train_labels_short.csv'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open (config_path) as file:
    config = yaml.safe_load(file)
    
model = Experiment(config)
model = model.load_from_checkpoint(ckpt_path)
model = model.model.to(device)
model.device = device

### RSNA dataset evaluation
The RSNA Pneumonia detection dataset has two classes: pneumonia (1) or no pneumonia (0). 

In [5]:
img, correct_class = RSNAdataset[5]
prompt = "Question: Does this image have Pneumonia? Yes or No? Answer: "
# prompt = "Pneumonia or no pneumonia?"
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Does this image have Pneumonia? Yes or No? Answer: 
Predicted class:  
Pneumothorax. No pleural effusion. No sphenorum. No pleural effusion. No hemangioma. No
Correct Answer:  0


In [6]:
img, correct_class = RSNAdataset[4]
prompt = "Question: Does this image have Pneumonia? Yes or No? Answer: "
# prompt = "Pneumonia or no pneumonia?"
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Does this image have Pneumonia? Yes or No? Answer: 
Predicted class:  
Yes, there is no right sided pleural effusion. No focal lung mass. No focal lung mass. No focal lung mass. No focal lung mass
Correct Answer:  1


In [7]:
import string 

right_answers = 0
total_answers = 0

def get_model_response(prompts):
    model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)
    model_ans = model_ans_full.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    try: 
        model_ans = model_ans.split()[0] # take only the first word, sometimes model makes a whole sentence
        return str(model_ans)
    except:
        return str(model_ans)

for idx in tqdm(RSNAdataset):
    img, ans = idx 
    if ans == 0:
        ans = 'yes'
    else:
        ans = 'no'
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], "Question: Does this image have Pneumonia? Yes or No? Answer: "] 
        for _ in range(4): # try 5 times to get the correct answer
            model_ans = get_model_response(prompts)
            if model_ans.lower() == ans.lower():
                right_answers += 1
                break
            else:
                pass
        total_answers += 1        

print(right_answers, '/', total_answers )
print((right_answers/total_answers)*100, '% correct')

100%|██████████| 3538/3538 [1:02:28<00:00,  1.06s/it]

577 / 3538
16.308648954211417 % correct





### COVID dataset evaluation
The COVID dataset has three classes: (1) Normal, (2) COVID-19, and (3) Non-COVID

In [8]:
img, correct_class = COVIDdataset[0]
prompt = "Question: Choose from the following classes: Normal, COVID-19, Non-Covid. Answer: "
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Choose from the following classes: Normal, COVID-19, Non-Covid. Answer: 
Predicted class:  
The patient is asymptomatic. The patient is asymptomatic. The patient is asymptomatic. The patient is asymptomatic
Correct Answer:  COVID-19


In [9]:
import string 

right_answers = 0
total_answers = 0

def get_model_response(prompts):
    model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)
    model_ans = model_ans_full.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    try: 
        model_ans = model_ans.split()[0] # take only the first word, sometimes model makes a whole sentence
        return str(model_ans)
    except:
        return str(model_ans)

for idx in tqdm(COVIDdataset):
    img, ans = idx 
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], "Question: Choose if the patient case is Normal, COVID-19, or Non-Covid. Answer: "] 
        for _ in range(4): # try 5 times to get the correct answer
            model_ans = get_model_response(prompts)
            if model_ans.lower() == ans.lower():
                right_answers += 1
                break
            else:
                pass
        total_answers += 1        

print(right_answers, '/', total_answers )
print((right_answers/total_answers)*100, '% correct')

100%|██████████| 3000/3000 [57:08<00:00,  1.14s/it] 

0 / 3000
0.0 % correct



