In [2]:

import pandas as pd
csv_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/train_new3.csv'
df = pd.read_csv(csv_path)
species = list(set(df['species'].tolist()))
species = [a.replace('_', ' ') for a in species]
species

['globis',
 'kiler whale',
 'spotted dolphin',
 'frasiers dolphin',
 'long finned pilot whale',
 'pantropic spotted dolphin',
 'killer whale',
 'rough toothed dolphin',
 'false killer whale',
 'short finned pilot whale',
 'cuviers beaked whale',
 'bottlenose dolphin',
 'commersons dolphin',
 'melon headed whale',
 'beluga',
 'white sided dolphin',
 'humpback whale',
 'common dolphin',
 'bottlenose dolpin',
 'sei whale',
 'pygmy killer whale',
 'pilot whale',
 'dusky dolphin',
 'minke whale',
 'gray whale',
 'fin whale',
 'brydes whale',
 'spinner dolphin',
 'southern right whale',
 'blue whale']

In [3]:
#classification

import torch, os
from PIL import Image
import numpy as np
import open_clip

pretrained = '/mnt/g/Logtemp/open_clip/Whale_Dolphin_Identification/2023_09_26-22_11_30-model_coca_ViT-L-14-lr_1e-06-b_32-j_4-p_amp/checkpoints/epoch_61.pt'
# model, _, preprocess = open_clip.create_model_and_transforms('coca_ViT-L-14', pretrained='mscoco_finetuned_laion2B-s13B-b90k')
model, _, preprocess = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained=pretrained, #mscoco_finetuned_laion2B-s13B-b90k
  color_image=True,
)

tokenizer = open_clip.get_tokenizer('coca_ViT-L-14')

image_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/images/false_killer_whale_bdea86a4d11fa9.jpg'
img = Image.open(image_path).convert('RGB')
img = img.resize((224, 224), Image.Resampling.LANCZOS)
image = preprocess(img).unsqueeze(0)
text = tokenizer(species)

filename = os.path.splitext(os.path.basename(image_path))[0]
ground_truth = "_".join(filename.split('_')[:-1]).replace("_", " ")

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    # Construct the dictionary
    class_predict_dict = dict(zip(species, text_probs))
    # Extract the key with the largest value
    class_predict = max(class_predict_dict, key=class_predict_dict.get)
    
    if ground_truth == class_predict:
        predicted_correct = 1
    else:
        predicted_correct = 0

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]
print("Ground truth:", ground_truth)  # prints: 0
print("Predicted class:", class_predict)  # prints: 0
print("Predicted correct:", predicted_correct)  # prints: 1

TODO: color_image = True
TODO: color_image = True
Label probs: tensor([[4.0117e-03, 9.9646e-03, 8.8292e-08, 2.8080e-04, 1.3491e-02, 2.4473e-05,
         1.0136e-03, 2.0245e-02, 2.2700e-01, 4.9068e-03, 5.2500e-05, 5.4739e-07,
         1.2419e-03, 2.0707e-04, 2.9619e-05, 6.1123e-12, 6.0200e-04, 2.1139e-08,
         2.0880e-06, 4.5062e-05, 7.0411e-01, 1.2382e-02, 5.9633e-10, 1.1547e-07,
         2.5476e-06, 3.7558e-05, 3.0966e-07, 4.8484e-05, 3.0047e-04, 1.9472e-08]])
Ground truth: false killer whale
Predicted class: globis
Predicted correct: 0


In [2]:
# Batch test for Direct Classification benchmark
import torch, os, open_clip
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm

csv_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/train_new3.csv'
df = pd.read_csv(csv_path)
species = list(set(df['species'].tolist()))
species = [a.replace('_', ' ') for a in species]


def process_single_epoch(pretrained_model, sentences):
    img_dir = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/images/'
    model, _, preprocess = open_clip.create_model_and_transforms(
        model_name="coca_ViT-L-14",
        pretrained=pretrained_model,
        color_image=True,
    )
    pt_dir = os.path.dirname(pretrained_model)
    epoch_number = os.path.basename(pretrained_model).split('epoch_')[1].split('.pt')[0]
    df_all = pd.DataFrame()

    # Prepare the tokenizer for sentences
    tokenizer = open_clip.get_tokenizer('coca_ViT-L-14')
    text = tokenizer(sentences)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    def read_single_image(img_dir, img_path, sentences, text_features):
        #get the ground truth value
        # Extract the filename without extension
        filename = os.path.splitext(os.path.basename(img_path))[0]
        image_path = os.path.join(img_dir, img_path)
        ground_truth = "_".join(filename.split('_')[:-1]).replace("_", " ")
        
        # load a sample image
        image = preprocess(Image.open(image_path)).unsqueeze(0)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            text_probs = text_probs.cpu().tolist()[0]

        # Construct the dictionary
        class_predict_dict = dict(zip(sentences, text_probs))
        # Extract the key with the largest value
        class_predict = max(class_predict_dict, key=class_predict_dict.get)
        
        if ground_truth == class_predict:
            predicted_correct = 1
        else:
            predicted_correct = 0
        
        pred_dict = {'img_path': [img_path], 'class_predict_dict': [class_predict_dict], 'ground_truth': [ground_truth], 'class_predict': [class_predict], 'predicted_correct': [predicted_correct]}
        df_single = pd.DataFrame(pred_dict)
        
        return df_single

    all_images = os.listdir(img_dir)
    for img_path in tqdm(all_images, desc=f"epoch_{epoch_number}: Processing images", unit="image"):
        df = read_single_image(img_dir, img_path, sentences, text_features)
        df_all = pd.concat([df_all, df]).reset_index(drop=True)

    #save df_all to csv
    csv_file = f"temp/epoch_{epoch_number}.csv"
    df_all.to_csv(csv_file, index=False)

    total_rows = len(df_all)
    # Count the occurrences of '1' in the 'predicted_correct' column
    count_ones = df_all['predicted_correct'].sum()
    epoch_filename = os.path.splitext(os.path.basename(pretrained_model))[0]
    predicted_correct_pct = count_ones /total_rows *100
    print(f"Epoch_file: {epoch_filename}. Total images: {total_rows}, predicted_correct': {count_ones}, predicted_correct (%)': {predicted_correct_pct}%")
    
    epoch_dict = {'pt_dir': [pt_dir],
                'epoch_filename': [epoch_filename], 
                'sentences': [sentences], 
                'total_rows': [total_rows], 
                'predicted_correct': [count_ones], 
                'predicted_correct (%)': [predicted_correct_pct]}
    df_single_epoch = pd.DataFrame(epoch_dict)
    return df_single_epoch


def process_single_pretrained_model(pretrained_model, csv_file):
    df_single_pretrained_model = pd.DataFrame()
    df = process_single_epoch(pretrained_model, species)
    df_single_pretrained_model = pd.concat([df_single_pretrained_model, df]).reset_index(drop=True)
    save_to_csv(df_single_pretrained_model, csv_file)
    
def extract_epoch_num(filepath):
    # Split the filename from the path
    filename = os.path.basename(filepath) #filepath.split('/')[-1]
    # Extract the number between "epoch_" and ".pt"
    epoch_num = int(filename.split('epoch_')[1].split('.pt')[0])
    return epoch_num

def save_to_csv(df, csv_file):
    if not os.path.exists(csv_file):
        df.to_csv(csv_file, index=False)
    else:
        # If the CSV file already exists, append without header
        df.to_csv(csv_file, mode='a', header=False, index=False)

def main():
    pt_folder_path = '/mnt/g/Logtemp/open_clip/Whale_Dolphin_Identification/2023_09_28-21_57_51-model_coca_ViT-L-14-lr_2e-05-b_32-j_4-p_amp/checkpoints/'
    pt_files = [os.path.join(pt_folder_path, filename) for filename in os.listdir(pt_folder_path) if filename.endswith('.pt')]
    #filter pt_files with low epoch number
    pt_files = [pt_file for pt_file in pt_files if extract_epoch_num(pt_file) >= 69]
    # Sort the list
    sorted_pt_files = sorted(pt_files, key=extract_epoch_num, reverse=True) #
    csv_file = "benchmark_all_epochs.csv"
    for pretrained_model in sorted_pt_files:
        process_single_pretrained_model(pretrained_model, csv_file)


main()

epoch_100: Processing images:  18%|█▊        | 946/5115 [08:45<38:37,  1.80image/s]


KeyboardInterrupt: 