In [None]:
from datasets import load_dataset,Dataset,load_from_disk,Image
from collections import defaultdict
import random
import os
import tqdm
import torch
from sae_lens.activation_visualization import load_llava_model,load_sae,generate_with_saev



In [None]:
# please replace these path to your own path
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path = ""
device = "cuda:0"
sae_device = "cuda:7"
sae_path = ""
dataset_path = ""
columns_to_read = ["input_ids", "pixel_values", "attention_mask", "image_sizes"]

save_path = ""


(
        processor,
        hook_language_model,
) = load_llava_model(MODEL_NAME, model_path, device,n_devices=8)

sae = load_sae(sae_path, sae_device)





In [3]:
example_prompt = """You are provided with an image and a list of 10 possible labels. Your task is to classify the image by selecting the most appropriate label from the list below:

Labels:
0: "bonnet, poke bonnet"
1: "green mamba"
2: "langur"
3: "Doberman, Doberman pinscher"
4: "gyromitra"
5: "Saluki, gazelle hound"
6: "vacuum, vacuum cleaner"
7: "window screen"
8: "cocktail shaker"
9: "garden spider, Aranea diademata"

Carefully analyze the content of the image and identify which label best describes it. Then, output only the **corresponding number** from the list without any additional text or explanation.
"""
local_dataset = load_from_disk(dataset_path)



In [None]:
print(local_dataset['label'])
print(local_dataset['image'][0])

In [None]:
from PIL import Image
system_prompt= " "
user_prompt= 'USER: \n<image> {input}'
assistant_prompt= '\nASSISTANT: {output}'
def prepare_inputs(prompt,image,processor):
    image = image.resize((336, 336)).convert('RGBA')
    formatted_prompt = f'{system_prompt}{user_prompt.format(input=prompt)}{assistant_prompt.format(output="")}'
    text_input = processor.tokenizer(formatted_prompt, return_tensors="pt")
    image_input = processor.image_processor(images=image, return_tensors="pt")
    return {
        "input_ids": text_input["input_ids"],
        "attention_mask": text_input["attention_mask"],
        "pixel_values": image_input["pixel_values"],
        "image_sizes": image_input["image_sizes"],
    }

input_list= []
image_path=""
image=Image.open(image_path)
inputs=prepare_inputs(example_prompt,image,processor)

print(inputs["input_ids"].shape)

In [None]:

cosi_file_path = "" 


osi_list = []
with open(cosi_file_path, "r") as file:
    for line in file:
        key, value = line.strip().split(",") 
        
        osi_list.append((int(key), float(value)))

print(osi_list)

In [None]:

output_list = []
image_list = []
with tqdm.tqdm(total=len(input_list)) as pbar:
    for inputs,data in zip(input_list,local_dataset):
        inputs={
            "input_ids": inputs["input_ids"].to(device),
            "attention_mask": inputs["attention_mask"].to(device),
            "pixel_values": inputs["pixel_values"].to(device),
            "image_sizes": inputs["image_sizes"].to(device),
        }
        total_activation_l0_norms_list,patch_features_list,feature_act_list,image_indice,output=generate_with_saev(
                inputs, hook_language_model, processor, save_path, data["image"], sae, sae_device,max_new_tokens=1,selected_feature_indices=osi_list,
            )
        output_list.append(output[-1])
        image_list.append({"image":data["image"].resize((336, 336)),"label":data["label"],"activation_l0":total_activation_l0_norms_list[0]})
        pbar.update(1)

   

In [None]:
print(image_list[0].keys())

In [None]:
from datasets import Dataset, Features, Array3D, Array2D,ClassLabel, Value
import numpy as np
import datasets
raw_data = []
for idx, entry in enumerate(image_list):
    img_array = np.array(entry['image'])
    label = entry['label']
    activation = np.array(entry['activation_l0'], dtype=np.float32)

    if img_array.ndim == 2:
        if img_array.shape == (336, 336):
            img_array = np.stack([img_array]*3, axis=-1)
        else:
            raise ValueError(f"Image at index {idx} has unexpected shape {img_array.shape}")
    elif img_array.shape == (336, 336, 4):
        img_array = img_array[..., :3]

    if img_array.shape != (336, 336, 3):
        raise ValueError(f"Image at index {idx} has shape {img_array.shape}, expected (336, 336, 3)")
    if img_array.dtype != np.uint8:
        img_array = img_array.astype(np.uint8)

    if activation.ndim == 1:
        if activation.size == 24*24:
            activation = activation.reshape((24, 24))
        else:
            raise ValueError(f"Activation at index {idx} has size {activation.size}, cannot reshape to (24,24)")
    elif activation.shape != (24, 24):
        raise ValueError(f"Activation at index {idx} has shape {activation.shape}, expected (24,24)")

    if activation.dtype != np.float32:
        activation = activation.astype(np.float32)

    raw_data.append({"image": img_array, "label": label, "activation_l0": activation})

raw_data_dict = {
    "image": [entry["image"] for entry in raw_data],
    "label": [entry["label"] for entry in raw_data],
    "activation_cosi": [entry['activation_l0'] for entry in raw_data] 
}


label_names = sorted(list(set(raw_data_dict["label"])))  

features = Features({
    "image": datasets.Image(), 
    "label": ClassLabel(names=label_names),             
    "activation_cosi": Array2D(dtype="float32", shape=(24, 24))   
})


raw_hf_dataset = Dataset.from_dict(raw_data_dict, features=features)


print(raw_hf_dataset)

raw_hf_dataset.save_to_disk("imagenet10_activated_cosi_dataset")

In [None]:
total=1000
acc=0
for i,output in enumerate(output_list):
    if output == str(local_dataset[i]['label']):
        acc+=1
print(acc)
print(f"accuracy:{acc/total}")

In [None]:
print(image_list[0]['activation'].shape)

In [12]:

import numpy as np
from PIL import Image
from IPython.display import display
import datasets
from datasets import Dataset, Features, Array3D, ClassLabel
import numpy as np




In [None]:

threshold_percentile = 25  
mask_image_list = []

for item in image_list:
    img = item['image']
    activation_mask = item['activation_l0']
    
    if isinstance(img, Image.Image):
        img = np.array(img)
    
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)  
    patch_count = 24
    patch_size = 336 // patch_count
    activation_mask = np.array(activation_mask).reshape(patch_count, patch_count)
    
    threshold = np.percentile(activation_mask, threshold_percentile)
    
    for i in range(patch_count):
        for j in range(patch_count):
            if activation_mask[i, j] <= threshold:
                img[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size, :] = 0

    mask_image_list.append({"image":Image.fromarray(img),"label":item['label']})
    
data = []
for entry in mask_image_list:
    img_array = np.array(entry['image'])  
    label = entry['label']             
    data.append({"image": img_array, "label": label})

data_dict = {
    "image": [entry["image"] for entry in data],
    "label": [entry["label"] for entry in data]
}

features = Features({
    "image": datasets.Image(),  
    "label": ClassLabel(names=list(set(data_dict["label"])))  
})

hf_dataset = Dataset.from_dict(data_dict, features=features)

print(hf_dataset)
hf_dataset.save_to_disk(f"imagenet10_cosi_mask{100-threshold_percentile}_dataset")


In [None]:
display(mask_image_list[0]['image'])

In [None]:
from PIL import Image
import numpy as np
from datasets import load_from_disk, Dataset, Features, ClassLabel as DsImage
import datasets
data = load_from_disk("")

def sample_patches(example, keep_ratio):
    image = example["image"]
    if not isinstance(image, Image.Image):
        image = Image.fromarray(np.uint8(image))
    image = image.resize((336,336)).convert('RGB') 

    image = np.array(image) 
    h, w, c = image.shape
    patch_size = 14
    num_patches_h = h // patch_size
    num_patches_w = w // patch_size
    total_patches = num_patches_h * num_patches_w

    keep_count = int(total_patches * keep_ratio)
    keep_indices = set(np.random.choice(total_patches, keep_count, replace=False))

    new_image = image.copy()
    patch_index = 0
    for i in range(num_patches_h):
        for j in range(num_patches_w):
            if patch_index not in keep_indices:
                new_image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size, :] = 0
            patch_index += 1
    
    new_image = Image.fromarray(new_image.astype(np.uint8))
    return new_image

keep_ratio = 0.75  
new_images = []
new_labels = [] 

for i in range(len(data)):
    example = data[i]
    transformed_image = sample_patches(example, keep_ratio=keep_ratio)
    transformed_array = np.array(transformed_image)
    new_images.append(transformed_array)
    new_labels.append(example["label"])

features = Features({
    "image": datasets.Image(),
    "label": data.features["label"]  
})

new_dataset = Dataset.from_dict({"image": new_images, "label": new_labels}, features=features)

new_dataset.save_to_disk("")


In [None]:
#119 206 633 841
from PIL import Image
import numpy as np
from datasets import load_from_disk, Dataset, Features, ClassLabel as DsImage
import datasets

data=load_from_disk("")
display(data[841]["image"])