### **Install Required Libraries**

In [1]:
# Install necessary libraries
# !pip install --upgrade transformers accelerate sentencepiece safetensors torch Pillow pandas huggingface_hub
# !pip install git+https://github.com/haotian-liu/LLaVA.git
# !pip install bitsandbytes

### **Import Libraries**

In [1]:
import os
import torch
from PIL import Image
import pandas as pd
import random
import string
from pathlib import Path
from huggingface_hub import login
import warnings
import gc
import logging
import traceback
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline


In [2]:
from huggingface_hub import login

# add huggingface token
login(token='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')

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\Roshan\.cache\huggingface\token
Login successful


### **Set Up the LLava Model and Processor**

In [3]:
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [4]:
# torch.cuda.empty_cache()
# del image_to_text_pipe
# del text_gen_pipe

# Memory currently allocated by tensors
allocated_memory = torch.cuda.memory_allocated(0) / (1024 ** 2)  # Convert to MB
print(f"Allocated Memory: {allocated_memory:.2f} MB")

# Memory reserved by PyTorch for caching purposes
reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 2)  # Convert to MB
print(f"Reserved Memory: {reserved_memory:.2f} MB")

Allocated Memory: 0.00 MB
Reserved Memory: 0.00 MB


In [5]:
import logging
import traceback

from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
from PIL import Image
import requests

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device.upper()}")

# Define model ID
model_id = "YouLiXiya/tinyllava-v1.0-1.1b-hf"

# Load the model
try:
    logging.info(f"Loading LLava model: {model_id}")
    model = LlavaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    ).to(device)
    logging.info("LLava model loaded successfully.")
except Exception as e:
    logging.error(f"Error loading LLava model: {e}")
    logging.error(traceback.format_exc())

# Load the processor
try:
    logging.info(f"Loading processor for model: {model_id}")
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    logging.info("Processor loaded successfully.")
except Exception as e:
    logging.error(f"Error loading processor: {e}")
    logging.error(traceback.format_exc())


2024-10-03 15:12:21,132 - INFO - Using device: CUDA
2024-10-03 15:12:21,134 - INFO - Loading LLava model: YouLiXiya/tinyllava-v1.0-1.1b-hf
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
2024-10-03 15:12:25,086 - INFO - LLava model loaded successfully.
2024-10-03 15:12:25,087 - INFO - Loading processor for model: YouLiXiya/tinyllava-v1.0-1.1b-hf
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2024-10-03 15:12:26,983 - INFO - Processor loaded successfully.


In [6]:
print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)
print("Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA Device")

PyTorch Version: 2.1.2+cu118
CUDA Available: True
CUDA Version: 11.8
Device Name: NVIDIA GeForce GTX 980


**Note:** Replace `'/path/to/llava-13b-v0'` with the actual path to your LLava model.

### **Define the LLava Interaction Function**

In [16]:
import logging
import traceback
import gc
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import os

# log setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# primary function
def llava_classify_and_respond_all(model, processor, image_path, questions, device="cuda"):
    record = {} 
    try:
        logging.info(f"Processing image: {image_path}")
        
        # load image
        image = Image.open(image_path).convert('RGB')
        logging.info("Image loaded successfully.")

        # define classifcation prompt
        classification_prompt = (
            "USER: <image>\n"
            "What is the english name of the species of bird that is most similar to the one in the image?\n"
            "ASSISTANT:"
        )
        logging.info(f"Using classification prompt:\n{classification_prompt}")  # **Change 1**

        # process prompt with image
        classification_inputs = processor(
            classification_prompt, image, return_tensors="pt"
        ).to(device)
        logging.info("Classification inputs processed successfully.")

        # generate
        classification_ids = model.generate(
            **classification_inputs,
            max_new_tokens=200,
            do_sample=False
        )

        # decode 
        classification_text = processor.decode(
            classification_ids[0], skip_special_tokens=True
        ).strip()
        logging.info(f"Generated classification:\n{classification_text}")

        # extract only assistant's answer
        if "ASSISTANT:" in classification_text:
            predicted_class = classification_text.split("ASSISTANT:")[-1].strip()
        else:
            predicted_class = classification_text
        logging.info(f"Predicted classification: {predicted_class}")
        record['Predicted Class'] = predicted_class

        # answer each question with the context of the predicted class
        for i, question in enumerate(questions, 1):
            # Construct the prompt for the current question
            qa_prompt = (
                f"USER: <image>\n"
                f"What is the english name of the species of bird that is most similar to the one in the image?\n"
                f"ASSISTANT: {predicted_class}\n"
                f"Question {i}: {question}\n"
                f"ASSISTANT:"
            )
            logging.info(f"Using prompt for Q{i}:\n{qa_prompt}")

            # process prompt with the image
            qa_inputs = processor(
                qa_prompt, image, return_tensors="pt"
            ).to(device)
            logging.info(f"Inputs for Q{i} processed successfully.")

            # generate
            answer_ids = model.generate(
                **qa_inputs,
                max_new_tokens=200,
                do_sample=False
            )

            # decode
            answer_text = processor.decode(
                answer_ids[0], skip_special_tokens=True
            ).strip()
            logging.info(f"Generated response for Q{i}:\n{answer_text}")  # **Change 4**

            # extract only assistant's answer
            if "ASSISTANT:" in answer_text:
                answer_clean = answer_text.split("ASSISTANT:")[-1].strip()
            else:
                answer_clean = answer_text
            logging.info(f"Cleaned Generated Answer for Q{i}:\n{answer_clean}")  # **Change 5**

            # add answer to record
            record[f'Answer {i}'] = answer_clean

            # log Q&A
            logging.info(f"Question {i}: {question}")  # **Change 6**
            logging.info(f"Answer {i}: {answer_clean}\n")  # **Change 7**

        return record

    except Exception as e:
        logging.error(f"Error processing {os.path.basename(image_path)}: {e}")
        logging.error(traceback.format_exc())
        return None

    finally:
        # safety for UnboundLocalError
        variables = [
            'image', 'classification_inputs', 'classification_ids',
            'classification_text', 'qa_inputs', 'answer_ids', 'answer_text'
        ]
        for var in variables:
            if var in locals():
                del locals()[var]
        gc.collect()
        torch.cuda.empty_cache()


### **Define the 10 Explainability Questions**

In [14]:
# Define the 10 Explainability Questions with the initial classificaiton question as the first question
questions = [
    "What characteristics make you think that this is the species you identified?",
    "Which physical features distinguish this species from other similar species?",
    "Can you describe any unique color patterns or markings that helped in your identification?",
    "What habitat or environment is this species typically found in, and does the image reflect that?",
    "Are there any behaviors or poses characteristic of this species evident in the image?",
    "How confident are you in your identification, and what factors contribute to your confidence level?",
    "Could this animal be mistaken for another species? If so, which one and why?",
    "What anatomical features (e.g., beak shape, fin structure) were most significant in your identification?",
    "Does the size or scale of the animal in the image influence your identification? How?",
    "Explain the step-by-step reasoning process you used to determine the species."
]


**Note:** Replace `'path_to_your_images_directory'` with the actual path to your images.

### **Set Up Image Directory**

### **Process images and collect to csv**

In [25]:
from tqdm import tqdm  # For progress bar
import logging
import traceback
import os
import pandas as pd

#directory containing images and output csv file path
image_dir = 'images' 
csv_file = os.path.abspath('llava_responses.csv')

# headers
csv_headers = ['Original File Name', 'Actual Class', 'Predicted Class'] + questions

# initialize the CSV with headers if it doesn't exist
file_exists = os.path.isfile(csv_file)
if not file_exists:
    df_init = pd.DataFrame(columns=csv_headers)
    df_init.to_csv(csv_file, index=False, encoding='utf-8-sig')
    logging.info(f"CSV headers written to {csv_file}")

# loop through each subfolder (bird species) in the images directory
for subdir in os.listdir(image_dir):
    subdir_path = os.path.join(image_dir, subdir)
    if os.path.isdir(subdir_path):
        # derive class from parent directory name e.g. (Cattle-Egret -> Cattle Egret)
        actual_class = subdir.replace('-', ' ').replace('_', ' ').strip()
        logging.info(f"Processing subdirectory: {subdir_path} as class: {actual_class}")

        # get images
        image_files = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
        logging.info(f"Found {len(image_files)} images in {subdir_path}")

        # main iteration using iterate using tqdm
        for filename in tqdm(image_files, desc=f"Processing {actual_class}"):
            try:
                original_path = os.path.join(subdir_path, filename)

                # CALL PRIMARY FUNCTION
                record = llava_classify_and_respond_all(
                    model=model,
                    processor=processor,
                    image_path=original_path,
                    questions=questions,
                    device=device
                )

                if record is None:
                    logging.warning(f"Skipping image due to processing error: {filename}")
                    continue  # Skip this image due to processing error

                # partially prepare response data
                row = {
                    'Original File Name': filename,
                    'Actual Class': actual_class,
                    'Predicted Class': record.get('Predicted Class', '')
                }

                # add remaining answers from question asnwers
                for i, question in enumerate(questions, 1):
                    answer_key = f'Answer {i}'
                    row[question] = record.get(answer_key, '')

                # convert to df and append
                df_row = pd.DataFrame([row])
                df_row.to_csv(csv_file, mode='a', header=False, index=False, encoding='utf-8-sig')
                logging.info(f"Written data to CSV for image: {filename}")

            except Exception as e:
                logging.error(f"Error processing {filename}: {e}")
                logging.error(traceback.format_exc())


2024-10-03 15:42:01,943 - INFO - CSV headers written to C:\Users\Roshan\Documents\GitHub\explainability\llava_responses.csv
2024-10-03 15:42:01,945 - INFO - Processing subdirectory: images\Asian-Green-Bee-Eater as class: Asian Green Bee Eater
2024-10-03 15:42:01,946 - INFO - Found 10 images in images\Asian-Green-Bee-Eater
Processing Asian Green Bee Eater:   0%|                                                         | 0/10 [00:00<?, ?it/s]2024-10-03 15:42:01,949 - INFO - Processing image: images\Asian-Green-Bee-Eater\Brown-Headed-Barbet_1.jpg
2024-10-03 15:42:01,955 - INFO - Image loaded successfully.
2024-10-03 15:42:01,956 - INFO - Using classification prompt:
USER: <image>
What is the english name of the species of bird that is most similar to the one in the image?
ASSISTANT:
2024-10-03 15:42:02,000 - INFO - Classification inputs processed successfully.
2024-10-03 15:42:04,203 - INFO - Generated classification:
USER:  
What is the english name of the species of bird that is most sim

KeyboardInterrupt: 

### **Test process single image**

In [17]:
import os

image_dir = 'images' 

test_image_path = os.path.join(image_dir, 'Cattle-Egret', 'Cattle-Egret_1.jpg')  # Update as needed

if not os.path.exists(test_image_path):
    print(f"Test image not found at {test_image_path}")
else:
    test_record = llava_classify_and_respond_all(
        model=model,
        processor=processor,
        image_path=test_image_path,
        questions=questions,
        device=device
    )
    
    if test_record:
        test_record['Original File Name'] = 'Cattle-Egret_1.jpg'
        test_record['Actual Class'] = 'Cattle Egret'
        test_record['Randomized File Name'] = 'TestRandomName.jpg'  # Placeholder, not used further

        ordered_test_record = {
            'Original File Name': test_record.pop('Original File Name'),
            'Randomized File Name': test_record.pop('Randomized File Name'),
            'Actual Class': test_record.pop('Actual Class'),
            'Predicted Class': test_record.pop('Predicted Class')
        }
        # add answers to ordered record
        for i in range(1, 11):
            ordered_test_record[f'Answer {i}'] = test_record.pop(f'Answer {i}')

        # print Q&A
        print(f"Predicted Class: {ordered_test_record['Predicted Class']}\n")
        for i in range(1, 11):
            question = questions[i-1]
            answer = ordered_test_record[f'Answer {i}']
            print(f"Question {i}: {question}")
            print(f"Answer {i}: {answer}\n")
    else:
        print("Failed to process the test image.")


2024-10-03 15:18:34,987 - INFO - Processing image: images\Cattle-Egret\Cattle-Egret_1.jpg
2024-10-03 15:18:34,998 - INFO - Image loaded successfully.
2024-10-03 15:18:34,999 - INFO - Using classification prompt:
USER: <image>
What is the english name of the species of bird that is most similar to the one in the image?
ASSISTANT:
2024-10-03 15:18:35,041 - INFO - Classification inputs processed successfully.
2024-10-03 15:18:36,858 - INFO - Generated classification:
USER:  
What is the english name of the species of bird that is most similar to the one in the image?
ASSISTANT: The bird in the image is most similar to the white egret.
2024-10-03 15:18:36,859 - INFO - Predicted classification: The bird in the image is most similar to the white egret.
2024-10-03 15:18:36,860 - INFO - Using prompt for Q1:
USER: <image>
Please classify the image below, what species of bird is it?
ASSISTANT: The bird in the image is most similar to the white egret.
Question 1: What characteristics make you thi

Predicted Class: The bird in the image is most similar to the white egret.

Question 1: What characteristics make you think that this is the species you identified?
Answer 1: The bird has a long neck, a white body, and a yellow beak. These features are characteristic of the white egret.

Question 2: Which physical features distinguish this species from other similar species?
Answer 2: The white egret has a long neck and a long beak, which are distinctive features that distinguish it from other similar species.

Question 3: Can you describe any unique color patterns or markings that helped in your identification?
Answer 3: The bird has a white body with a yellow head, which is a distinctive feature. Additionally, it has a long neck and legs, which are also characteristic of the white egret.

Question 4: What habitat or environment is this species typically found in, and does the image reflect that?
Answer 4: The bird is typically found in grassy areas, such as fields and meadows, where 

### **Test the text generation pipeline independently**

In [55]:
# test
test_prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"
test_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"

try:
    test_image = Image.open(requests.get(test_image_url, stream=True).raw).convert('RGB')
    logging.info("Test image loaded successfully.")
    
    prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"
    logging.info(f"Using test prompt: {prompt}")
    
    inputs = processor(prompt, test_image, return_tensors="pt").to(device)
    logging.info("Test inputs processed successfully.")
    
    output_ids = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=False
    )
    generated_text = processor.decode(output_ids[0], skip_special_tokens=True).strip()
    logging.info(f"Generated response: {generated_text}")
    
    # extract answer
    if generated_text.startswith(prompt):
        answer = generated_text.replace(prompt, '').strip()
    else:
        answer = generated_text
    logging.info(f"Cleaned Generated Answer: {answer}")
    
    print(f"Generated Answer: {answer}")
    
except Exception as e:
    print(f"Error in LLava pipeline test: {e}")
    logging.error(traceback.format_exc())


In [14]:
# import pprint
# pprint.pprint(ordered_test_record)

In [27]:
import pandas as pd
df = pd.read_csv('llava_responses.csv')


In [28]:
print(df.head())


           Original File Name           Actual Class  \
0   Brown-Headed-Barbet_1.jpg  Asian Green Bee Eater   
1  Brown-Headed-Barbet_12.jpg  Asian Green Bee Eater   
2  Brown-Headed-Barbet_14.jpg  Asian Green Bee Eater   
3  Brown-Headed-Barbet_16.jpg  Asian Green Bee Eater   
4  Brown-Headed-Barbet_25.jpg  Asian Green Bee Eater   

                                     Predicted Class  \
0  The bird in the image is most similar to the h...   
1  The bird in the image is most similar to the g...   
2  The bird in the image is most similar to the g...   
3  The bird in the image is most similar to the g...   
4  The bird in the image is most similar to the g...   

  What characteristics make you think that this is the species you identified?  \
0  The bird is perched on a branch, and it has a ...                             
1  The bird has a green and yellow coloration, wh...                             
2  The bird has a long tail and is perched on a b...                            