# Image Classification Evaluation
### Imports & Config

In [1]:
import sys
import torch
import yaml
import pandas as pd
from tqdm import tqdm
import random
import json
import os

sys.path.append('../fromage')
from imgclsDataset import RSNAPneumoniaDataset, COVIDDataset
from data import cxr_image_transform
from experiment import Experiment

In [2]:
rsna_dataset_path = '/kuacc/users/hpc-dtank/hpc_run/datasets/RSNA'
covid_dataset_path = "/kuacc/users/hpc-dtank/hpc_run/datasets/COVID"

ckpt_path = "/kuacc/users/hpc-dtank/hpc_run/projects/models/vl_eval_5/last.ckpt"
config_path = "/kuacc/users/hpc-dtank/hpc_run/projects/models/vl_eval_5/train-vleval-5.yaml"

### Preprocessing, only need to run once

In [None]:
# # RSNA csv preprocessing
# RSNAdata = pd.read_csv(rsna_dataset_path + '/stage_2_train_labels.csv')
# random_indices = random.sample(range(len(RSNAdata)), 3538)
# selected_rows = RSNAdata.loc[random_indices]
# selected_rows.to_csv(rsna_dataset_path + '/stage_2_train_labels_short.csv', index=False)
# print('done')

In [None]:
# # COVID dataset preprocessing
# import os
# import json

# data_root = covid_dataset_path
# output_file = covid_dataset_path + "/COVID.json"

# def create_json(folder_path, class_label):
#     json_data = []
#     image_folder = os.path.join(folder_path, "images")
#     for image_name in os.listdir(image_folder):
#         image_path = os.path.join(image_folder, image_name)
#         json_entry = {
#             "image_path": os.path.relpath(image_path, data_root),
#             "class": class_label
#         }
#         json_data.append(json_entry)
#     return json_data

# def main():
#     all_data = []

#     for data_type in ["Infection Segmentation Data", "Lung Segmentation Data"]:
#         data_type_folder = os.path.join(data_root, data_type)
#         subfolders = [subfolder for subfolder in os.listdir(data_type_folder) if os.path.isdir(os.path.join(data_type_folder, subfolder))]
        
#         for subfolder in subfolders:
#             class_labels = ["Normal", "Non-COVID", "COVID-19"]
#             for class_label in class_labels:
#                 test_folder = os.path.join(data_type_folder, subfolder, "Test", class_label)
#                 json_data = create_json(test_folder, class_label)
#                 all_data.extend(json_data)

#     with open(output_file, "w") as json_file:
#         json.dump(all_data, json_file, indent=4)

# if __name__ == "__main__":
#     main()

In [None]:
# # Read the input JSON file    
# f = open(covid_dataset_path + '/COVID.json')
# covid_data = json.load(f)

# # Randomly sample 3000 items
# random_indices = random.sample(range(len(covid_data)), 3000)
# selected_rows = [covid_data[i] for i in random_indices]

# # Write the selected data to a new JSON file
# output_json_path = covid_dataset_path + '/COVIDshort.json'
# with open(output_json_path, 'w') as json_output_file:
#     json.dump(selected_rows, json_output_file)
# print('done')

### Load datasets and model

In [3]:
transform = cxr_image_transform(resize=512, center_crop_size=480, train=False) 

RSNAdataset = RSNAPneumoniaDataset(rsna_dataset_path, transform)
COVIDdataset = COVIDDataset(covid_dataset_path, transform)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open (config_path) as file:
    config = yaml.safe_load(file)
    
model = Experiment(config)
model = model.load_from_checkpoint(ckpt_path)
model = model.model.to(device)
model.device = device

### RSNA dataset evaluation
The RSNA Pneumonia detection dataset has two classes: pneumonia (1) or no pneumonia (0). 

In [5]:
img, correct_class = RSNAdataset[5]
prompt = "Question: Does this image have Pneumonia? Yes or No? Answer: "
# prompt = "Pneumonia or no pneumonia?"
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Does this image have Pneumonia? Yes or No? Answer: 
Predicted class:  
Pneumothorax. No pleural effusion. No sphenorum. No pleural effusion. No hemangioma. No
Correct Answer:  0


In [6]:
img, correct_class = RSNAdataset[4]
prompt = "Question: Does this image have Pneumonia? Yes or No? Answer: "
# prompt = "Pneumonia or no pneumonia?"
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Does this image have Pneumonia? Yes or No? Answer: 
Predicted class:  
Yes, there is no right sided pleural effusion. No focal lung mass. No focal lung mass. No focal lung mass. No focal lung mass
Correct Answer:  1


In [7]:
import string 

right_answers = 0
total_answers = 0

def get_model_response(prompts):
    model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)
    model_ans = model_ans_full.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    try: 
        model_ans = model_ans.split()[0] # take only the first word, sometimes model makes a whole sentence
        return str(model_ans)
    except:
        return str(model_ans)

for idx in tqdm(RSNAdataset):
    img, ans = idx 
    if ans == 0:
        ans = 'yes'
    else:
        ans = 'no'
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], "Question: Does this image have Pneumonia? Yes or No? Answer: "] 
        for _ in range(4): # try 5 times to get the correct answer
            model_ans = get_model_response(prompts)
            if model_ans.lower() == ans.lower():
                right_answers += 1
                break
            else:
                pass
        total_answers += 1        

print(right_answers, '/', total_answers )
print((right_answers/total_answers)*100, '% correct')

100%|██████████| 3538/3538 [1:02:28<00:00,  1.06s/it]

577 / 3538
16.308648954211417 % correct





### COVID dataset evaluation
The COVID dataset has three classes: (1) Normal, (2) COVID-19, and (3) Non-COVID

In [8]:
img, correct_class = COVIDdataset[0]
prompt = "Question: Choose from the following classes: Normal, COVID-19, Non-Covid. Answer: "
print("Prompt: ", prompt)

with torch.inference_mode():
    model.eval()
    prompts = [img, prompt] 
    print("Predicted class: ", model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))
    
print("Correct Answer: ", correct_class)

Prompt:  Question: Choose from the following classes: Normal, COVID-19, Non-Covid. Answer: 
Predicted class:  
The patient is asymptomatic. The patient is asymptomatic. The patient is asymptomatic. The patient is asymptomatic
Correct Answer:  COVID-19


In [9]:
import string 

right_answers = 0
total_answers = 0

def get_model_response(prompts):
    model_ans_full = model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5)
    model_ans = model_ans_full.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    try: 
        model_ans = model_ans.split()[0] # take only the first word, sometimes model makes a whole sentence
        return str(model_ans)
    except:
        return str(model_ans)

for idx in tqdm(COVIDdataset):
    img, ans = idx 
    with torch.inference_mode():
        model.eval()
        prompts = [idx[0], "Question: Choose if the patient case is Normal, COVID-19, or Non-Covid. Answer: "] 
        for _ in range(4): # try 5 times to get the correct answer
            model_ans = get_model_response(prompts)
            if model_ans.lower() == ans.lower():
                right_answers += 1
                break
            else:
                pass
        total_answers += 1        

print(right_answers, '/', total_answers )
print((right_answers/total_answers)*100, '% correct')

100%|██████████| 3000/3000 [57:08<00:00,  1.14s/it] 

0 / 3000
0.0 % correct



