In [1]:
# install medClip and clone the COVID-19 datasubset repo
# !pip install git+https://github.com/RyanWangZf/MedCLIP.git

# install a library
# !pip install torchxrayvision

# some imports
import torchxrayvision as xrv
import matplotlib.pyplot as plt
from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip import MedCLIPProcessor
from medclip import PromptClassifier
from PIL import Image
import pandas as pd
from tqdm import tqdm

# some parameters
threshold = 0.5  # Threshold value

In [1]:
import sys
import torch
import yaml
import pandas as pd
from tqdm import tqdm
import random
import json
import os

sys.path.append('../fromage')
from imgclsDataset import COVIDDataset
covid_dataset_path = "/kuacc/users/hpc-dtank/hpc_run/datasets/COVID"

COVIDdataset = COVIDDataset(covid_dataset_path)

# loading the model
processor = MedCLIPProcessor()
model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
model.from_pretrained()
clf = PromptClassifier(model, ensemble=True)
clf.cuda()

print('dataset length:', len(COVIDdataset))

# we keep track of the amount of correct predictions to determine the accuracy later
correct_predictions = 0

for i, sample in tqdm(enumerate(COVIDdataset)): # d is the image loader, and contains 500+ images
    image = sample[0]
    inputs = processor(images=image, return_tensors="pt")
    
    for j in range(5):
        print('try ', j)
        
        # prepare input prompt texts
        from medclip.prompts import generate_covid_class_prompts, process_class_prompts
        cls_prompts = process_class_prompts(generate_covid_class_prompts(n=10))
        inputs['prompt_inputs'] = cls_prompts
        
        # make classification
        output = clf(**inputs)
        pred = output['logits']

        # get the covid label (is it covid or not?) and the prediction label
        covid_label = sample[1] 
        if covid_label == 'COVID-19':
            covid_label = 1
        else:
            covid_label = 0

        if pred.item() >= threshold:
            predicted_label = 1
        else:
            predicted_label = 0

        # if the covid label and prediciton label are the same, the prediction was successful
        if covid_label == predicted_label:
            correct_predictions += 1
            break 

    # simple accuracy calculation for each iteration
    accuracy = correct_predictions / (i+1) * 100
    print("Accuracy:", correct_predictions, '/', (i+1), ' = ', accuracy)

In [2]:
import sys
import torch
import yaml
import pandas as pd
from tqdm import tqdm
import random
import json
import os

sys.path.append('../fromage')
from imgclsDataset import RSNAPneumoniaDataset

rsna_dataset_path = '/kuacc/users/hpc-dtank/hpc_run/datasets/RSNA'

RSNAdataset = RSNAPneumoniaDataset(rsna_dataset_path)

# loading the model
processor = MedCLIPProcessor()
model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
model.from_pretrained()
clf = PromptClassifier(model, ensemble=True)
clf.cuda()

print('dataset length:', len(RSNAdataset))

# we keep track of the amount of correct predictions to determine the accuracy later
correct_predictions = 0

for i, sample in tqdm(enumerate(RSNAdataset)): # d is the image loader, and contains 500+ images
    image = sample[0]
    inputs = processor(images=image, return_tensors="pt")
    
    for j in range(5):
        print('try ', j)
        
        # prepare input prompt texts
        from medclip.prompts import generate_rsna_class_prompts, process_class_prompts
        cls_prompts = process_class_prompts(generate_rsna_class_prompts(n=10))
        inputs['prompt_inputs'] = cls_prompts
        
        # make classification
        output = clf(**inputs)
        pred = output['logits']

        # get the covid label (is it covid or not?) and the prediction label
        rsna_label = sample[1]
        print('label:', rsna_label)

        if pred.item() >= threshold:
            predicted_label = 1
        else:
            predicted_label = 0

        # if the covid label and prediciton label are the same, the prediction was successful
        if rsna_label == predicted_label:
            correct_predictions += 1
            break 

    # simple accuracy calculation for each iteration
    accuracy = correct_predictions / (i+1) * 100
    print("Accuracy:", correct_predictions, '/', (i+1), ' = ', accuracy)