In [None]:
!pip install datasets
!pip install transformers
!pip install bitsandbytes
!pip install accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes

In [None]:
!pip install --upgrade bitsandbytes

In [None]:
import torch
from datasets import load_dataset
from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor

In [None]:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b",
                                                             load_in_4bit=True,
                                                             torch_dtype = torch.bfloat16)

In [None]:
processor = InstructBlipProcessor.from_pretrained(
    "Salesforce/instructblip-vicuna-7b",
)

In [None]:
datasets = [
    ("detection-datasets/fashionpedia", None, "val"),
    ("keremberke/nfl-object-detection", "mini", "test"),
    ("keremberke/plane-detection", "mini", "train"),
    ("Matthijs/snacks", None, "validation"),
    ("rokmr/mini_pets", None, "test"),
    ("keremberke/pokemon-classification", "mini", "train"),
]

In [None]:
prompt1 = "describe this image in full detail. describe each and every aspect of the image so that an artist could re create the image"
prompt2 = "create an extensive description of this image"

In [None]:
counter = 0
for name, config, split in datasets:
  d = load_dataset(name, config, split = split)
  for idx in range(len(d)):
    image = d[idx]["image"]
    desc = ""
    for _prompt in [prompt1, prompt2]:
      inputs = processor(
          images = image,
          text = _prompt,
          return_rensors = "pt"
      ).to(model.device, torch.bfloat16)
      outputs = model.generate(
          **inputs,
          do_sample = False,
          num_beams = 10,
          max_length = 512,
          min_length = 16,
          top_p = 0.9,
          repetition_penalty = 1.5,
          temperature = 1,
      )
      generated_text = processor.batch_decode(
          outputs,
          skip_special_tokens = True,
      )[0].strip()

      desc += generated_text + " "

    desc = desc.strip() #remove \n \t
    image.save(f"images/{counter}.jpg")
    print(counter, desc)

    with open("description.csv", "a") as f:
      f.write(f"{counter}, {desc}\n")

    counter+=1
    torch.cuda.empty_cache()
