# Extract Captions using Llava

In [1]:
import torch
from transformers import AutoTokenizer, AutoProcessor, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, LlavaNextProcessor
import array
import os
import pandas as pd
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import LlamaTokenizer, LlamaTokenizerFast

In [None]:
config = {'dataset': 'kodak'}

In [10]:
path = "llava-hf/llava-v1.6-vicuna-7b-hf"
model = LlavaNextForConditionalGeneration.from_pretrained(path, device_map = "auto")

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.29s/it]
Some parameters are on the meta device device because they were offloaded to the cpu.


In [11]:
tokenizer = LlamaTokenizerFast.from_pretrained(path, device_map = "auto")
processor = LlavaNextProcessor.from_pretrained(path, device_map = "auto")

In [6]:
path = "/path/to/your/dataset/"
Image_names = os.listdir(path)
Images = [Image.open(os.path.join(path, image_name)) for image_name in Image_names]

In [7]:
prompts = [
          "USER: <image> \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n \
           ASSISTANT: "
          ]

In [8]:
columns = ["image_name", "item1", "item2", "item3", "item1_description", "item2_description", "item3_description", "overall_description", "d1_len", "d2_len", "d3_len", "dall_len"]
df = pd.DataFrame(columns = columns)
df.loc[:, "image_name"] = Image_names
df = df.set_index("image_name")

In [9]:
import re

In [10]:
def extract_assistant_text(text):
    match = re.search(r'ASSISTANT: (.*)', text)
    if match:
        return match.group(1)
    return None

In [11]:
def extract_only_all(prompts, Image_names, Images, model, processor):
  cnt = 0
  for image_name, image in zip(Image_names, Images):
    answers = []
    for prompt in prompts:
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
        # Generate
        generate_ids = model.generate(**inputs, max_new_tokens=150, temperature = 0)
        ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print(ans)
    try:
      text = extract_assistant_text(ans[0].replace('\n', ''))
      if(text is None):
        print(f"Error: text = {text}")
        continue
      df.loc[image_name, "overall_description"] = text
    except Exception:
      print(f"Error: items = {items}")
      continue
    print(f"Processed image {image_name}, overall_description = {text}")
    cnt+=1
    if cnt % 50 == 0: #save
      df.to_csv(f"output_llavanext_{config['dataset']}.csv")
      


In [33]:
def extract_description(prompts, Image_names, Images, model, processor):
  for image_name, image in zip(Image_names, Images):
    answers = []
    for prompt in prompts:
      try:
        inputs = processor(text=prompt, images=image, return_tensors="pt")
        # Generate
        generate_ids = model.generate(**inputs, max_new_tokens=150)
        ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print(ans)
      except Exception:
        print("API Error", Exception)
        continue
      answers.append(ans)
    ans_text = [a for a in answers]
    items = ans_text[0].split("$")
    if len(items) != 3:
      print(f"Error: items = {items}")
      continue
    try:
      item1, item1_desc = items[0].split("@")
      item2, item2_desc = items[1].split("@")
      item3, item3_desc = items[2].split("@")
      df.loc[image_name, "item1"] = item1
      df.loc[image_name, "item2"] = item2
      df.loc[image_name, "item3"] = item3
      df.loc[image_name, "item1_description"] = item1_desc
      df.loc[image_name, "item2_description"] = item2_desc
      df.loc[image_name, "item3_description"] = item3_desc
      df.loc[image_name, "overall_description"] = ans_text[1]
    except Exception:
      print(f"Error: items = {items}")
      continue
    print(f"Processed image {image_name}, {item1}:{item1_desc},{item2}:{item2_desc},{item3}:{item3_desc}")


In [12]:
extract_only_all(prompts, Image_names, Images, model, processor)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


['USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n\nThe image presents a modern and stylish bar setting. The bar counter, crafted from white marble, is adorned with a variety of decorative elements. A wooden shelf, filled with an assortment of bottles, adds a touch of rustic charm. Above the counter, a black metal shelf holds a collection of vases, each housing a different plant, contributing to the lively atmosphere. The bar stools, upholstered in white, are arranged neatly along the counter, ready for patrons. The overall ambiance is one of sophistication and comfort, inviting patrons to relax and enjoy their time.']
Processed image c42476f8f2747cf77b4eaac719943677.png, overall_description = The image presents a modern and stylish bar setting. The bar counter, crafted from white marble, is adorned with a variety of decorative elements. A wooden shelf, filled with an assortment of bottl



['USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n\nIn the image, a person is seen walking on a sidewalk, pulling a black suitcase behind them. The individual is dressed in a green shirt and blue jeans, and they appear to be in motion, possibly heading towards or returning from a journey. The suitcase is large and seems to be well-suited for travel. The setting is an urban environment, with a clear sky overhead and a building visible in the background. The overall scene suggests a sense of movement and travel.']
Processed image 85e6554aedd269c6a08c24ee4cd1f74b.png, overall_description = In the image, a person is seen walking on a sidewalk, pulling a black suitcase behind them. The individual is dressed in a green shirt and blue jeans, and they appear to be in motion, possibly heading towards or returning from a journey. The suitcase is large and seems to be well-suited for travel. The se



["USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n\nIn the image, a man is seen riding a black motorcycle on a road under a bridge. He is wearing a black jacket, black pants, and a black helmet. The motorcycle has a blue gas tank and is equipped with a black seat. The man appears to be in motion, as suggested by the blurred background and the slight tilt of the motorcycle. The bridge in the background adds an urban touch to the scene. The man's attire and the motorcycle's design suggest a preference for a classic and timeless style."]
Processed image 0f6cd456ade0adf0c1a0d4e522954d00.png, overall_description = In the image, a man is seen riding a black motorcycle on a road under a bridge. He is wearing a black jacket, black pants, and a black helmet. The motorcycle has a blue gas tank and is equipped with a black seat. The man appears to be in motion, as suggested by the blurred backgroun



["USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n\nIn the heart of the action, a Formula 1 race car, painted in vibrant hues of orange and black, is captured in motion. The car, adorned with the number 25, is leaning into a turn on a track, demonstrating the driver's skill and the car's agility. The background reveals a blurred cityscape, adding a sense of speed and urgency to the scene. The image encapsulates the thrill and excitement of a Formula 1 race, where every second counts."]
Processed image 586460a320d3e725cdf3f6032dd09f52.png, overall_description = In the heart of the action, a Formula 1 race car, painted in vibrant hues of orange and black, is captured in motion. The car, adorned with the number 25, is leaning into a turn on a track, demonstrating the driver's skill and the car's agility. The background reveals a blurred cityscape, adding a sense of speed and urgency to the 



["USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n            In the image, a young couple is captured in a moment of intimacy. They are seated on the floor, leaning against a wall adorned with geometric patterns. The man, dressed in a gray shirt and black pants, holds the woman's hand, creating a sense of closeness. The woman, wearing a blue jacket and white sneakers, gazes at the man with a soft expression. A black backpack rests nearby, adding a touch of everyday life to the scene. The overall atmosphere is one of warmth and connection."]
Processed image 81522064f61ba1c9ec111bb349851bca.png, overall_description =             In the image, a young couple is captured in a moment of intimacy. They are seated on the floor, leaning against a wall adorned with geometric patterns. The man, dressed in a gray shirt and black pants, holds the woman's hand, creating a sense of closeness. The wom



["USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \n\nIn the image, a solitary figure is captured in the midst of a rainy city night. The person, clad in a black jacket and holding an umbrella, is walking on a wet sidewalk, their footprints leaving a trail behind. The cityscape is bathed in the soft glow of streetlights, casting long shadows and reflecting off the wet pavement. The scene is a blend of urban life and the tranquility of a rainy night, with the person's journey illuminated by the city's ambient light."]
Processed image e22eb9ac525d8d5178014d96f024bb8d.png, overall_description = In the image, a solitary figure is captured in the midst of a rainy city night. The person, clad in a black jacket and holding an umbrella, is walking on a wet sidewalk, their footprints leaving a trail behind. The cityscape is bathed in the soft glow of streetlights, casting long shadows and reflectin



["USER:  \n please give overall description of this image in 50 words. generate only an informative and natural paragraph. \n            ASSISTANT: \nIn the image, a young woman stands in front of a rusted metal gate, holding an open umbrella. She is dressed in a long coat and shorts, with knee-high socks and black shoes. The gate is old and weathered, with a chain-link fence behind it. The ground is covered in fallen leaves, suggesting it might be autumn. The woman's pose and the umbrella she holds add a sense of anticipation or perhaps a hint of melancholy to the scene. The overall atmosphere is one of quiet solitude, with the woman as the central figure in this urban landscape."]
Processed image d542411baac09dfc9e75b9c1b0532d94.png, overall_description = In the image, a young woman stands in front of a rusted metal gate, holding an open umbrella. She is dressed in a long coat and shorts, with knee-high socks and black shoes. The gate is old and weathered, with a chain-link fence beh



KeyboardInterrupt: 

In [28]:
df.to_csv(f"{config['dataset']}_overall_description.csv")

# Caption Compression

# LLMzip requirements

In [5]:
import torch
import array
import zlib
import argparse

class LlamaZip:
    def __init__(self, model, tokenizer):
        self.CONTEXT_SIZE = 8
        self.BATCH_SIZE = 10
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    def text_to_tokens(self, text):
        # ignore the warning that this gives about too many tokens
        tokens = self.tokenizer(text, return_tensors="pt", add_special_tokens=True)
        tokens = tokens["input_ids"].squeeze()
        return tokens

    def tokens_to_text(self, tokens):
        tokens = tokens.reshape((1, -1))
        text = self.tokenizer.batch_decode(tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)
        return text[0]

    def pad(self, tokens, padding_val):
        pad_len = self.CONTEXT_SIZE - tokens.shape[0] % self.CONTEXT_SIZE
        if pad_len != self.CONTEXT_SIZE:
            padding = torch.tensor([padding_val]*pad_len)

            tokens = torch.cat((tokens, padding))

        else:
            pad_len = 0


        return tokens, pad_len

    @torch.no_grad()
    def get_logits(self, tokens, token_index, past=None):
        my_inputs = {}
        my_inputs['input_ids'] = tokens[:, token_index].reshape(-1, 1)

        output = self.model(**my_inputs, past_key_values=past)
        logits = output.logits
        if len(logits.shape) > 2:
            logits = logits.reshape((logits.shape[0], -1))
        return logits, output.past_key_values


    def encode_one_batch(self, tokens, token_index, past=None):

        assert len(tokens.shape) == 2

        logits, past = self.get_logits(tokens, token_index, past)
        assert len(logits.shape) == 2
        logits, sorted_tokens = torch.sort(logits, descending=True)

        assert len(sorted_tokens.shape) == 2


        next_tokens = tokens[:, token_index + 1]
        next_tokens_expanded = next_tokens.view(-1, 1).expand_as(sorted_tokens)
        next_tokens_expanded = next_tokens_expanded

        # Find score as index of next tokens
        scores = (sorted_tokens==next_tokens_expanded).nonzero(as_tuple=True)

        scores = scores[1] # remove index column

        return scores, past

    def decode_one_batch(self, input_tokens, scores, score_index, past=None):
        assert len(scores.shape) == 2
        logits, past = self.get_logits(input_tokens, score_index, past)

        logits, sorted_tokens = torch.sort(logits, descending=True)
        assert len(sorted_tokens.shape) == 2
        # the scores give the indexes of the decoded tokens
        indexes = scores[:, score_index].flatten()
        decoded_tokens = sorted_tokens[torch.arange(indexes.shape[0]), indexes]

        return decoded_tokens.int(), past


    def encode(self, text):
        tokens = self.text_to_tokens(text)
        print("tokens:", tokens)
        return self.encode_tokens(tokens)

    def encode_tokens(self, tokens):

        tokens, pad_len = self.pad(tokens, self.tokenizer.eos_token_id)
        tokens = tokens.view(-1, self.CONTEXT_SIZE)

        output_scores = torch.zeros(tokens.shape)


        # Add eos to the start of each block (to give it somewhere to start)
        eos = torch.tensor([self.tokenizer.eos_token_id]*tokens.shape[0]).unsqueeze(1)
        tokens = torch.cat((eos, tokens), 1)

        tokens = tokens.to(self.device)

        batches = tokens.shape[0]//self.BATCH_SIZE
        if tokens.shape[0] % self.BATCH_SIZE != 0:
            batches += 1

        # score each batch
        print("Encoding")
        for i in range(batches):
            cur_tokens = tokens[i*self.BATCH_SIZE:(i + 1)*self.BATCH_SIZE]
            cur_output_scores = torch.zeros((cur_tokens.shape[0], cur_tokens.shape[1]-1))
            past = None
            print(i, "out of", batches)

            for j in range(cur_tokens.shape[1]-1):
                cur_output_scores[:, j], past = self.encode_one_batch(cur_tokens, j, past)
                print("encoded: ",j, "out of", cur_tokens.shape[1]-1)
            output_scores[i*self.BATCH_SIZE:(i + 1)*self.BATCH_SIZE] = cur_output_scores

        output_scores = output_scores.flatten().int()
        if pad_len > 0:
            output_scores = output_scores[:-pad_len]
        return output_scores

    def decode(self, scores):
        output_tokens = self.decode_tokens(scores)
        output_tokens = output_tokens.unsqueeze(0)
        print("Decoded tokens: ", output_tokens)
        text = self.tokenizer.batch_decode(output_tokens)
        #text = "".join(text)
        #text = text.replace("<|endoftext|>", "")
        return text[0]

    def decode_tokens(self, scores):

        scores, pad_len = self.pad(scores, self.tokenizer.eos_token_id)

        scores = scores.view(-1, self.CONTEXT_SIZE) # all rows, CONTEXT_SIZE

        output_tokens = torch.zeros(scores.shape, dtype=int)


        # Add eos to the start of each block (to give it somewhere to start)
        eos = torch.tensor([self.tokenizer.eos_token_id]*output_tokens.shape[0]).unsqueeze(1)
        output_tokens = torch.cat((eos, output_tokens), 1) # all rows, CONTEXT_SIZE + 1

        output_tokens = output_tokens.to(self.device)

        batches = scores.shape[0]//self.BATCH_SIZE
        if scores.shape[0] % self.BATCH_SIZE != 0:
            batches += 1

        # score each batch
        print("Decoding")
        for i in range(batches):
            print(i, "out of", batches)
            cur_scores = scores[i*self.BATCH_SIZE:(i + 1)*self.BATCH_SIZE] # BATCH_SIZE, CONTEXT_SIZE

            cur_output_tokens = output_tokens[i*self.BATCH_SIZE:(i + 1)*self.BATCH_SIZE] # BATCH_SIZE, CONTEXT_SIZE
            cur_output_tokens = cur_output_tokens.to(self.device)
            past = None
            for j in range(scores.shape[1]):

                cur_output_tokens[:, j+1], past = self.decode_one_batch(cur_output_tokens, cur_scores, j, past) # BATCH_SIZE

            output_tokens[i*self.BATCH_SIZE:(i + 1)*self.BATCH_SIZE] = cur_output_tokens



        output_tokens = output_tokens[:, 1:].int()
        output_tokens = output_tokens.flatten()

        if pad_len != 0:
            output_tokens = output_tokens[:-pad_len]

        return output_tokens



    def encode_and_zip(self, text):
        encoded = self.encode(text)
        encoded = array.array("H", encoded)
        return zlib.compress(encoded, level=9)


    def unzip_and_decode(self, zipped):
        unzipped = zlib.decompress(zipped)
        print("unzipped sequence: ", unzipped)
        unzipped = array.array("H", unzipped)
        print("unzipped array: ", unzipped)
        print("len: ", len(unzipped))
        decoded = self.decode(torch.tensor(unzipped))
        print("decoded: ", decoded)
        return decoded

    def zip_file(self, text_file, zip_file):
        with open(text_file, encoding="utf-8") as f:
            text = f.read()

        zipped = self.encode_and_zip(text)

        with open(zip_file, "wb") as f:
            f.write(zipped)

    def unzip_file(self, zip_file, text_file):
        with open(zip_file, "rb") as f:
            zipped = f.read()
        text = self.unzip_and_decode(zipped)

        with open(text_file, "w", encoding="utf-8") as f:
            f.write(text)







# Compression

In [12]:
llama_zip = LlamaZip(model, tokenizer)

In [11]:
def evaluation(llama_zip, df):
  #dfの全要素について処理する
  for i, image_name in enumerate(df.index):
    if(df.loc[image_name, "dall_len"] > 0):
      print("skipping ", image_name)
      continue
    text = df.loc[image_name, "overall_description"]
    zipped = llama_zip.encode_and_zip(text)
    df.loc[image_name, "dall_len"] = len(zipped) * 8
    if(i % 50 == 0):
      df.to_csv(f"/root/shared_smurai/df/{config['dataset']}_overall_description.csv")

In [12]:
evaluation(llama_zip, df)

skipping  0
skipping  1
skipping  2
skipping  3
skipping  4
skipping  5
skipping  6
skipping  7
skipping  8
skipping  9
skipping  10
skipping  11
skipping  12
skipping  13
skipping  14
skipping  15
skipping  16
skipping  17
skipping  18
skipping  19
skipping  20
skipping  21
skipping  22
skipping  23
skipping  24
skipping  25
skipping  26
skipping  27
skipping  28
skipping  29
skipping  30
skipping  31
skipping  32
skipping  33
skipping  34
skipping  35
skipping  36
skipping  37
skipping  38
skipping  39
skipping  40
skipping  41
skipping  42
skipping  43
skipping  44
skipping  45
skipping  46
skipping  47
skipping  48
skipping  49
skipping  50
skipping  51
skipping  52
skipping  53
skipping  54
skipping  55
skipping  56
skipping  57
skipping  58
skipping  59
skipping  60
skipping  61
skipping  62
skipping  63
skipping  64
skipping  65
skipping  66
skipping  67
skipping  68
skipping  69
skipping  70
skipping  71
skipping  72
skipping  73
skipping  74
skipping  75
skipping  76
skipping 

In [13]:
df.to_csv(f"{config['dataset']}_llava_1.6_vicuna.csv")

# Ablation Study

In [15]:
import zlib
df['zlib_len'] = df['overall_description'].apply(lambda x: len(zlib.compress(x.encode('utf-8'))) * 8)

# 結果を表示
print(df[['overall_description', 'zlib_len']].sum())

overall_description    The image presents a modern and stylish bar se...
zlib_len                                                          598960
dtype: object


In [18]:
zlib_len_sum = df['zlib_len'].sum()

In [19]:
print(zlib_len_sum)

598960


In [20]:
zlib_len_sum / overall_description_length_sum * 100

56.72743252867816