In [1]:
!conda env list

# conda environments:
#
base                     /home2/akshett.jindal/miniconda3
subba                 *  /home2/akshett.jindal/miniconda3/envs/subba



In [None]:
!pip install -q timm

In [1]:
import transformers

In [2]:
PROMPT = "Describe the image"

In [3]:
MODEL_ID = "Salesforce/instructblip-vicuna-7b"
CONFIG_CLASS = transformers.InstructBlipConfig
MODEL_CLASS = transformers.InstructBlipForConditionalGeneration
PROCESSOR_CLASS = transformers.InstructBlipProcessor

In [4]:
MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

In [5]:
BASE_DIR = "/tmp/akshett.jindal"
SUBJECT = 1
TO_CACHE = False

In [6]:
import os.path

HUGGINGFACE_CACHE_DIR = os.path.join(BASE_DIR, ".huggingface_cache")
OUTPUT_DIR = os.path.join(BASE_DIR, "image_embeddings", MODEL_NAME, f"subject_0{SUBJECT}")
MODEL_CHECKPOINTS_DIR = os.path.join(BASE_DIR, "cached_models", MODEL_NAME)

In [7]:
from os import makedirs

makedirs(OUTPUT_DIR, exist_ok=True)

In [8]:
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

In [9]:
processor = PROCESSOR_CLASS.from_pretrained(
    MODEL_ID,
    cache_dir=HUGGINGFACE_CACHE_DIR,
)

In [10]:
model_config = CONFIG_CLASS.from_pretrained(MODEL_ID)
model_config.output_hidden_states = True
model_config.vision_config.output_hidden_states = True
model_config.qformer_config.output_hidden_states = True

In [11]:
if TO_CACHE:
    model = MODEL_CLASS.from_pretrained(
        MODEL_ID,
        config=model_config,
        cache_dir=HUGGINGFACE_CACHE_DIR,
    )
    model.save_pretrained(MODEL_CHECKPOINTS_DIR)
    del model

In [12]:
with init_empty_weights():
    model = MODEL_CLASS(model_config)

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=MODEL_CHECKPOINTS_DIR,
    device_map="auto",
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


  0%|          | 0/805 [00:00<?, ?w/s]

  0%|          | 0/50 [00:00<?, ?w/s]

  0%|          | 0/54 [00:00<?, ?w/s]

  0%|          | 0/54 [00:00<?, ?w/s]

  0%|          | 0/54 [00:00<?, ?w/s]

  0%|          | 0/54 [00:00<?, ?w/s]

  0%|          | 0/25 [00:00<?, ?w/s]

In [13]:
import nsd_dataset.mind_eye_nsd_utils as menutils
from datasets import Dataset

image_ids, images = menutils.get_subject_images(BASE_DIR, SUBJECT)

def data_generator():
    for image_id, image in zip(image_ids, images):
        yield {"id": image_id, "image": image}

dataset = Dataset.from_generator(data_generator)
dataset

Dataset({
    features: ['id', 'image'],
    num_rows: 10000
})

In [14]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

In [15]:
BATCH_SIZE = 2

In [16]:
from collections import OrderedDict

def to_cpu(value):
    if isinstance(value, torch.Tensor):
        return value.cpu()
    elif isinstance(value, tuple):
        return tuple(to_cpu(v) for v in value)
    elif isinstance(value, list):
        return list(to_cpu(v) for v in value)
    elif isinstance(value, set):
        return set(list(to_cpu(v) for v in value))
    elif isinstance(value, dict) or isinstance(value, OrderedDict):
        return {
            k: to_cpu(v)
            for k, v in value.items()
        }
    else:
        print('unknown type:', key, value)

In [17]:
from tqdm.auto import tqdm
import pickle
import json

buffer = []

model.eval()

with torch.inference_mode():
    for batch_num, d in tqdm(enumerate(batch(dataset, n=BATCH_SIZE)), total=len(dataset) // BATCH_SIZE):
        inputs = processor(
            images=torch.tensor(d["image"]),
            text=[PROMPT] * BATCH_SIZE,
            return_tensors="pt",
        )

        outputs = model(**inputs)
        outputs = to_cpu(outputs)

        logits = outputs["logits"]
        vision_outputs = outputs["vision_outputs"]
        qformer_outputs = outputs["qformer_outputs"]
        language_model_outputs = outputs["language_model_outputs"]

        vision_hidden_states = tuple(
            torch.mean(hs, dim=1).numpy()
            for hs in vision_outputs["hidden_states"]
        )
        qformer_hidden_states = tuple(
            torch.mean(hs, dim=1).numpy()
            for hs in qformer_outputs["hidden_states"]
        )

        buffer.append({
            "image_ids": d["id"],
            "vision_hidden_states": vision_hidden_states,
            "qformer_hidden_states": qformer_hidden_states,
        })

        if len(buffer) == 100:
            batch_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num+1}.pkl")
            with open(batch_file, "wb") as f:
                pickle.dump(buffer, f)
            buffer = []

if len(buffer) > 0:
    batch_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num+1}.pkl")
    with open(batch_file, "wb") as f:
        pickle.dump(buffer, f)

  0%|          | 0/5000 [00:00<?, ?it/s]