In [None]:
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#

In [None]:
import os, pickle, io, base64, json

from glob import glob
from tqdm.auto import tqdm

from PIL import Image

import torch as T
import transformers

from llava.conversation import conv_templates
from llava.model import *

def f2b(f):
    b = io.BytesIO()
    f.save(b, format='JPEG')
    b = str(base64.b64encode(b.getvalue()))[2:-1]
    return b
def b2f(b):
    return Image.open(io.BytesIO(base64.b64decode(b))).convert('RGB')
def crop_resize(f, sz=512):
    w, h = f.size
    if w>h:
        p = (w-h)//2
        f = f.crop([p, 0, p+h, h])
    elif h>w:
        p = (h-w)//2
        f = f.crop([0, p, w, p+w])
    f = f.resize([sz, sz])
    return f
def remove_alter(s):  # hack expressive instruction
    if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()
    if '</s>' in s: s = s[:s.index('</s>')].strip()
    if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
    if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
    s = '.'.join([s.strip() for s in s.split('.')[:2]])
    if s[-1]!='.': s += '.'
    return s.strip()

In [None]:
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'

MODEL_NAME = './_ckpt/LLaVA-7B-v1'
model_name = os.path.expanduser(MODEL_NAME)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)

tokenizer.padding_side = 'left'

mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

vision_tower = model.get_model().vision_tower[0]
vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
model.get_model().vision_tower[0] = vision_tower
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size//vision_config.patch_size)**2

_ = model.eval()

In [None]:
#요약을 위한 모델 가져오기
summer = transformers.pipeline('summarization', 'jordiclive/flan-t5-11b-summarizer-filtered', torch_dtype=T.bfloat16, device=0)

In [None]:
#최종 결과 저장을 위한 변수들 정의
pkl, tsv, ei = {'task': []}, open('./_data/ipr2pr.tsv', 'w'), {}

lst = glob('_data/*/prompt.json')
for file in tqdm(lst):
    prompt = json.load(open(file, 'r'))
    txt = prompt['edit']

    txt = "what will this image be like if '%s'  (in a short paragraph)"%(txt) #prompt를 넣어 만들어진 새로운 instruction
    txt = txt+'\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
    conv = conv_templates['vicuna_v1_1'].copy()
    conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
    txt = conv.get_prompt()
    txt = tokenizer(txt) #토크나이저에 txt를 넣음. 
    txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
    
    for img in glob('/'.join(file.split('/')[:-1])+'/*_0.jpg'):
        item = file.split('/')[-2]+'_'+img.split('/')[-1].replace('.jpg', '')
        inp, ans = Image.open(img).convert('RGB'), Image.open(img.replace('_0.jpg', '_1.jpg')).convert('RGB')
        
        img = image_processor.preprocess(inp, return_tensors='pt')['pixel_values'][0] #인풋 이미지를 clip processor에 넣고 돌림.
        with T.inference_mode():
            #model = LlavaLlamaForCausalLM -> MLLM model, text output
            out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(), 
                                 do_sample=False, max_new_tokens=1024)[0].tolist()
            
            out = remove_alter(tokenizer.decode(out))
            res = summer(['summarize the following paragraph in 32 words:\n\n%s'%(out)], num_beams=5, min_length=5, max_length=64, 
                         do_sample=False, no_repeat_ngram_size=3, truncation=True)[0]['summary_text']

        #input, goal, instruction 저장
        pkl['task'].append([{'input': item, 'answer': item.replace('_0', '_1'), 'instruction': prompt['edit'], 'lineidx': tsv.tell()}])
        #input image, goal image 저장
        tsv.write('%s\t%s\n'%(f2b(inp), f2b(ans)))
        #원래 instruction , prompt + instruction을 저장
        ei[item] = {'instruction': prompt['edit'], 'expressive': res}


#위의 변수에 저장한 값을 각 형태로 저장해주고 있음. 
pickle.dump(pkl, open('./_data/ipr2pr.pkl', 'wb'))
tsv.flush(), tsv.close()
json.dump(ei, open('./_data/ipr2pr_expressive.json', 'w'), indent=2)

In [None]:
#저장이 잘 되었는지 확인함. 
pkl, tsv, ei = pickle.load(open('./_data/ipr2pr.pkl', 'rb')), open('./_data/ipr2pr.tsv', 'r'), json.load(open('./_data/ipr2pr_expressive.json', 'r'))
for task in pkl['task']:
    task = task[0]
    tsv.seek(task['lineidx'])
    b = tsv.readline().strip().split('\t')
    print(task)
    display(b2f(b[0])), display(b2f(b[1]))
    print(ei[task['input']])
    print('\n-----\n')