# Rodrigo Barraza's Inscriptions: Blip 2 Mass Captioning
Large RAM and VRAM is required to load the larger models. RAM should be at least 24-32GB with 64GB being optimal. VRAM should be at least 16GB or more.

In [None]:
!pip3 install transformers

In [None]:
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoTokenizer, Blip2Model
import torch
import sys
import validators

import os
import re
from pathlib import Path
from collections import OrderedDict
from IPython.display import clear_output
from PIL import Image


# from lavis.models import load_model_and_preprocess

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Salesforce/blip2-opt-2.7b
# Salesforce/blip2-opt-2.7b-coco
# Salesforce/blip2-opt-6.7b
# Salesforce/blip2-opt-6.7b-coco
# Salesforce/blip2-flan-t5-xl
# Salesforce/blip2-flan-t5-xl-coco
# /Salesforce/blip2-flan-t5-xxl

tokenizer = AutoTokenizer.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco")
    # add_prefix_space=True) # Required to use bad_words_ids
processor = Blip2Processor.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-6.7b-coco",
    device_map='auto',
    # load_in_8bit=True)
    torch_dtype=torch.float16)

In [None]:
###############################################################################
# Start of Options
# imagesDirectory = "/mnt/d/dataset-1080"
# useFolderNamesAsTokens = True  # Append the folder names to the beginning of the caption
# overwriteExistingCaptions = False  # Overwrite existing captions
# tokensStartOrEnd = 'start'  # end or start
# useNucleusSampling = False
# appendStyles = True
# showImages = False
imagesDirectory = "/mnt/d/fixes"
useFolderNamesAsTokens = True  # Append the folder names to the beginning of the caption
overwriteExistingCaptions = True  # Overwrite existing captions
tokensStartOrEnd = 'start'  # end or start
appendStyles = True
showImages = False

# PREPROCESSOR SETTINGS
skipSpecialTokens = True

# MODEL GENERATION SETTINGS
useNucleusSampling = False
numberOfBeams = 3  # The number of beams to use for beam search
lengthPenalty = 1
minTokenLength = 5  # The amount of minimum tokens to generate
maxTokenLength = 72  # The maximum amount of tokens to generate
repetitionPenalty = 1
topP = 0.9
temperature = 1.0

# TOKENIZER SETTINGS
enableForceWords = False
forceWordsList = ["water"]
removeBadWords = False
badWordsList = ["teddy bear"]
padding = True
add_special_tokens = False


# prompt = "Describe the style in 1 word. Answer:"
prompt = None
# End of Options
###############################################################################

forceWordsIds = None
badWordsIds = None
# Count the total number of images in the directory and subdirectories
totalImages = 0
for dirpath, dirnames, filenames in os.walk(imagesDirectory):
    totalImages += sum([filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp", "gif")) for filename in filenames])
processedImages = 0

def generateCaption(rawImage):
    global forceWordsIds
    generatedCaption = ''
    if enableForceWords:
        forceWords = tokenizer(
            forceWordsList,
            padding=padding,
            add_special_tokens=add_special_tokens,
            return_tensors="pt").to(device).input_ids
        forceWordsIds = forceWords.tolist()  # Convert the tensor to a list of lists

    if removeBadWords:
        badWords = tokenizer(
            badWordsList,
            padding=padding,
            add_special_tokens=add_special_tokens,
            return_tensors="pt").to(device).input_ids
        badWordsIds = badWords.tolist()  # Convert the tensor to a list of lists

    inputs = processor(images=rawImage, return_tensors="pt").to(device, torch.float16)

    generated_ids = model.generate(
        **inputs,
        do_sample=useNucleusSampling,
        force_words_ids=forceWordsIds,
        # bad_words_ids=badWordsIds,
        num_beams=numberOfBeams,
        max_length=maxTokenLength,
        min_length=minTokenLength,
        repetition_penalty=repetitionPenalty,
        length_penalty=lengthPenalty,
        top_p=topP,
        # num_return_sequences=5,
        temperature=temperature)
    imageCaption = processor.batch_decode(generated_ids, skip_special_tokens=skipSpecialTokens)

    modifiedCaption = imageCaption[0]
    captionWords = modifiedCaption.strip()
    # Fix grammatical spelling errors by BLIP2
    if "laying" in captionWords:
        modifiedCaption = modifiedCaption.replace('laying', 'lying')
    # Append/Prepend folder names to the caption
    relpath = os.path.relpath(dirpath, imagesDirectory)
    relpathParts = [part for part in relpath.split(os.sep) if "_" not in part and part != "."]
    validParts = [part for part in relpathParts if part.lower() not in captionWords and part.lower() not in modifiedCaption]

    if useFolderNamesAsTokens and validParts:
        if tokensStartOrEnd == 'end':
            generatedCaption = f"{modifiedCaption}, {', '.join(validParts)}"
        else:
            generatedCaption = f"{', '.join(validParts)}, {modifiedCaption}"
    else:
        generatedCaption = modifiedCaption.strip()
    print(generatedCaption)
    return cleanUpCaption(generatedCaption)

def generateExtraDescriptors(rawImage, caption):
    generatedExtraDescriptions = ''
    inputs = processor(images=rawImage, text="Describe the style in 1 word. Answer:", return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        do_sample=False,
        num_beams=1,
        max_length=10,
        min_length=1,
        repetition_penalty=repetitionPenalty,
        length_penalty=lengthPenalty,
        top_p=topP,
        num_return_sequences=1,
        temperature=temperature)
    style = processor.batch_decode(generated_ids, skip_special_tokens=skipSpecialTokens)
    inputs = processor(images=rawImage, text="Describe the theme in 1 word. Answer:", return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        do_sample=False,
        num_beams=1,
        max_length=10,
        min_length=1,
        repetition_penalty=repetitionPenalty,
        length_penalty=lengthPenalty,
        top_p=topP,
        num_return_sequences=1,
        temperature=temperature)
    theme = processor.batch_decode(generated_ids, skip_special_tokens=skipSpecialTokens)
    inputs = processor(images=rawImage, text="Describe the medium in 1 word. Answer:", return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        **inputs,
        do_sample=False,
        num_beams=1,
        max_length=10,
        min_length=1,
        repetition_penalty=repetitionPenalty,
        length_penalty=lengthPenalty,
        top_p=topP,
        num_return_sequences=1,
        temperature=temperature)
    medium = processor.batch_decode(generated_ids, skip_special_tokens=skipSpecialTokens)

    combined = OrderedDict()
    combinedLists = style + theme + medium
    answers = [item.strip() for sublist in combinedLists for item in re.split(',|and', sublist) if item.strip()]

    for answer in answers:
        lowerWord = answer.lower().lstrip()

        if '_' in lowerWord or '??' in lowerWord or '!!' in lowerWord or '—' in lowerWord or '~' in lowerWord or '@' in lowerWord or '|' in lowerWord:
            lowerWord = ''

        if lowerWord.startswith(('a ', 'the ', 'and ')):
            lowerWord = lowerWord.split(' ', 1)[1]

        if lowerWord.endswith(('.', ',', '!', '?')):
            lowerWord = lowerWord[:-1]

        if len(lowerWord) > 1:
            combined[lowerWord] = None

    uniqueCombinedArray = list(combined)

    uniqueImageAnswers = set(answer.lower() for answer in uniqueCombinedArray)
    filteredImageAnswers = [ans for ans in uniqueImageAnswers if not re.search(rf'\b{re.escape(ans)}\b', caption.lower())]

    if filteredImageAnswers:
        generatedExtraDescriptions = caption + ', ' + ', '.join(filteredImageAnswers)
    return cleanUpCaption(generatedExtraDescriptions)

def cleanUpCaption(caption):
    cleanedUpCaption = caption
    if " - " in caption:
        cleanedUpCaption = cleanedUpCaption.replace(
            "t - shirt", "t-shirt").replace(
            " - man", "-man").replace(
            " - men", "-men").replace(
            "t - rex", "t-rex").replace(
            "sci - fi", "sci-fi").replace(
            "x - files", "x-files")

    if " - man" in caption:
        cleanedUpCaption = cleanedUpCaption.replace(" - man", "-man")
    if "http" in caption or "www" in caption:
        cleanedUpCaption = cleanedUpCaption.replace("http", "").replace("www", "")
    if "/" in caption:
        cleanedUpCaption = cleanedUpCaption.replace("/", "")
    if  " - stock image" in caption:
        cleanedUpCaption = cleanedUpCaption.replace(" - stock image", "")
    if "sci - fi" in caption:
        cleanedUpCaption = cleanedUpCaption.replace("sci - fi", "sci-fi")
    if "t.v" in caption or "t.v." in caption:
        cleanedUpCaption = cleanedUpCaption.replace("t.v.", "television").replace("t.v", "television")
    if " & " in caption and len(caption) == 5:
        cleanedUpCaption = cleanedUpCaption.replace(" & ", "&")
    if "black" in caption and "white" in caption:
        cleanedUpCaption = cleanedUpCaption.replace("black & white", "black and white")

    if "blanka" in caption:
        cleanedUpCaption = cleanedUpCaption.replace("hulk", "blanka")
        cleanedUpCaption = cleanedUpCaption.replace("green hair", "orange hair")
    return cleanedUpCaption

def captionImages(dirpath):
    global processedImages
    suspects = os.listdir(dirpath)
    imageSuspects = [filename for filename in suspects if filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp", "gif"))]

    # Process each image
    for suspectIndex in range(len(imageSuspects)):
        processedImages += 1
        remainingImages = totalImages - processedImages
        caption = ""
        print(f"Processed images: {processedImages}/{totalImages}")
        print(f"Remaining images: {remainingImages}")

        # Load Image
        imagePath = imageSuspects[suspectIndex]
        imageFilePath = f"{dirpath}/{imagePath}"
        textFilePath = f"{Path(imageFilePath).with_suffix('')}.txt"

        # If the image hasn't already been processed, caption it
        if overwriteExistingCaptions or not os.path.exists(textFilePath):
            rawImage = Image.open(imageFilePath).convert('RGB')
            # Display the image as it's been processed
            if showImages:
                display(rawImage)
                # display(image.resize(( int(image.width * 0.333), int(image.height * 0.333))))

            caption = generateCaption(rawImage)
            
            if appendStyles:
                caption = generateExtraDescriptors(rawImage, caption)

            # Remove periods
            # caption.replace('.', '')
            # caption.replace('"', '')

            # Save Caption as .txt file
            with open(textFilePath, 'w+') as f:
                f.write(caption)
        # clear_output(wait=True)
        print(caption)


# Iterate through directories inside directories
for dirpath, dirnames, filenames in os.walk(imagesDirectory):
    captionImages(dirpath)