In [1]:
import numpy as np
import pandas as pd
import json, os, time, pickle
from modules.llm import LlavaModel, LlamaModel
from modules.prompts import CoTPrompt
from dotenv import load_dotenv; load_dotenv()

CoT = CoTPrompt("CoT")

caltech_images_path = '../database/Caltech/'
caltech_class_meta_path = '../data/caltech-101/meta/caltech_220_images.json'

REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
LLaVA = LlavaModel(REPLICATE_API_TOKEN)
Llama = LlamaModel(REPLICATE_API_TOKEN)

### LLM Calls for Generating Addition Information

In [4]:
from modules.gen_features import ImageFeatures
caltech_image_working_file = "../data/caltech-101/cotPrompt/llava/image_features.parquet"
Cal101Features = ImageFeatures(caltech_images_path, caltech_class_meta_path, 
                             caltech_image_working_file, llm_model=LLaVA, PromptSys=CoT)
Cal101Features.gen_info()

Loaded parquet!
Loaded parquet!
1/202
gerenuk_0011 done!
2/202
Image: gerenuk_0010
Time taken per label: 21.26 seconds
--------------------------------------------------
3/202
Image: hawksbill_0085
Time taken per label: 20.03 seconds
--------------------------------------------------
4/202
Image: hawksbill_0017
Time taken per label: 18.75 seconds
--------------------------------------------------
5/202
Image: headphone_0024
Time taken per label: 18.34 seconds
--------------------------------------------------
6/202
Image: headphone_0004
Time taken per label: 22.71 seconds
--------------------------------------------------
7/202
Image: ant_0011
Time taken per label: 21.39 seconds
--------------------------------------------------
8/202
Image: ant_0037
Time taken per label: 24.06 seconds
--------------------------------------------------
9/202
Image: butterfly_0089
Time taken per label: 19.24 seconds
--------------------------------------------------
10/202
Image: butterfly_0038
Time tak

In [3]:
Cal101Features.img_features

Unnamed: 0,file_name,label_id,init_pred,img_desc
0,gerenuk_0011,gerenuk,Gerenuk,"The main object in the image is a small, brown..."
1,gerenuk_0010,gerenuk,,
2,hawksbill_0085,hawksbill,,
3,hawksbill_0017,hawksbill,,
4,headphone_0024,headphone,,
...,...,...,...,...
197,elephant_0039,elephant,,
198,tick_0011,tick,,
199,tick_0029,tick,,
200,metronome_0031,metronome,,


In [None]:
from modules.gen_features import LabelFeatures
caltech_label_working_file = "../data/caltech-101/cotPrompt/llava/label_features.parquet"
Cal101_LabelFeatures = LabelFeatures(caltech_images_path, caltech_class_meta_path, 
                                    caltech_label_working_file, llm_model=Llama)
# Cal101_LabelFeatures.gen_info()

In [None]:
# Cal101_LabelFeatures.label_features

### Encoder

In [None]:
from modules.prompts import DefaultPrompt
from modules.encoder import FeaturesEncoder

model = "../models/clip-vit-large-patch14"
caltech_images_path = '../database/Caltech/caltech-101/101_ObjectCategories'

caltech_image_working_file = "../data/caltech-101/cotPrompt/llava/image_features.parquet"
caltech_label_working_file = "../data/caltech-101/cotPrompt/llava/label_features.parquet"

encoding_images_path = "../data/caltech-101/cotPrompt/llava/image_features.pkl"
encoding_labels_path = "../data/caltech-101/cotPrompt/llava/label_features.pkl"

FE = FeaturesEncoder(caltech_images_path, encoding_images_path, encoding_labels_path, 
                     img_file_type='jpg', model=model)

DefPrompt = DefaultPrompt('Default')

human_design_prompt = "A photo of {}"

FE.encode_images(caltech_image_working_file)
# FE.encode_labels(caltech_label_working_file, human_design_prompt)


Loading model: clip-vit-large-patch14
Creating embedding dict...
1/202
Time taken per label: 7.03 seconds
--------------------------------------------------
2/202
Time taken per label: 0.53 seconds
--------------------------------------------------
3/202
Time taken per label: 0.48 seconds
--------------------------------------------------
4/202
Time taken per label: 0.48 seconds
--------------------------------------------------
5/202
Time taken per label: 0.46 seconds
--------------------------------------------------
6/202
Time taken per label: 0.46 seconds
--------------------------------------------------
7/202
Time taken per label: 0.46 seconds
--------------------------------------------------
8/202
Time taken per label: 0.44 seconds
--------------------------------------------------
9/202
Time taken per label: 0.49 seconds
--------------------------------------------------
10/202
Time taken per label: 0.5 seconds
--------------------------------------------------
11/202
Time tak

In [None]:
FE.encode_labels(caltech_label_working_file, human_design_prompt)

### Prediction

In [10]:
import pickle
from modules.classifier import ImageClassifier
# Classification
encoded_image_file = "../data/caltech-101/cotPrompt/llava/image_features.pkl"
encoded_text_file  = "../data/caltech-101/cotPrompt/llava/label_features.pkl"

with open(encoded_image_file, "rb") as f: 
    img_features = pickle.load(f)

with open(encoded_text_file, "rb") as f: 
    label_features = pickle.load(f)

In [11]:
import pandas as pd
acc_df = pd.DataFrame(columns=['accuracy', 'precision', 'recall', 'f1'])
for X in ['X_if', 'X_df', 'X_pf', 'X_q']:
    I4P = ImageClassifier(label_features, mode='M4', img_features=img_features, ifeature=X)
    print("="*50)
    df = I4P.classify()
    accuracy, precision, recall, f1 = I4P.evaluation(df)
    acc_df.loc[X] = [accuracy, precision, recall, f1]

save_path = "../data/accuracies/caltech_llava_cot.csv"
acc_df.to_csv(save_path, index=False)
acc_df.head()

Using model M4: Fused Features Embedding
Using Image Feature: Encoded Image X_if
Accuracy: 0.9208
Precision: 0.9208
Recall: 0.9125
F1-score: 0.9073
Using model M4: Fused Features Embedding
Using Image Feature: Encoded Image Description X_df
Accuracy: 0.6832
Precision: 0.6832
Recall: 0.6335
F1-score: 0.6303
Using model M4: Fused Features Embedding
Using Image Feature: Encoded Init Prediction X_pf
Accuracy: 0.7178
Precision: 0.7178
Recall: 0.6541
F1-score: 0.6615
Using model M4: Fused Features Embedding
Using Image Feature: Encoded Fused Image Feature X_q
Accuracy: 0.7525
Precision: 0.7525
Recall: 0.7078
F1-score: 0.7042


Unnamed: 0,accuracy,precision,recall,f1
X_if,0.920792,0.920792,0.912541,0.907261
X_df,0.683168,0.683168,0.633534,0.630344
X_pf,0.717822,0.717822,0.654125,0.661528
X_q,0.752475,0.752475,0.707779,0.704212
