In [1]:
# use AI Sauna shared scratch directory to store big models

import os
os.environ["HF_HOME"] = "/scratch/project_462000584/huggingface"

In [2]:
# load the LLaVA-13B model

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

model_path = "liuhaotian/llava-v1.5-13b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
# load the Finna-HKM-images dataset

import datasets
from datasets import load_dataset, Image

# don't decode images into PIL Image objects automatically (it will be done by the code below)
dataset = load_dataset("/users/oisuomin/Finna-HKM-images", split='train').cast_column("image", Image(decode=False))
print(dataset)

Resolving data files:   0%|          | 0/5948 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['image', 'formats', 'id', 'imageRights', 'images', 'languages', 'nonPresenterAuthors', 'onlineUrls', 'presenters', 'rating', 'series', 'subjects', 'title', 'year', 'rawData'],
    num_rows: 5947
})


In [18]:
# adapted from https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/run_llava.py

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)

import torch

from llava.utils import disable_torch_init
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from PIL import Image as PILImage

def image_parser(args):
    out = args.image_file.split(args.sep)
    return out

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = PILImage.open(BytesIO(response.content)).convert("RGB")
    else:
        image = PILImage.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

def my_eval_model(args):
    # Model
#    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
#    tokenizer, model, image_processor, context_len = load_pretrained_model(
#        args.model_path, args.model_base, model_name
#    )

    qs = args.query
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, args.conv_mode, args.conv_mode
            )
        )
    else:
        args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    image_files = image_parser(args)
    images = load_images(image_files)
    image_sizes = [x.size for x in images]
    images_tensor = process_images(
        images,
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    return outputs

In [19]:
# high level captioning function

def caption(image_file, prompt):
    args = type('Args', (), {
        "model_path": model_path,
        "model_base": None,
        "model_name": get_model_name_from_path(model_path),
        "query": prompt,
        "conv_mode": None,
        "image_file": image_file,
        "sep": ",",
        "temperature": 0,
        "top_p": None,
        "num_beams": 1,
        "max_new_tokens": 512
    })()

    return my_eval_model(args)

In [24]:
%%time

# caption the first 100 images

data = []

#for i in range(len(dataset['train'])):
for i in range(100):
    print(i, 'https://api.finna.fi' + dataset[i]['images'][0])
    prompt = "This is an alt text description. What can be seen in the front? what can be seen in the back? Is the photo coloured or black and white? indicate in the description if there's text in the picture. Do not use words image or picture in the description. Don't count the amount of things."
    image_file = dataset[i]['image']['path']
    text = caption(image_file, prompt)
    data.append({'text': text, 'idx': i})
    print(text)
    print('')

0 https://api.finna.fi/Cover/Show?source=Solr&id=hkm.00049815-1E48-4C06-8C23-C5C8DA9F11BA&index=0&size=large
In the front, there is a building with a person standing in front of it. In the back, there is a building with a person standing in front of it. The photo is black and white. There is no text in the picture.

1 https://api.finna.fi/Cover/Show?source=Solr&id=hkm.000BC921-3333-425F-AD9D-4F6DF8F1F411&index=0&size=large
In the front, there is a building with a fence and a tree. In the back, there is a building with a fence and a tree. The photo is black and white. There is no text in the picture.

2 https://api.finna.fi/Cover/Show?source=Solr&id=hkm.001AC4AD-E464-45E1-AE48-B2916B0F7524&index=0&size=large
In the front, there is a building with a large archway. In the back, there is a large building with a clock tower. The photo is black and white. There is a text in the picture, which is located in the bottom left corner.

3 https://api.finna.fi/Cover/Show?source=Solr&id=hkm.001F7288

In [46]:
# format the results into a Markdown document

HEADER = """
# Generated alt text for historical images

Visual language model used: {model_name}

Image dataset: NatLibFi/Finna-HKM-images

Prompt text:

{prompt}

"""


TEMPLATE = """
---

## {title} ({author}, {year})

![]({image_url})

### Generated alt text

{caption}

### Original description

{description}

"""

def get_author(idx):
    try:
        return '; '.join(dataset[idx]['rawData']['author_facet'])
    except:
        return 'unknown'

with open('output.md', 'w') as outfile:
    print(HEADER.format(model_name=model_path, prompt=prompt), file=outfile)
    
    for item in data:
        info = {
            'title': dataset[item['idx']]['title'],
            'caption': item['text'],
            'image_url': 'https://api.finna.fi' + dataset[item['idx']]['images'][0],
            'description': dataset[item['idx']]['rawData']['description'],
            'author': get_author(item['idx']),
            'year': dataset[item['idx']]['year']
        } 
        print(TEMPLATE.format(**info), file=outfile)