In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('../')))

from tqdm import tqdm
import pandas as pd
from utils.file_utils import read_image, read_json, save_image, save_json, read_pkl
from utils.draw_utils import draw_box
from utils.helper_utils import float2_0_1000
from prompts import all_prompts

# make train

In [2]:
img_root = '/home/shaotao/DATA/AMEX/screenshot'
ann_root = '/home/shaotao/DATA/AMEX/element_anno'
shape_pkl_p = './out/amex_img_shapes.pkl'
df_p = './out/amex_info.xlsx'
out_json_p = f'amex_box2func_with_ocr.json'

df = pd.read_excel(df_p)
filt = df['num_func_ann'] > 0
df = df[filt]

In [None]:
print(df.shape, df.columns)

In [None]:
# get ori img shape dict
img_shape_dict = read_pkl(shape_pkl_p)
INIT_PROMPT = all_prompts['box2func_with_ocr_prompt_for_train']
CONTINUE_PROMPT = """## Box
({x1},{y1}),({x2},{y2})

## OCR result
{text}"""

import random
random.seed(42)
all_datas = []
for img_idx in tqdm(range(df.shape[0])):
    ann_p = os.path.join(ann_root, df.iloc[img_idx]['filename'])
    ann = read_json(ann_p)
    
    img_p = ann['image_path']
    click_ele_lst = ann['clickable_elements']
    h, w = img_shape_dict[img_p]
    img_p = os.path.join(img_root, img_p)
    if img_idx % 500 == 0:
        img = read_image(img_p)
    
    # find specific elements
    final_ele_lst = []
    for ele in click_ele_lst:
        func_ann = ele.get('functionality', '').strip()
        has_func = func_ann != ''
        if has_func:
            final_ele_lst.append(ele)
    random.shuffle(final_ele_lst)
 
    conversation = []
    for ele_idx, ele in enumerate(final_ele_lst):
        box = ele['bbox']
        func_ann = ele.get('functionality', '').strip()
        text = ele.get('xml_desc', [])
        type_ = ele.get('type', 'text')
            
        if len(text) == 0:
            text = 'null'
        else:
            text = text[0].strip()
        if 'icon' in type_.lower():
            text = 'null'
        if 'tab' in text.lower():
            text = 'null'
        text = text.replace('\n', ' ')
        text = text.replace('\u200b', '')

        has_func = func_ann != ''
        x1, y1, x2, y2 = box
        x1, y1, x2, y2 = x1 / w, y1 / h, x2 / w, y2 / h
        try:
            pt = list(map(float2_0_1000, [x1, y1, x2, y2]))
        except Exception as e:
            print('idx: ', img_idx, e)
            # skip whole image
            break
        if has_func:
            if ele_idx == 0:   
                prompt = INIT_PROMPT.format(x1=pt[0], y1=pt[1], x2=pt[2], y2=pt[3], text=text)
            else:
                prompt = CONTINUE_PROMPT.format(x1=pt[0], y1=pt[1], x2=pt[2], y2=pt[3], text=text)
        else:
            print(f'no func ann in idx: {img_idx}')
            continue


        ans = func_ann
        conversation.append({'from': 'human', 'value': prompt})
        conversation.append({'from': 'gpt', 'value': ans})
        if img_idx % 500 == 0:
            img = draw_box(img, (x1, y1, x2, y2))
    if img_idx % 500 == 0:
        save_image(img, f'tmp_{img_idx}.jpg')
    if len(conversation) == 0:
        print('skipping idx: ', img_idx)
        continue
    line = {'conversation': conversation, 'image_lst': [os.path.join(img_root, img_p)]}
    all_datas.append(line)
    if len(all_datas) % 500 == 0:
        print(f'IDX: {len(all_datas)},  sample func_ann: {func_ann}')
    # break
print('total data num: ', len(all_datas))
save_json(all_datas, out_json_p)