<a href="https://colab.research.google.com/github/steinhaug/stable-diffusion/blob/main/BLIP-2/AutoCaptioner-LLaVA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## AutoCaptioner using LLaVA

Remember to run postWorker notebook when this one is done.  

[![Buy me a beer](https://raw.githubusercontent.com/steinhaug/stable-diffusion/main/assets/buy-me-a-beer.png ) ](https://steinhaug.com/donate/)

__NB!__  
This notebook is not made for "people in a hurry", so you need to make sure to download the correct images you want to caption and make sure the inference is targeting the correct folders.  
If you would like a "1,2,3 go" notebook for image captioning ask me and I'll put it together.  

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@markdown Download some images
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!mkdir /content/images
#!wget --header 'Authorization: Bearer TOKEN_HERE' https://huggingface.co/camenduru/polaroid/resolve/main/style_name_fix.zip
!cp /content/drive/MyDrive/data/blip2.zip /content/images/style_name_fix.zip
!unzip /content/images/style_name_fix.zip -d /content/images

## Pre-install: System

In [3]:
#@markdown 1.0: Install dependencies
from IPython.display import clear_output
%cd /content
!git clone -b dev https://github.com/camenduru/LLaVA
%cd /content/LLaVA
!wget https://raw.githubusercontent.com/L0garithmic/fastcolabcopy/main/fastcopy.py

!pip install -q transformers==4.36.2
!pip install ninja
!pip install flash-attn --no-build-isolation

!pip install -e .
clear_output()
print('[1;32mDone! ✓')

[1;32mDone! ✓


In [1]:
#@markdown 1.2: Load notebook functions, needs to reload if you restart session.
import tarfile
import os
import glob
import zipfile
import shutil

def compress_tar(directory_path, output_tar_file, inclusion_pattern=None):
    if output_tar_file is None:
        output_tar_file = f"{directory_path}/{return__folderName(directory_path)}.tar"
    if not output_tar_file.endswith('.gz'):
        output_tar_file += '.gz'
    with tarfile.open(output_tar_file, 'w:gz') as tar:
        for root, dirs, files in os.walk(directory_path):
            if inclusion_pattern:
                files_to_include = glob.glob(os.path.join(root, inclusion_pattern))
            else:
                files_to_include = [os.path.join(root, file) for file in files]
            for file_path in files_to_include:
                arcname = os.path.relpath(file_path, directory_path)
                print(f"{arcname}")
                tar.add(file_path, arcname=arcname)

def decompress_tar(tar_file, destination=None, flatten_structure=False):
    with tarfile.open(tar_file, 'r') as tar:
        if destination is not None:
            print(f"Create: {destination}")
            os.makedirs(destination, exist_ok=True)

        folder, extension = os.path.splitext(return__folderName(tar_file))

        for member in tar.getmembers():
            if flatten_structure:
                # Use just the filename without directories
                member.name = os.path.basename(member.name)
                #print(f"member.name 1: {member.name}")
            if destination is not None:
                # Join the destination directory with the member's name
                #member_path = os.path.join(destination, folder)
                if flatten_structure:
                    member_path = os.path.join(destination, folder, os.path.dirname(member.name))
                else:
                    member_path = destination
                #print(f"member_path 2: {member_path}")
            else:
                member_path = os.path.dirname(member.name)
                #print(f"member_path 3: {member_path}")

            if flatten_structure and destination==None:
                member_path = os.path.dirname(tar_file)
                #print(f"member_path 4: {member_path}")

            #print(f"{member.name}: {member_path}")
            tar.extract(member, path=member_path)

            #break
def decompress_tar_gz(file_path, output_dir='.'):
    os.makedirs(output_dir, exist_ok=True)
    try:
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(output_dir)
        print(f"Successfully decompressed '{file_path}' to '{output_dir}'.")
    except tarfile.TarError as e:
        print(f"Error decompressing '{file_path}': {e}")

def decompress_zip_files(directory):
    # Check if the directory exists
    if not os.path.isdir(directory):
        print("The specified directory does not exist.")
        return

    # List all files in the directory
    files = os.listdir(directory)

    # Filter to keep only .zip files
    zip_files = [f for f in files if f.endswith('.zip')]

    # Process each zip file
    for zip_file in zip_files:
        # Full path to the zip file
        zip_path = os.path.join(directory, zip_file)

        # Extract the base name without the '.zip' extension to create a new folder
        folder_name = zip_file[:-4]
        new_folder_path = os.path.join(directory, folder_name)

        # Create the new directory
        if not os.path.exists(new_folder_path):
            os.makedirs(new_folder_path)

        # Open the zip file
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # Extract all the files
            for file_info in zip_ref.infolist():
                # Ensure we are only extracting files, not directories
                if file_info.is_dir():
                    continue  # Skip directories
                # Extract each file to the root of the new folder, ignoring any folder structure within the zip
                extracted_path = os.path.join(new_folder_path, os.path.basename(file_info.filename))
                with zip_ref.open(file_info.filename) as source, open(extracted_path, 'wb') as target:
                    shutil.copyfileobj(source, target)

        print(f"Extracted {zip_file} into {new_folder_path}")

def write_the_file(path, data_string):
    if len(str(data_string)):
        with open(path, 'w+') as fw:
            fw.write(str(data_string))
    else:
        if os.path.exists(path):
            os.remove(path)
    return '';

def delete_dir(directory_path):
    if os.path.exists(directory_path):
        shutil.rmtree(directory_path)
        print(f"Directory and all contents deleted: {directory_path}")
    else:
        print("The directory does not exist.")

def return__folderName(directory_path, verify_folder=False):
    if not verify_folder:
        return os.path.basename(os.path.normpath(directory_path))
    if os.path.isdir(directory_path):
        last_folder_name = os.path.basename(os.path.normpath(directory_path))
        return last_folder_name
    else:
        return None # Return None for invalid paths

def ensure_array(input_var):
    if isinstance(input_var, list):
        return input_var
    elif isinstance(input_var, str):
        return [input_var]
    else:
        raise ValueError("Input must be a string or a list")

def array__prefix_with(filter_extensions, prefix='.'):
    return [ext if ext.startswith(prefix) else prefix + ext for ext in filter_extensions]

def return__fileCount(directory_path, extensions=None):
    matching_files_count = 0

    if extensions is not None:
        extensions = array__prefix_with(ensure_array(extensions)) #['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']

    for root, dirs, files in os.walk(directory_path):
        for file in files:
            if extensions is None:
                matching_files_count += 1
            elif any(file.lower().endswith(extension.lower()) for extension in extensions):
                matching_files_count += 1

    return matching_files_count



## Pre-install: Download and prepare files

Download the image set, decompress the files and launch the captioning loop

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

!huggingface-cli login --token {HF_TOKEN}
from huggingface_hub import snapshot_download

In [None]:
#@ markdown Download image sets
import os
SAVE_PATH = '/content/datasets'
REPO_ID = 'steinhaug/onceUponAtimeInPornVille'
os.makedirs(f"{SAVE_PATH}/{REPO_ID}", exist_ok=True)
path = snapshot_download(repo_id=REPO_ID, repo_type="dataset", revision="main", allow_patterns="SexArt/*", local_dir=f"{SAVE_PATH}/{REPO_ID}", local_dir_use_symlinks=False)

In [11]:
#@markdown Decompress captions
tar_gz_file_path = '/content/drive/MyDrive/datasets/SexArt.tar.gz'
output_directory = '/content/datasets/SexArt'

if not os.path.exists(output_directory):
    os.makedirs(output_directory)

# Call the function
decompress_tar_gz(tar_gz_file_path, output_directory)


Successfully decompressed '/content/drive/MyDrive/datasets/SexArt.tar.gz' to '/content/datasets/SexArt'.


In [None]:
#@markdown Decompress all .tar folders
IMAGE_FOLDER = '/content/datasets/steinhaug/onceUponAtimeInPornVille/SexArt'
output_directory = '/content/datasets/SexArt'

for item_name in os.listdir(IMAGE_FOLDER):
    file_path = os.path.join(IMAGE_FOLDER, item_name)
    #print(f"{file_path}")
    root, extension = os.path.splitext(file_path)
    if extension == '.tar':
        decompress_tar(file_path, output_directory, False)
        print(f"Decompressed: {file_path}")


In [None]:
!rm -Rf /content/datasets/steinhaug/onceUponAtimeInPornVille/Domai

## Model CONFIG

In [71]:
#@markdown __[\*]>=-  Unload model -=<[\*]__ <br>
#@markdown Run this cell if you need to load another model, or reset the runtime.
image_processor = None
vision_tower = None
import gc
import torch
model = None
tokenizer = None
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()

In [1]:
# @markdown 1/3: Model select
from IPython.display import clear_output
%cd /content/LLaVA
clear_output()

import fastcopy
import os
def return__folderName(directory_path, verify_folder=False):

    if not verify_folder:
        return os.path.basename(os.path.normpath(directory_path))

    # Ensure the path is a valid directory
    if os.path.isdir(directory_path):
        # Split the path into components and get the last one
        last_folder_name = os.path.basename(os.path.normpath(directory_path))
        return last_folder_name
    else:
        return None  # Return None for invalid paths

model_id = "4bit/llava-v1.5-13b-3GB" # @param ['liuhaotian/llava-v1.5-7b', '4bit/llava-v1.5-13b-3GB']

model_drive_path = f"/content/drive/MyDrive/models/LLaVA/{model_id}"
model_path = f"/content/models/{return__folderName(model_id)}"

print(f"[1;32mSelected model: {return__folderName(model_id)} ✓")

[1;32mSelected model: llava-v1.5-13b-3GB ✓


In [2]:
#@markdown 2/3: Model loader, prepare files and load model into GPU...
import os
if not os.path.isdir(model_path):
    os.makedirs(model_path)
    !python fastcopy.py "$model_drive_path/". "$model_path" --thread 20 --size-limit 400mb
    !python fastcopy.py "$model_drive_path/". "$model_path" --thread 3 --size-limit 800mb
    !python fastcopy.py "$model_drive_path/". "$model_path" --thread 3 --size-limit 3500mb
    !rsync -r -v --size-only --progress $model_drive_path/. $model_path --delete

from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
import torch

if model_id == 'liuhaotian/llava-v1.5-7b':
    kwargs = {"device_map": "auto", "low_cpu_mem_usage": True }
    kwargs['load_in_8bit'] = True
    kwargs['quantization_config'] = BitsAndBytesConfig(
        load_in_8bit=True
    )
    print("Model download and config complete.")
elif model_id == '4bit/llava-v1.5-13b-3GB':
    kwargs = {"device_map": "auto", "low_cpu_mem_usage": True }
    kwargs['load_in_4bit'] = True
    kwargs['quantization_config'] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4'
    )
    print("Model download and config complete.")
elif model_id == '4bit/llava-v1.5-13b-4GB-8bit':
    kwargs = {"device_map": "auto", "low_cpu_mem_usage": True }
    kwargs['load_in_4bit'] = True
    kwargs['quantization_config'] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    print("Model download and config complete.")
else:
    print("Error...")




def old_ones():
    #@ markdown llava-v1.5-13b-3GB
    from transformers import AutoTokenizer, BitsAndBytesConfig
    from llava.model import LlavaLlamaForCausalLM
    import torch

    model_path = "/content/llava-v1.5-13b-3GB"
    kwargs = {"device_map": "auto", "low_cpu_mem_usage": True }
    kwargs['load_in_4bit'] = True
    kwargs['quantization_config'] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4'
    )


    #@ markdown llava-v1.5-7b
    from transformers import AutoTokenizer, BitsAndBytesConfig
    from llava.model import LlavaLlamaForCausalLM
    import torch

    model_path = "/content/llava-v1.5-7b"
    kwargs = {"device_map": "auto", "low_cpu_mem_usage": True }
    kwargs['load_in_8bit'] = True
    kwargs['quantization_config'] = BitsAndBytesConfig(
        load_in_8bit=True
    )



#@ title Load it up - move to GPU
my_list = ['llava-v1.5-7b','llava-v1.5-13b-3GB','llava-v1.5-13b-4GB-8bit']
if return__folderName(model_path) in my_list:
    model = LlavaLlamaForCausalLM.from_pretrained(model_path, **kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device='cuda')
    image_processor = vision_tower.image_processor
else:
    print(f"{element_to_check} is not a valid model.")


[2024-04-26 14:14:35,826] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Model download and config complete.


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [3]:
#@markdown 3/3: Load the caption_image func
import os
import requests
from PIL import Image
from io import BytesIO
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer

def caption_image(image_file, prompt, temperature=0.2):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    disable_torch_init()
    conv_mode = "llava_v0"
    conv = conv_templates[conv_mode].copy()
    roles = conv.roles
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
    inp = f"{roles[0]}: {prompt}"
    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    raw_prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
      output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=temperature,
                                  max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    conv.messages[-1][-1] = outputs
    output = outputs.rsplit('</s>', 1)[0]
    return image, output

## Inference testing

In [None]:
#@markdown <h2>Multiquestion loop</h2>
question = '''Answer the following questions:
1 - Describe the subject in the image.
2 - Use descriptive language for the subject.
3 - Describe the subject using explicit language focusing on bodily features.
4 - You are working at an adult store and you job is to write captions for images. Describe the following photo.
5 - You are working for a marketing company that sells adulterated photos and your job is to write captions so that people might buy the photo. Describe the following photo.
Separate your answers with line breaks.
'''
file_name = 'p1.jpg'
image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.1); output = output.replace("\n\n", "\n");
sizeDiv=8; h, w = image.size; display(image.resize(( int(h / sizeDiv) , int(w / sizeDiv) )));
print(f"{question}\n\n0.1\n{output}")
image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.3); output = output.replace("\n\n", "\n");
print(f"\n0.3:\n{output}")

question = '''Answer the following questions:
1 - Describe the subject in the image.
2 - Use descriptive language for the subject.
3 - Describe the subject using explicit language focusing on bodily features.
4 - If this was a photo for an adult movie, what would the title be?
5 - If this was a photo for a book, what would the title of the book be?
6 - If this was a photo for a novel, what would the title of the novel be?
7 - You are in a bookstore that sells romantic and sexual inspired books, on the shelf you see a book with the photo on the front. What is the title of the book?

Separate your answers with line breaks.
'''

image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.1); output = output.replace("\n\n", "\n");
print(f"\n{question}\n\n0.1\n{output}")
image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.3); output = output.replace("\n\n", "\n");
print(f"\n0.3:\n{output}")

question = '''Answer the following questions:
1 - You are working in the company "Pussy Inc" where everybody share their love for the womans vagina in all forms. Your job is to write 50 or so words describing the photo vividly.
2 - You are working in the company "Pussy Inc" where everybody share their love for the womans vagina in all forms. Your job is to write atleast 50 words describing the photo vividly.
3 - You are working in the company "Pussy Inc" where everybody share their love for the womans vagina in all forms. Your job is to write atleast 50 words describing the picture with sexualized and perverted language, for the picture of the month in our newspaper.

Separate your answers with line breaks.
'''

image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.1); output = output.replace("\n\n", "\n");
print(f"\n{question}\n\n0.1\n{output}")
image, output = caption_image(f'/content/images/blip2/{file_name}', question, 0.3); output = output.replace("\n\n", "\n");
print(f"\n0.3:\n{output}\n")


In [None]:
#@markdown <h2>Benchmark loop</h2>
question = 'Describe the image and color details.'

file_names = os.listdir('/content/images/blip2')
sorted_file_names = sorted(file_names)
for file_name in sorted_file_names:
    try:
        image, output = caption_image(f'/content/images/blip2/{file_name}', question)
        print(f"{file_name}:{output}")
        # image
    except Exception as e:
        print(f"Error processing {file_name}: {str(e)}")
        continue

## AutoCaptioning testing

In [None]:
import os
question = '''You are captioning adult images, answer the following questions:
1 - Is the subjects head visible? Yes / No
2 - Are the subjects hands visible? Yes / No
3 - Are the subjects feet visible? Yes / No
4 - Describe what the hands are doing, or can't tell if you do not know.
5 - If possible, describe what the hands are doing.
6 - Would you say the subject is: standing, sitting or lying?
7 - Try to reason if the photo is indoors or outdoors?
8 - Try to reason what time of day the photo is taken?

Separate your answers with line breaks.'''

file_names = [
    '/content/datasets/SexArt/23.03.23.Paulina.Pace.Stunning.View/SexArt_Stunning-View_Paulina-Pace_high_0004.jpg',
]

for file_name in file_names:
    image, output = caption_image(file_name, question, 0.2); output = output.replace("\n\n", "\n");
    sizeDiv=24; h, w = image.size; display(image.resize(( int(h / sizeDiv) , int(w / sizeDiv) )));
    print(f"{os.path.basename(file_name)}")
    print(f"\n0.2\n{output}")
    #image, output = caption_image(file_name, question, 0.2); output = output.replace("\n\n", "\n");
    #print(f"\n\n0.2\n{output}")

## AutoCaptioning

In [None]:
#@markdown **CAPTIONER** AutoCaptioner Loop
IMAGE_FOLDER = "/content/datasets/SexArt" # @param {type:"string"}

question = '''You are captioning adult images, answer the following questions:
1 - If this was a photo for an adult movie, what would the title be?
2 - Use descriptive language for the subject.
3 - How would you rate the subject? Completely naked, almost naked, panties and bra, clothed or can't tell.
4 - If the subject has a vagina visible in the photo, how would you describe it? Shaved pussy, trimmed pussy, hairy pussy or can't tell.
5 - If the subject is a woman, how would you rate her breats? Tiny, small, medium, large, mega or can't tell.
6 - How would you rate the subjects body build? Skinny, thin, normal, big, fat or unsure.
7 - Is the subjects head visible? Yes / No
8 - Are the subjects hands visible? Yes / No
9 - Are the subjects feet or lower legs visible? Yes / No
10 - Would you say the subject is: standing, kneeling, lying or sitting? If lying then on the back or stomach?
11 - Describe what the hands are doing, or can't tell if you do not know.

Separate your answers with line breaks.'''

image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
total_count = return__fileCount(IMAGE_FOLDER, image_extensions)

image_count = 0
for root, dirs, files in os.walk(IMAGE_FOLDER):
    for file in files:
        #print( f"{root}/{file}" )
        file_path = os.path.join(root, file)
        file_root, file_ext = os.path.splitext(file_path)
        if file_ext.lower() in image_extensions:
            image_count = image_count + 1
            caption_file = f"{file_root}.txt"
            print(f"-= {image_count}/{total_count} : {caption_file} =-")
            if not os.path.isfile(caption_file):
                try:
                    image, caption_result = caption_image(file_path, question)
                    caption_result = caption_result.replace("\n\n", "\n")
                    write_the_file(caption_file, caption_result)
                    sizeDiv = 25
                    h, w = image.size
                    display(image.resize(( int(h / sizeDiv) , int(w / sizeDiv) )))
                    print(f"{caption_result}")
                except Exception as e:
                    print(f"Error processing {image_count}/{total_count} ({item_name}): {str(e)}")
                    continue
            else:
                verbosed_message = f"Caption file already exists."
                #print(verbosed_message)

In [None]:
# Compresses all the .txt files
compress_tar('/content/datasets/SexArt', '/content/drive/MyDrive/datasets/SexArt3.tar', '*.txt')


In [None]:
#if "step1_installed_flag" not in globals():
#  raise Exception("Please run step 1 first!")

#@markdown ### 📈 Analyze Tags
#@markdown Perhaps you need another look at your dataset.
show_top_tags = 50 #@param {type:"number"}

from collections import Counter
top_tags = Counter()

for txt in [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(".txt")]:
  with open(os.path.join(IMAGE_FOLDER, txt), 'r') as f:
    top_tags.update([s.strip() for s in f.read().split(" ")])

top_tags = Counter(top_tags)
print(f"📊 Top {show_top_tags} tags:")
for k, v in top_tags.most_common(show_top_tags):
  print(f"{k} ({v})")