In [2]:
import logging
import math
import os
import random
import re
import sys
import time
from argparse import Namespace
from collections import namedtuple

import numpy as np
import sentencepiece as spm
import torch
from fairseq_cli.generate import get_symbols_to_strip_from_output
from PIL import Image
from torchvision import transforms

from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints img_src_tokens img_gpt_input_mask img_path_batch")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")

In [11]:
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.interactive")

In [12]:
def square_transform(size=224):
    inception_normalize = transforms.Compose(
        [transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]
    )
    return transforms.Compose(
        [
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            inception_normalize,
        ]
    )

def split_string(string, separators):
    """
    Function to split a given string based on a list of separators.

    Args:
    string (str): The input string to be split.
    separators (list): A list of separators to be used for splitting the string.

    Returns:
    A list containing the split string with separators included.
    """
    pattern = "|".join(re.escape(separator) for separator in separators) 
    result = re.split(f'({pattern})', string)  
    return [elem for elem in result if elem] 

def get_interactive_tokens_and_lengths(self, lines, encode_fn, tokenizer=None):
    """
    line format: [image]path<tab>text<tab>[image]path
    model input: `<s> <image> image hidden </image> My cat looking very dignified.</s>`
    """
    image_feature_length = self.args.image_feature_length
    bos_id = self.dictionary.bos()
    eos_id = self.dictionary.eos()
    boi_id = self.dictionary.index("<image>")
    eoi_id = self.dictionary.index("</image>")
    
    def convert_one_line(input_str):
        # TODO: input interleave image and text
        token = []
        img_src_token = []
        img_gpt_input_mask = []
        segments = input_str.split('<tab>')
        token.append(bos_id)
        img_gpt_input_mask.append(0)
        for i, segment in enumerate(segments):
            if segment.startswith('[image]'):
                image_path = segment[7:]
                # read image and transform to tensor
                image = Image.open(image_path).convert("RGB")
                # update the global_path
                # global global_image_path
                # global_image_path = image_path
                image_tensor = square_transform(self.args.input_resolution)(image)
                img_src_token.append(image_tensor)
                # global global_image_tensor
                # global_image_tensor = image_tensor
                token.extend([boi_id] + list(range(4, image_feature_length+4)) + [eoi_id])
                
                img_gpt_input_mask.extend([0] + [1] * image_feature_length + [0])
            else:
                special_tokens = [self.source_dictionary[idx] for idx in range(tokenizer.vocab_size(), 
                                                                               len(self.source_dictionary))]
                split_special_token_words = []
                split_resutls = split_string(segment, special_tokens)
                for string in split_resutls:
                    if string in special_tokens:
                        # print(f"dict-length({len(self.source_dictionary)}), substring {string} is a special token")
                        split_special_token_words.append(string)
                    else:
                        encode_tokens = tokenizer.encode(string, out_type=str)
                        # print(f"dict-length({len(self.source_dictionary)}), substring {string} is not a special token, tokenized into {encode_tokens}")
                        split_special_token_words.extend(encode_tokens)
                segment = ' '.join(split_special_token_words)
                
                text_tokens = self.source_dictionary.encode_line(
                    encode_fn(segment), add_if_not_exist=False
                ).tolist()
                
                text_tokens = text_tokens[:-1] # </s> in token
                token.extend(text_tokens)
                img_gpt_input_mask.extend([0] * (len(text_tokens))) # </s> in token
        token.append(eos_id)
        # img_gpt_input_mask = img_gpt_input_mask[:-1]
        assert len(token) == len(img_gpt_input_mask) + 1 
        token = torch.LongTensor(token)
        img_gpt_input_mask = torch.LongTensor(img_gpt_input_mask)
        img_src_token = torch.stack(img_src_token, dim=0)
        return token, img_src_token, img_gpt_input_mask
    
    tokens = []
    img_src_tokens = []
    img_gpt_input_masks = []
    for src_str in lines:
        token, img_src_token, img_gpt_input_mask = convert_one_line(src_str)
        tokens.append(token)
        img_src_tokens.append(img_src_token)
        img_gpt_input_masks.append(img_gpt_input_mask)
    lengths = [t.numel() for t in tokens]
    
    return tokens, lengths, img_src_tokens, img_gpt_input_masks


def make_batches(lines, cfg, task, max_positions, encode_fn):
    def encode_fn_target(x):
        return encode_fn(x)

    if cfg.generation.constraints:
        # Strip (tab-delimited) contraints, if present, from input lines,
        # store them in batch_constraints
        batch_constraints = [list() for _ in lines]
        for i, line in enumerate(lines):
            if "\t" in line:
                lines[i], *batch_constraints[i] = line.split("\t")

        # Convert each List[str] to List[Tensor]
        for i, constraint_list in enumerate(batch_constraints):
            batch_constraints[i] = [
                task.target_dictionary.encode_line(
                    encode_fn_target(constraint),
                    append_eos=False,
                    add_if_not_exist=False,
                )
                for constraint in constraint_list
            ]

    if cfg.generation.constraints:
        constraints_tensor = pack_constraints(batch_constraints)
    else:
        constraints_tensor = None

    tokenizer = spm.SentencePieceProcessor()
    if os.path.exists('data/sentencepiece.bpe.model'):
        tokenizer.Load('data/sentencepiece.bpe.model')
    else:
        tokenizer = None
    tokens, lengths, img_src_tokens, img_gpt_input_mask = get_interactive_tokens_and_lengths(task, lines, encode_fn, tokenizer)

    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_caption_inference(
            tokens, lengths, img_src_tokens, img_gpt_input_mask, constraints=constraints_tensor
        ),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=max_positions,
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        ids = batch["id"]
        src_tokens = batch["net_input"]["src_tokens"]
        src_lengths = batch["net_input"]["src_lengths"]
        img_src_tokens = batch["net_input"]["img_src_tokens"]
        img_gpt_input_mask = batch["net_input"]["img_gpt_input_mask"]
        constraints = batch.get("constraints", None)

        yield Batch(
            ids=ids,
            src_tokens=src_tokens,
            src_lengths=src_lengths,
            img_src_tokens=img_src_tokens,
            img_gpt_input_mask=img_gpt_input_mask,
            constraints=constraints,
        )


In [25]:
cfg = OmegaConf.load('../configs/config.yaml')

if isinstance(cfg, Namespace):
    cfg = convert_namespace_to_omegaconf(cfg)

logger.info(cfg)

2023-07-31 15:05:43 | INFO | fairseq_cli.interactive | {'hydra': {'run': {'dir': '.'}}, 'defaults': ['_self_', {'task': {'_name': 'generation_obj', 'data': 'None', 'sample_break_mode': 'none', 'tokens_per_sample': 1024, 'output_dictionary_size': -1, 'self_target': False, 'future_target': False, 'past_target': False, 'add_bos_token': True, 'max_target_positions': None, 'shorten_method': 'none', 'shorten_data_split_list': '', 'pad_to_fixed_length': False, 'pad_to_fixed_bsz': False, 'seed': 1, 'batch_size': 1, 'batch_size_valid': 1, 'dataset_impl': None, 'data_buffer_size': 10, 'tpu': False, 'use_plasma_view': False, 'plasma_path': '/tmp/plasma', 'required_batch_size_multiple': 1, 'dict_path': '/home/omote/WorkSpace/unilm/kosmos-2/data/dict.txt', 'image_feature_length': 64, 'input_resolution': 1024, 'location_bin_size': 32, 'locate_special_token': 1}}, {'model': None}, {'criterion': 'cross_entropy'}, {'optimizer': None}, {'lr_scheduler': 'fixed'}, {'bpe': None}, {'tokenizer': None}, {'sco

In [16]:
cfg

{'dict-path': 'data/dict.txt', 'required-batch-size-multiple': 1, 'remove-bpe': 'sentencepiece', 'max-len-b': 500, 'add-bos-token': True, 'beam': 1, 'buffer-size': 1, 'image-feature-length': 64, 'locate-special-token': 1, 'batch-size': 1, 'nbest': 1, 'no-repeat-ngram-size': 3, 'location-bin-size': 32}

In [23]:
utils.import_user_module(cfg.task)

In [24]:
cfg

{'hydra': {'run': {'dir': '.'}}, 'defaults': ['_self_', {'task': '_generation_obj'}, {'model': None}, {'criterion': 'cross_entropy'}, {'optimizer': None}, {'lr_scheduler': 'fixed'}, {'bpe': None}, {'tokenizer': None}, {'scoring': None}, {'generation': None}, {'common_eval': None}, {'eval_lm': None}]}