# Rodrigo Barraza's Inscriptions: Blip 2 Mass Captioning
Large RAM and VRAM is required to load the larger models. RAM should be at least 24-32GB with 64GB being optimal. VRAM should be at least 16GB or more.

In [None]:
# !pip3 install transformers

In [None]:
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, Blip2Model
import torch
import sys
import validators
# from lavis.models import load_model_and_preprocess

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Salesforce/blip2-opt-2.7b
# Salesforce/blip2-opt-2.7b-coco
# Salesforce/blip2-opt-6.7b
# Salesforce/blip2-opt-6.7b-coco
# Salesforce/blip2-flan-t5-xl
# Salesforce/blip2-flan-t5-xl-coco
# /Salesforce/blip2-flan-t5-xxl

tokenizer = AutoTokenizer.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco")
    # add_prefix_space=True) # Required to use bad_words_ids
processor = Blip2Processor.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco",
    device_map='auto',
    # load_in_8bit=True)
    torch_dtype=torch.float16)

In [None]:
###############################################################################
# Start of Options
# imagesDirectory = "/mnt/d/dataset-1080"
# useFolderNamesAsTokens = True  # Append the folder names to the beginning of the caption
# overwriteExistingCaptions = False  # Overwrite existing captions
# tokensStartOrEnd = 'start'  # end or start
# useNucleusSampling = False
# appendStyles = True
# showImages = False
image_url_or_path = "/mnt/d/fixes/0033836.jpg"

# PREPROCESSOR SETTINGS
skipSpecialTokens = True

# MODEL GENERATION SETTINGS
useNucleusSampling = False
numberOfBeams = 1  # The number of beams to use for beam search
minTokenLength = 15  # The amount of minimum tokens to generate
maxTokenLength = 15  # The maximum amount of tokens to generate
repetitionPenalty = 1.0
lengthPenalty = 1.0
topP = 0.9
temperature = 1.0

# TOKENIZER SETTINGS
enableForceWords = False
forceWordsList = ["a kid named turtle"]
removeBadWords = False
badWordsList = ["teddy bear"]
padding = True
add_special_tokens = False


# prompt = "Describe the style in 1 word. Answer:"
prompt = None
# End of Options
###############################################################################

forceWordsIds = None
badWordsIds = None

if (validators.url(image_url_or_path)):
    image = Image.open(
        requests.get(
        image_url_or_path,
        stream=True).raw).convert('RGB')
else:
    image = Image.open(image_url_or_path).convert('RGB')

if enableForceWords:
    forceWords = tokenizer(
        forceWordsList,
        padding=padding,
        add_special_tokens=add_special_tokens,
        return_tensors="pt").to(device).input_ids
    forceWordsIds = forceWords.tolist()  # Convert the tensor to a list of lists

if removeBadWords:
    badWords = tokenizer(
        badWordsList,
        padding=padding,
        add_special_tokens=add_special_tokens,
        return_tensors="pt").to(device).input_ids
    badWordsIds = badWords.tolist()  # Convert the tensor to a list of lists

# bad_words = tokenizer(["teddy bear"], padding=True, add_special_tokens=False, return_tensors="pt").to(device).input_ids
# bad_words_list = bad_words.tolist()  # Convert the tensor to a list of lists

inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(
    **inputs,
    do_sample=useNucleusSampling,
    force_words_ids=forceWordsIds,
    bad_words_ids=badWordsIds,
#    bad_words_ids=bad_words_list,
    num_beams=numberOfBeams,
    max_length=maxTokenLength,
    min_length=minTokenLength,
    repetition_penalty=repetitionPenalty,
    length_penalty=lengthPenalty,
    top_p=topP,
    # num_return_sequences=5,
    temperature=temperature)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=skipSpecialTokens)[0].strip()
display(image.resize(( int(image.width * 0.333), int(image.height * 0.333))))
print("Caption: ", generated_text)

if prompt:
    inputs = processor(images=image, text="Describe the style in 1 word. Answer:", return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        max_length=10,
        min_length=7,
        do_sample=useNucleusSampling,
        force_words_ids=forceWordsIds,
        bad_words_ids=badWordsIds,
    #    bad_words_ids=bad_words_list,
        num_beams=numberOfBeams,
        repetition_penalty=repetitionPenalty,
        length_penalty=lengthPenalty,
        top_p=topP,
        temperature=temperature)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    print(prompt)
    print(generated_text)

# Captioner

In [None]:
# url = "https://huggingface.co/spaces/Salesforce/BLIP2/resolve/main/house.png"
# image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

# prompt = "Question: How could someone get out of the house? Answer:"
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

# print(generated_text)


import validators

url = '/mnt/d/dataset-1080/_/0014596.jpg'

if (validators.url(url)):
    image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
else:
    image = Image.open(url).convert('RGB')  

display(image.resize((596, 437)))
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
# Greedy Decoding (default)
print("###Greedy###")
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
# Beam Search Decoding
print("###Beam Search###")
generated_ids = model.generate(**inputs,
                               num_beams=5,
                               max_length=20,
                               min_length=15,
                               repetition_penalty=1.0,
                               length_penalty=1.0,
                               top_p=0.9,
                               temperature=1)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)
# Nucleus Sampling Decoding
print("###Nucleus Sampling###")
generated_ids= model.generate(**inputs, do_sample=True, top_p=0.9)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

t = GPT2Tokenizer.from_pretrained('gpt2')
force_words = t("planet", add_prefix_space=True, add_special_tokens=False).input_ids


# image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
# model.generate({"image": image}, use_nucleus_sampling=False, num_captions=1, min_length=15, max_length=20)



# prompt = "Question: Describe the style in 1 word? Answer:"
# inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs, num_beams=2)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# print(generated_text)


#### Load BLIP2 captioning model

In [None]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# we associate a model with its preprocessors to make it easier for inference.
model, vis_processors, _ = load_model_and_preprocess(
    # name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
    # name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device
    # name="blip2_opt", model_type="caption_coco_opt2.7b", is_eval=True, device=device
    name="blip2_opt", model_type="caption_coco_opt6.7b", is_eval=True, device=device
    # name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device
    # name="blip2_t5", model_type="caption_coco_flant5xl", is_eval=True, device=device
    # This next model is one scary devil in terms of size...
    # ... it requires at least 32GB of VRAM to run...
    # ... and will not load on 3090s or 4090s.
    # name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device
)

vis_processors.keys()

#### Auto Caption

In [None]:
import os
import re
from pathlib import Path
from collections import OrderedDict
from IPython.display import clear_output
from PIL import Image
###############################################################################
# Start of Options
imagesDirectory = "/mnt/d/dataset-1080"
useFolderNamesAsTokens = True  # Append the folder names to the beginning of the caption
overwriteExistingCaptions = False  # Overwrite existing captions
tokensStartOrEnd = 'start'  # end or start
minTokenLength = 15  # The amount of minimum tokens to generate
maxTokenLength = 20  # The maximum amount of tokens to generate

useNucleusSampling = False
repetitionPenalty = 1

appendStyles = True

showImages = False
# End of Options
###############################################################################

numberOfCaptions = 1  # How many captions to generate
# Count the total number of images in the directory and subdirectories
totalImages = 0
for dirpath, dirnames, filenames in os.walk(imagesDirectory):
    totalImages += sum([filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp", "gif")) for filename in filenames])

processedImages = 0


def process_images(dirpath):
    global processedImages
    suspects = os.listdir(dirpath)
    imageSuspects = [filename for filename in suspects if filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp", "gif"))]

    # Process each image
    for suspectIndex in range(len(imageSuspects)):
        processedImages += 1
        remainingImages = totalImages - processedImages
        caption = ""
        print(f"Processed images: {processedImages}/{totalImages}")
        print(f"Remaining images: {remainingImages}")

        # Load Image
        imagePath = imageSuspects[suspectIndex]
        imageFilePath = dirpath + "/" + imagePath
        textFilePath = Path(imageFilePath).with_suffix('.txt')

        # If the image hasn't already been processed, caption it
        if overwriteExistingCaptions or not os.path.exists(textFilePath):
            rawImage = Image.open(imageFilePath).convert('RGB')
            # Display the image as it's been processed
            if showImages:
                display(rawImage)
            image = vis_processors["eval"](rawImage).unsqueeze(0).to(device)
            imageCaption = model.generate({"image": image}, min_length=minTokenLength, max_length=maxTokenLength, use_nucleus_sampling=useNucleusSampling, num_captions=numberOfCaptions, repetition_penalty=repetitionPenalty)

            modifiedCaption = imageCaption[0]
            captionWords = modifiedCaption.split()

            # Fix grammatical spelling errors by BLIP2
            if "laying" in captionWords:
                modifiedCaption = modifiedCaption.replace('laying', 'lying')

            # Append/Prepend folder names to the caption
            relpath = os.path.relpath(dirpath, imagesDirectory)
            relpathParts = [part for part in relpath.split(os.sep) if "_" not in part and part != "."]
            validParts = [part for part in relpathParts if part.lower() not in captionWords and part.lower() not in modifiedCaption]
            if useFolderNamesAsTokens and validParts:
                if tokensStartOrEnd == 'end':
                    caption = f"{modifiedCaption}, {', '.join(validParts)}"
                else:
                    caption = f"{', '.join(validParts)}, {modifiedCaption}"

            else:
                caption = imageCaption[0]

            # Append answers to the caption
            if appendStyles:
                style = model.generate({"image": image, "prompt": "Describe the style in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                theme = model.generate({"image": image, "prompt": "Describe the theme in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                background = model.generate({"image": image, "prompt": "Describe object in the background in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                medium = model.generate({"image": image, "prompt": "Describe the medium in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                color = model.generate({"image": image, "prompt": "Describe the color in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                person = model.generate({"image": image, "prompt": "Describe the person in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                outfit = model.generate({"image": image, "prompt": "Describe the outfit in 1 word. Answer:"}, use_nucleus_sampling=False, num_captions=1, min_length=5, max_length=15)
                
                combined = OrderedDict()
                combinedLists = person + outfit + background + style + theme + medium + color
                answers = [item.strip() for sublist in combinedLists for item in re.split(',|and', sublist) if item.strip()]

                for answer in answers:
                    lowerWord = answer.lower().lstrip()

                    if '_' in lowerWord or '??' in lowerWord or '!!' in lowerWord:
                        lowerWord = ''

                    if 'f**k' in lowerWord:
                        lowerWord.replace('f**k', 'fuck')

                    if lowerWord.startswith(('a ', 'the ', 'and ')):
                        lowerWord = lowerWord.split(' ', 1)[1]

                    if lowerWord.endswith(('.', ',', '!', '?')):
                        lowerWord = lowerWord[:-1]

                    if len(lowerWord) > 1:
                        combined[lowerWord] = None

                uniqueCombinedArray = list(combined)

                uniqueImageAnswers = set(answer.lower() for answer in uniqueCombinedArray)
                filteredImageAnswers = [ans for ans in uniqueImageAnswers if not re.search(rf'\b{re.escape(ans)}\b', caption.lower())]

                if filteredImageAnswers:
                    caption += ', ' + ', '.join(filteredImageAnswers)

            # Remove periods
            caption.replace('.', '')

            # Save Caption as .txt file
            with open(textFilePath, 'w+') as f:
                f.write(caption)

        clear_output(wait=True)
        print(caption)
        print(imageFilePath)
                     
# Iterate through directories inside directories
for dirpath, dirnames, filenames in os.walk(imagesDirectory):
    process_images(dirpath)