In [None]:
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

In [None]:
import orjson
from transformers import AutoTokenizer

import os
import cv2
from copy import deepcopy
import random
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from collections import defaultdict, Counter
from tqdm import tqdm

In [None]:
PRETRAINED = "lmms-lab/llama3-llava-next-8b"
SYSTEM_PROMPT = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."

In [None]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)

In [None]:
all_data = []
with open('../../data/processed_data_v2/refcoco_data.json', 'r') as f:
    all_data += orjson.loads(f.read())
    
# with open('../../data/processed_data_v2/ade20k_ref_data.json', 'r') as f:
#     all_data += orjson.loads(f.read())
    
# with open('../../data/processed_data_v2/paco_ref_data.json', 'r') as f:
#     all_data += orjson.loads(f.read())
    
# with open('../../data/processed_data_v2/partimagenet_ref_data.json', 'r') as f:
#     all_data += orjson.loads(f.read())

In [None]:
img_grouped_data = {}
for d in all_data:
    img_grouped_data.setdefault(d['image_path'], []).append(d)
    
img_bboxs = {}
for d in all_data:
    if len(d['bboxes']) > 0:
        img_bboxs.setdefault(d['image_path'], []).append(tuple(d['bboxes'][0]))

In [None]:
len(img_grouped_data), len(img_bboxs)

In [None]:
sample_group = list(img_grouped_data.values())[777]
d = sample_group[10]
len(sample_group)

In [None]:
[x['phrases'] for x in sample_group]

In [None]:
count_dict = {
    '1': 'single answer',
    '1+': 'maybe multiple answers',
    '0+': 'maybe no or multiple answers',
}

In [None]:
def bbox_to_str(bbox):
    return f"[{bbox[0]:03d},{bbox[1]:03d},{bbox[2]:03d},{bbox[3]:03d}]"

def point_to_str(point):
    return f"({point[0]:03d},{point[1]:03d})"

In [None]:
MAX_PACKING = 5
all_convs = []

for img_path, sample_group in tqdm(img_grouped_data.items()):

    sample_group_copy = deepcopy(sample_group)
    random.shuffle(sample_group_copy)
    
    to_i_sample = 0
    for _ in range(20):
        if to_i_sample >= len(sample_group_copy):
            break

        img_conv = [{"role": "system", "content": SYSTEM_PROMPT}]
        
        for i_conv, i_sample in enumerate(range(to_i_sample, to_i_sample+MAX_PACKING)):
            if i_sample >= len(sample_group_copy):
                break
            
            ref_sample = sample_group_copy[i_sample]
            
            ref_conv = []
            if isinstance(ref_sample['phrases'], list):
                s_phrase = random.choice(ref_sample['phrases'])
            else:
                s_phrase = ref_sample['phrases']
            
            # print(s_phrase)
            
            # answer_counts = ref_sample['answer_counts']
            # answer_counts_str = count_dict[answer_counts]
            
            bboxes = np.array(ref_sample['bboxes'])
            points_and_labels = ref_sample['points_and_labels']
            
            answer_counts_str = '0+'
            question_box = '<image>\n' if i_conv == 0 else ''
            question_box += f'Please provide the bounding box coordinate of the region this sentence describes ({answer_counts_str}):\n"{s_phrase}".'
            if len(bboxes) == 0:
                answer_box = 'No object found.'
            else:
                answer_box = ' '.join([bbox_to_str(x) for x in bboxes])

            ref_conv.extend([
                {"role": "user", "content": question_box},
                {"role": "assistant", "content": f'\n{answer_box}'}
            ])
            
            bb_pnls = list(zip(bboxes, points_and_labels))
            random.shuffle(bb_pnls)
            for bbox, p_n_ls in bb_pnls:
                n_sel_points = random.normalvariate(10, 4)
                n_sel_points = int(max(1, min(20, n_sel_points)))
                # print('n_sel_points', n_sel_points)
                sampled_points_and_labels = random.sample(p_n_ls, n_sel_points)
                
                points_txt = ' '.join([point_to_str(x[:2]) for x in sampled_points_and_labels])
                question_points = 'Check if the points listed below are located on the object with bounding box {}:\n{}'.format(
                    bbox_to_str(bbox), points_txt)
                answer_points = ''.join(['Yes' if x[2] else 'No' for x in sampled_points_and_labels])
                
                ref_conv.extend([
                    {"role": "user", "content": question_points},
                    {"role": "assistant", "content": f'\n{answer_points}'}
                ])
            
            test_input_ids = tokenizer.apply_chat_template(img_conv + ref_conv, tokenize=True)
            # print(len(test_input_ids))
            if len(test_input_ids) > 1536:
                # print('fulled! go next\n')
                break
            else:
                img_conv.extend(ref_conv)
                to_i_sample = i_sample + 1
            
            
        all_convs.append({
            'image_path': img_path,
            'conversation': img_conv
        })

In [None]:
with open('./refcoco_convs_ep1.json', 'w') as f:
    f.write(orjson.dumps(all_convs).decode())