In [None]:
!conda env list

In [None]:
!pip install -q timm

In [None]:
import transformers

In [None]:
PROMPTS = {
    "instruction": "Describe the image",
}

In [None]:
MODEL_ID = "Salesforce/instructblip-vicuna-7b"
CONFIG_CLASS = transformers.InstructBlipConfig
MODEL_CLASS = transformers.InstructBlipForConditionalGeneration
PROCESSOR_CLASS = transformers.InstructBlipProcessor

In [None]:
MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

In [None]:
BASE_DIR = "/tmp/akshett.jindal"
SUBJECT = 1
TO_CACHE = False

In [None]:
import os.path

HUGGINGFACE_CACHE_DIR = os.path.join(BASE_DIR, ".huggingface_cache")
OUTPUT_DIR = os.path.join(BASE_DIR, "image_embeddings", MODEL_NAME, f"subject_0{SUBJECT}")
MODEL_CHECKPOINTS_DIR = os.path.join(BASE_DIR, "cached_models", MODEL_NAME)

In [None]:
from os import makedirs

makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

In [None]:
processor = PROCESSOR_CLASS.from_pretrained(
    MODEL_ID,
    cache_dir=HUGGINGFACE_CACHE_DIR,
)

In [None]:
model_config = CONFIG_CLASS.from_pretrained(MODEL_ID)
model_config.output_hidden_states = True
model_config.vision_config.output_hidden_states = True
model_config.qformer_config.output_hidden_states = True

In [None]:
if TO_CACHE:
    model = MODEL_CLASS.from_pretrained(
        MODEL_ID,
        config=model_config,
        cache_dir=HUGGINGFACE_CACHE_DIR,
    )
    model.save_pretrained(MODEL_CHECKPOINTS_DIR)
    del model

In [None]:
with init_empty_weights():
    model = MODEL_CLASS(model_config)

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=MODEL_CHECKPOINTS_DIR,
    device_map="auto",
)

In [None]:
import nsd_dataset.mind_eye_nsd_utils as menutils
from datasets import Dataset

image_ids, images = menutils.get_subject_images(BASE_DIR, SUBJECT)

def data_generator():
    for image_id, image in zip(image_ids, images):
        yield {"id": image_id, "image": image}

dataset = Dataset.from_generator(data_generator)
dataset

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [None]:
BATCH_SIZE = 2

In [None]:
from collections import OrderedDict

def to_cpu(value):
    if isinstance(value, torch.Tensor):
        return value.cpu()
    elif isinstance(value, tuple):
        return tuple(to_cpu(v) for v in value)
    elif isinstance(value, list):
        return list(to_cpu(v) for v in value)
    elif isinstance(value, set):
        return set(list(to_cpu(v) for v in value))
    elif isinstance(value, dict) or isinstance(value, OrderedDict):
        return {
            k: to_cpu(v)
            for k, v in value.items()
        }
    else:
        print('unknown type:', key, value)

In [None]:
from tqdm.auto import tqdm
import pickle
import json

buffer = []

model.eval()

with torch.inference_mode():
    for batch_num, d in tqdm(enumerate(batch(dataset, n=BATCH_SIZE)), total=len(dataset) // BATCH_SIZE):
        inputs = processor(
            images=torch.tensor(d["image"]),
            text=[PROMPTS["instruction"]] * BATCH_SIZE,
            return_tensors="pt",
        )

        outputs = model(**inputs)
        outputs = to_cpu(outputs)

        logits = outputs["logits"]
        vision_outputs = outputs["vision_outputs"]
        qformer_outputs = outputs["qformer_outputs"]
        language_model_outputs = outputs["language_model_outputs"]

        vision_hidden_states = list(vision_outputs["hidden_states"])
        qformer_hidden_states = list(qformer_outputs["hidden_states"])

        buffer.append({
            "image_ids": d["id"],
            "vision_hidden_states": vision_hidden_states,
            "qformer_hidden_states": qformer_hidden_states,
        })

        if len(buffer) == 100:
            batch_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num+1}.pkl")
            with open(batch_file, "wb") as f:
                pickle.dump(buffer, f)
            buffer = []

if len(buffer) > 0:
    batch_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num+1}.pkl")
    with open(batch_file, "wb") as f:
        pickle.dump(buffer, f)

In [None]:
logits = outputs["logits"]
vision_outputs = outputs["vision_outputs"]
qformer_outputs = outputs["qformer_outputs"]
language_model_outputs = outputs["language_model_outputs"]

print(f"{outputs.keys() = }")
print(f"{logits.shape = }")
print(f"{vision_outputs['hidden_states'][0].shape = }")
print(f"{qformer_outputs.keys() = }")
print(f"{language_model_outputs.keys() = }")

# hape = torch.Size([1, 257, 1408])

In [None]:
import pickle

filepath = "/tmp/akshett.jindal/image_embeddings/Salesforce_instructblip-vicuna-7b/subject_01/batch_100.pkl"

with open(filepath, "rb") as f:
    output = pickle.load(f)

output

In [None]:
output[0]['vision_hidden_states'][0].shape

In [None]:
import torch

torch.save(output, "/tmp/akshett.jindal/test.pt")