In [7]:
def extract_species(file_path):
    species = []
    with open(file_path, 'r') as f:
        for line in f:
            columns = line.strip().split('\t')
            if len(columns) > 1:
                species.append(columns[1])
    species = list(set(species))
    return species

caption_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/captions.txt'
species = extract_species(caption_path)

species

['brydes whale',
 'globis',
 'long finned pilot whale',
 'fin whale',
 'spotted dolphin',
 'frasiers dolphin',
 'pilot whale',
 'blue whale',
 'beluga',
 'rough toothed dolphin',
 'cuviers beaked whale',
 'killer whale',
 'melon headed whale',
 'bottlenose dolphin',
 'gray whale',
 'spinner dolphin',
 'common dolphin',
 'white sided dolphin',
 'sei whale',
 'southern right whale',
 'short finned pilot whale',
 'pygmy killer whale',
 'pantropic spotted dolphin',
 'commersons dolphin',
 'false killer whale',
 'bottlenose dolpin',
 'dusky dolphin',
 'minke whale',
 'humpback whale',
 'kiler whale']

In [6]:
#classification

import torch, os
from PIL import Image
import numpy as np
import open_clip

pretrained = '/mnt/g/Logtemp/open_clip/Whale_Dolphin_Identification/2023_08_18-21_14_32-model_coca_ViT-L-14-lr_1e-06-b_32-j_4-p_amp/checkpoints/epoch_16.pt'
# model, _, preprocess = open_clip.create_model_and_transforms('coca_ViT-L-14', pretrained='mscoco_finetuned_laion2B-s13B-b90k')
model, _, preprocess = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained=pretrained, #mscoco_finetuned_laion2B-s13B-b90k
)

tokenizer = open_clip.get_tokenizer('coca_ViT-L-14')

image_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/images/false_killer_whale_bdea86a4d11fa9.jpg'
img = Image.open(image_path).convert('RGB')
img = img.resize((224, 224), Image.Resampling.LANCZOS)
image = preprocess(img).unsqueeze(0)
text = tokenizer(species)

filename = os.path.splitext(os.path.basename(image_path))[0]
ground_truth = "_".join(filename.split('_')[:-1]).replace("_", " ")

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    # Construct the dictionary
    class_predict_dict = dict(zip(species, text_probs))
    # Extract the key with the largest value
    class_predict = max(class_predict_dict, key=class_predict_dict.get)
    
    if ground_truth == class_predict:
        predicted_correct = 1
    else:
        predicted_correct = 0

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]
print("Ground truth:", ground_truth)  # prints: 0
print("Predicted class:", class_predict)  # prints: 0
print("Predicted correct:", predicted_correct)  # prints: 1

  from .autonotebook import tqdm as notebook_tqdm


Label probs: tensor([[2.4179e-05, 3.1715e-02, 1.0379e-03, 6.5446e-06, 4.6195e-05, 2.4105e-05,
         7.0505e-06, 2.9267e-08, 2.7521e-07, 4.8915e-05, 7.4288e-05, 7.7751e-06,
         1.3812e-04, 3.5976e-05, 5.6864e-08, 3.4625e-05, 2.3820e-04, 3.1189e-07,
         9.3151e-06, 1.7222e-06, 9.1095e-07, 1.0159e-03, 4.8858e-03, 2.0357e-04,
         9.6033e-01, 5.9075e-06, 3.7162e-05, 1.4038e-05, 7.8107e-06, 4.4372e-05]])
Ground truth: false killer whale
Predicted class: brydes whale
Predicted correct: 0


In [10]:
# Batch test for Direct Classification benchmark
import torch, os, open_clip
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm

def extract_species(file_path):
    species = []
    with open(file_path, 'r') as f:
        for line in f:
            columns = line.strip().split('\t')
            if len(columns) > 1:
                species.append(columns[1])
    species = list(set(species))
    return species

caption_path = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/captions.txt'
species = extract_species(caption_path)


def process_single_epoch(pretrained_model, sentences):
    img_dir = '/mnt/g/Datasets/Whale_Dolphin_Identification/Square/test/images/'
    model, _, preprocess = open_clip.create_model_and_transforms(
        model_name="coca_ViT-L-14",
        pretrained=pretrained_model,
    )
    pt_dir = os.path.dirname(pretrained_model)
    epoch_number = os.path.basename(pretrained_model).split('epoch_')[1].split('.pt')[0]
    df_all = pd.DataFrame()

    # Prepare the tokenizer for sentences
    tokenizer = open_clip.get_tokenizer('coca_ViT-L-14')
    text = tokenizer(sentences)
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    def read_single_image(img_dir, img_path, sentences, text_features):
        #get the ground truth value
        # Extract the filename without extension
        filename = os.path.splitext(os.path.basename(img_path))[0]
        image_path = os.path.join(img_dir, img_path)
        ground_truth = "_".join(filename.split('_')[:-1]).replace("_", " ")
        
        # load a sample image
        image = preprocess(Image.open(image_path)).unsqueeze(0)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            text_probs = text_probs.cpu().tolist()[0]

        # Construct the dictionary
        class_predict_dict = dict(zip(sentences, text_probs))
        # Extract the key with the largest value
        class_predict = max(class_predict_dict, key=class_predict_dict.get)
        
        if ground_truth == class_predict:
            predicted_correct = 1
        else:
            predicted_correct = 0
        
        pred_dict = {'img_path': [img_path], 'class_predict_dict': [class_predict_dict], 'ground_truth': [ground_truth], 'class_predict': [class_predict], 'predicted_correct': [predicted_correct]}
        df_single = pd.DataFrame(pred_dict)
        
        return df_single

    all_images = os.listdir(img_dir)
    for img_path in tqdm(all_images, desc=f"epoch_{epoch_number}: Processing images", unit="image"):
        df = read_single_image(img_dir, img_path, sentences, text_features)
        df_all = pd.concat([df_all, df]).reset_index(drop=True)

    #save df_all to csv
    csv_file = f"temp/epoch_{epoch_number}.csv"
    df_all.to_csv(csv_file, index=False)

    total_rows = len(df_all)
    # Count the occurrences of '1' in the 'predicted_correct' column
    count_ones = df_all['predicted_correct'].sum()
    epoch_filename = os.path.splitext(os.path.basename(pretrained_model))[0]
    predicted_correct_pct = count_ones /total_rows *100
    print(f"Epoch_file: {epoch_filename}. Total images: {total_rows}, predicted_correct': {count_ones}, predicted_correct (%)': {predicted_correct_pct}%")
    
    epoch_dict = {'pt_dir': [pt_dir],
                'epoch_filename': [epoch_filename], 
                'sentences': [sentences], 
                'total_rows': [total_rows], 
                'predicted_correct': [count_ones], 
                'predicted_correct (%)': [predicted_correct_pct]}
    df_single_epoch = pd.DataFrame(epoch_dict)
    return df_single_epoch


def process_single_pretrained_model(pretrained_model, csv_file):
    df_single_pretrained_model = pd.DataFrame()
    df = process_single_epoch(pretrained_model, species)
    df_single_pretrained_model = pd.concat([df_single_pretrained_model, df]).reset_index(drop=True)
    save_to_csv(df_single_pretrained_model, csv_file)
    
def extract_epoch_num(filepath):
    # Split the filename from the path
    filename = os.path.basename(filepath) #filepath.split('/')[-1]
    # Extract the number between "epoch_" and ".pt"
    epoch_num = int(filename.split('epoch_')[1].split('.pt')[0])
    return epoch_num

def save_to_csv(df, csv_file):
    if not os.path.exists(csv_file):
        df.to_csv(csv_file, index=False)
    else:
        # If the CSV file already exists, append without header
        df.to_csv(csv_file, mode='a', header=False, index=False)

def main():
    pt_folder_path = '/mnt/g/Logtemp/open_clip/Whale_Dolphin_Identification/2023_08_18-21_14_32-model_coca_ViT-L-14-lr_1e-06-b_32-j_4-p_amp/checkpoints/'
    pt_files = [os.path.join(pt_folder_path, filename) for filename in os.listdir(pt_folder_path) if filename.endswith('.pt')]
    #filter pt_files with low epoch number
    pt_files = [pt_file for pt_file in pt_files if extract_epoch_num(pt_file) >= 41]
    # Sort the list
    sorted_pt_files = sorted(pt_files, key=extract_epoch_num)
    csv_file = "benchmark_all_epochs.csv"
    for pretrained_model in sorted_pt_files:
        process_single_pretrained_model(pretrained_model, csv_file)


main()

epoch_34: Processing images: 100%|██████████| 5115/5115 [48:43<00:00,  1.75image/s]


Epoch_file: epoch_34. Total images: 5115, predicted_correct': 5030, predicted_correct (%)': 98.33822091886609%


epoch_35: Processing images: 100%|██████████| 5115/5115 [52:26<00:00,  1.63image/s]  


Epoch_file: epoch_35. Total images: 5115, predicted_correct': 5026, predicted_correct (%)': 98.26001955034212%


epoch_36: Processing images: 100%|██████████| 5115/5115 [50:31<00:00,  1.69image/s]  


Epoch_file: epoch_36. Total images: 5115, predicted_correct': 5026, predicted_correct (%)': 98.26001955034212%


epoch_37: Processing images: 100%|██████████| 5115/5115 [40:55<00:00,  2.08image/s]


Epoch_file: epoch_37. Total images: 5115, predicted_correct': 5005, predicted_correct (%)': 97.84946236559139%


epoch_38: Processing images: 100%|██████████| 5115/5115 [40:33<00:00,  2.10image/s]


Epoch_file: epoch_38. Total images: 5115, predicted_correct': 5021, predicted_correct (%)': 98.16226783968719%


epoch_39: Processing images: 100%|██████████| 5115/5115 [40:26<00:00,  2.11image/s]


Epoch_file: epoch_39. Total images: 5115, predicted_correct': 5024, predicted_correct (%)': 98.22091886608015%


epoch_40: Processing images: 100%|██████████| 5115/5115 [41:06<00:00,  2.07image/s]


Epoch_file: epoch_40. Total images: 5115, predicted_correct': 5013, predicted_correct (%)': 98.0058651026393%


epoch_41: Processing images:   2%|▏         | 89/5115 [00:54<51:14,  1.63image/s]


KeyboardInterrupt: 