In [None]:
import pandas as pd
def extract_classification(file_path):
    classification = []
    df = pd.read_csv(file_path)
    classification = df['tissue_types'].tolist()
    classification = list(set(classification))
    return classification

csv_path = '/mnt/e/Temp_data/Body_Parts_XRay/Original/train_df_new1.csv'
classification = extract_classification(csv_path)

classification

In [None]:
#classification

import torch, os
from PIL import Image
import numpy as np
import open_clip

pretrained = '/mnt/g/Logtemp/open_clip/Whale_Dolphin_Identification/2023_08_18-21_14_32-model_coca_ViT-L-14-lr_1e-06-b_32-j_4-p_amp/checkpoints/epoch_16.pt'
# model, _, preprocess = open_clip.create_model_and_transforms('coca_ViT-L-14', pretrained='mscoco_finetuned_laion2B-s13B-b90k')
model, _, preprocess = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained=pretrained, #mscoco_finetuned_laion2B-s13B-b90k
)

tokenizer = open_clip.get_tokenizer('coca_ViT-L-14')

image_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/images/false_killer_whale_bdea86a4d11fa9.jpg'
img = Image.open(image_path).convert('RGB')
img = img.resize((224, 224), Image.Resampling.LANCZOS)
image = preprocess(img).unsqueeze(0)
text = tokenizer(species)

filename = os.path.splitext(os.path.basename(image_path))[0]
ground_truth = "_".join(filename.split('_')[:-1]).replace("_", " ")

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    # Construct the dictionary
    class_predict_dict = dict(zip(species, text_probs))
    # Extract the key with the largest value
    class_predict = max(class_predict_dict, key=class_predict_dict.get)
    
    if ground_truth == class_predict:
        predicted_correct = 1
    else:
        predicted_correct = 0

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]
print("Ground truth:", ground_truth)  # prints: 0
print("Predicted class:", class_predict)  # prints: 0
print("Predicted correct:", predicted_correct)  # prints: 1

In [None]:
# Batch test for Direct Classification benchmark
import torch, os, open_clip
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm

def extract_classification(file_path):
    classification = []
    df_classification = pd.read_csv(file_path)
    classification = df_classification['tissue_types'].tolist()
    classification = list(set(classification))
    return classification, df_classification

classification_csv_file = '/mnt/e/Temp_data/Body_Parts_XRay/Original/train_df_new1.csv'
classification, df_classification = extract_classification(classification_csv_file)


def process_single_epoch(pretrained_model, sentences):
    img_dir = '/mnt/e/Temp_data/Body_Parts_XRay/Original/test/images/'
    model, _, preprocess = open_clip.create_model_and_transforms(
        model_name="coca_ViT-B-32",
        pretrained=pretrained_model,
    )
    pt_dir = os.path.dirname(pretrained_model)
    epoch_number = os.path.basename(pretrained_model).split('epoch_')[1].split('.pt')[0]
    df_all = pd.DataFrame()

    # Prepare the tokenizer for sentences
    tokenizer = open_clip.get_tokenizer('coca_ViT-B-32')
    text = tokenizer(sentences)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    def read_single_image(img_dir, img_filename, sentences, text_features):
        #get the ground truth value
        image_path = os.path.join(img_dir, img_filename)
        ground_truth = df_classification[df_classification['filename']==img_filename]['tissue_types'].iloc[0]
        
        # load a sample image
        image = preprocess(Image.open(image_path)).unsqueeze(0)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            text_probs = text_probs.cpu().tolist()[0]

        # Construct the dictionary
        class_predict_dict = dict(zip(sentences, text_probs))
        # Extract the key with the largest value
        class_predict = max(class_predict_dict, key=class_predict_dict.get)
        
        if ground_truth == class_predict:
            predicted_correct = 1
        else:
            predicted_correct = 0
        
        pred_dict = {'img_filename': [img_filename], 'class_predict_dict': [class_predict_dict], 'ground_truth': [ground_truth], 'class_predict': [class_predict], 'predicted_correct': [predicted_correct]}
        df_single = pd.DataFrame(pred_dict)
        
        return df_single

    all_images = os.listdir(img_dir)
    for img_filename in tqdm(all_images, desc=f"epoch_{epoch_number}: Processing images", unit="image"):
        df = read_single_image(img_dir, img_filename, sentences, text_features)
        df_all = pd.concat([df_all, df]).reset_index(drop=True)

    #save df_all to csv
    csv_file = f"temp/epoch_{epoch_number}.csv"
    df_all.to_csv(csv_file, index=False)

    total_rows = len(df_all)
    # Count the occurrences of '1' in the 'predicted_correct' column
    count_ones = df_all['predicted_correct'].sum()
    epoch_filename = os.path.splitext(os.path.basename(pretrained_model))[0]
    predicted_correct_pct = count_ones /total_rows *100
    print(f"Epoch_file: {epoch_filename}. Total images: {total_rows}, predicted_correct': {count_ones}, predicted_correct (%)': {predicted_correct_pct}%")
    
    epoch_dict = {'pt_dir': [pt_dir],
                'epoch_filename': [epoch_filename], 
                'sentences': [sentences], 
                'total_rows': [total_rows], 
                'predicted_correct': [count_ones], 
                'predicted_correct (%)': [predicted_correct_pct]}
    df_single_epoch = pd.DataFrame(epoch_dict)
    return df_single_epoch


def process_single_pretrained_model(pretrained_model, csv_file):
    df_single_pretrained_model = pd.DataFrame()
    df = process_single_epoch(pretrained_model, classification)
    df_single_pretrained_model = pd.concat([df_single_pretrained_model, df]).reset_index(drop=True)
    save_to_csv(df_single_pretrained_model, csv_file)
    
def extract_epoch_num(filepath):
    # Split the filename from the path
    filename = os.path.basename(filepath) #filepath.split('/')[-1]
    # Extract the number between "epoch_" and ".pt"
    epoch_num = int(filename.split('epoch_')[1].split('.pt')[0])
    return epoch_num

def save_to_csv(df, csv_file):
    if not os.path.exists(csv_file):
        df.to_csv(csv_file, index=False)
    else:
        # If the CSV file already exists, append without header
        df.to_csv(csv_file, mode='a', header=False, index=False)

def main():
    pt_folder_path = '/mnt/x/Log/open_clip_2920X/Body_Parts_XRay/2023_08_19-22_56_49-model_coca_ViT-B-32-lr_1e-06-b_16-j_4-p_amp/checkpoints/'
    pt_files = [os.path.join(pt_folder_path, filename) for filename in os.listdir(pt_folder_path) if filename.endswith('.pt')]
    # Sort the list
    sorted_pt_files = sorted(pt_files, key=extract_epoch_num)
    csv_file = "benchmark_all_epochs.csv"
    for pretrained_model in sorted_pt_files:
        process_single_pretrained_model(pretrained_model, csv_file)


main()