In [None]:
import os
import sys
import pandas as pd

os.environ['CUDA_VISIBLE_DEVICES'] = '5'
!export CUDA_VISIBLE_DEVICES=5

In [None]:
pd.set_option('display.max_colwidth', 1000)

In [None]:
from tqdm.auto import tqdm

## Parallel exec

### Load data

In [None]:
import pandas as pd

# df = pd.read_json('/data/vkarlov/robotics/train_dataset/train_dataset.json')
# df['image'] = '/data/vkarlov/robotics/train_dataset/train_images/' + df['image']

df = pd.read_json('/home/edamirov/notebooks/ml_hack/playground_pipeline/test_dataset/test_dataset.json')
df['image'] = '/home/edamirov/notebooks/ml_hack/playground_pipeline/test_dataset/test_images/train_images/' + df['image']


# df = pd.read_csv('df_mistral_questions.tsv', sep='\t')

In [None]:
df.shape

### Infer first stage

In [None]:
ROBOT_SYSTEM = """
You are helpful robot with arm operating in a house. You will be asked to do various tasks and you should tell me the sequence of actions you would do to accomplish my task. 
You have 3 possible actions: "pick_up(from, what)", "move_to(where, what)" and "put(where, what)". 

You should ask from 3 to 5 SHORT questions about surroundings to make precise instructions.
Pay special attention to whether you are holding any object already, objects locations and robot current location and state - ask relevant questions about all this things.
You must work only with objects and locations mentioned in request!
IMPORTANT: you must take into account whether you need to move your arm or yourself and if neccessary item is already in your arm!
DO NOT SOLVE TASK OR LIST INSTRUCTIONS, JUST ASK QUESTIONS!
Example (DON'T REPEAT THIS QUESTIONS!):

USER: 
How would you take a cucumber from the table and put it in an orange box?
ASSISTANT: 
1. Is there anything robot holding now? 
2. What robot is holding?
3. Is orange box far away or nearby?
4. Is there anything is the box?
""".strip()

In [None]:
import torch
from transformers import pipeline

pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha", torch_dtype=torch.bfloat16, device_map="auto", max_new_tokens=256, batch_size=48)

In [None]:
query_prompts_1_stage = []
for cur_query in tqdm(df['goal_eng'].to_list()):
    messages = [
        {
            "role": "system",
            "content": ROBOT_SYSTEM,
        },
        {
            "role": "user", 
            "content": cur_query
        },
    ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    query_prompts_1_stage.append(prompt)

In [None]:
mistral_1_stage_res = pipe(query_prompts_1_stage)

In [None]:
mistral_questions = [i[0]['generated_text'].split('<|assistant|>\n')[1] for i in mistral_1_stage_res]

In [None]:
df['mistral_questions'] = mistral_questions

In [None]:
# df.to_csv('df_mistral_questions_TEST.tsv', sep='\t', index=False)

### Parallel LLaVa

In [None]:
import pandas as pd
import os
from datasets import load_dataset
from PIL import Image
import itertools
from transformers import ViltProcessor

os.environ['CUDA_VISIBLE_DEVICES'] = '5'
!export CUDA_VISIBLE_DEVICES='5'

import torch
torch.cuda.device_count()

import argparse
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
import json
import numpy as np

In [None]:
from dataclasses import dataclass, field
from torch.utils.data import Dataset

In [None]:
from typing import List

@dataclass
class Step:
    action: str = ""
    text: str = ""
    arguments: List[str] = field(default_factory=list)

@dataclass
class SorterTask():
    action: str = ""
    image: str = ""
    text: str = ""
    goal: str = ""
    text: str = ""
    task_type: int = -1
    plan_id: int = -1
    steps: List[Step] = field(default_factory=list)
    arguments: List[str] = field(default_factory=list)

    def to_list(self):
        return [[step.action, [arg for arg in step.arguments]] for step in self.steps]

class SorterDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        # with open(path_to_csv, 'r') as f:
        #     self._data = json.load(f)
        self._data = df
        self._size = len(self._data)

    def __len__(self):
        return self._size

    def __getitem__(self, idx) -> SorterTask:
        entry = self._data[idx]
        steps = []
        # for plan in entry['plan']:
        #     steps.append(Step(action=plan[0],
        #                       arguments=plan[1]))
        return SorterTask(goal=entry['mistral_questions'],
                        image=entry['image'],
                        # image='/data/vkarlov/robotics/train_dataset/train_images/' + "1808600344675605482_0.png",
                        steps=steps,
                        task_type=entry['task_type'],
                        plan_id=entry["plan_id"])    

In [None]:
import torch
import torch.nn.functional as F

from tqdm import tqdm
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import pipeline
from typing import Any, List, Optional


@dataclass
class BaseInput:
    text: Optional[str] = None
    

@dataclass
class BaseOutput:
    text: Optional[str] = None

In [None]:
model_path = "liuhaotian/llava-v1.5-7b"
conv_mode = 'llava_v1'

In [None]:
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None, 
    model_name=model_name, 
    load_8bit=True, 
    load_4bit=False)

In [None]:
if 'llama-2' in model_name.lower():
    conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

if conv_mode is not None and conv_mode != conv_mode:
    print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, conv_mode, conv_mode))
else:
    conv_mode = conv_mode

In [None]:
LLAVA_SYSTEM = \
'''
You are acting as vision module of robot with arm.
You are responsible for recognising objects and locations around you.
Answer on all questions based on given picture, mention any details important for robot.

Pay special attention to whether there is something in your arm or not!
Your arm is black, surrounded with red element.
IMPORTANT: Your arm is in the middle of the image!
'''

In [None]:
def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def preprocess_text_to_inp_ids(req):
    inp = f"Request: How whould you {req.lower()[:-1]}?\nAnswer:"
    inp = LLAVA_SYSTEM + inp
    conv = conv_templates[conv_mode].copy()
    roles = conv.roles
    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').cuda()
    # print(input_ids.shape)
    # print(prompt)
    return input_ids

def pad_ids(input_ids, max_batch_len):
    p1d = (max_batch_len - len(input_ids), 0)
    return F.pad(input_ids, p1d, "constant", 0)

def create_att_mask(input_ids, max_batch_len):
    return torch.cat((
        torch.zeros(max_batch_len - len(input_ids)), torch.ones(len(input_ids))
        ))

def infer_model(batch):
    batch_images = [load_image(elem.image) for elem in batch]
    image_tensors = image_processor.preprocess(batch_images, return_tensors='pt')['pixel_values'].half().cuda()
    # print(image_tensors)
    batch_input_ids = [preprocess_text_to_inp_ids(elem.goal) for elem in batch]
    max_batch_len = np.max([len(ids) for ids in batch_input_ids])
    batch_att_masks = [create_att_mask(ids, max_batch_len) for ids in batch_input_ids]
    batch_input_ids = [pad_ids(ids, max_batch_len) for ids in batch_input_ids]
    # print(batch_att_masks[0])
    # print(batch_input_ids[0])
    conv = conv_templates[conv_mode].copy()
    roles = conv.roles

    # print(prompt)
    # print('-' * 100)

    # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    # keywords = [stop_str]
    # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, batch_input_ids[0])
    # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids=torch.stack(batch_input_ids),
            attention_mask=torch.stack(batch_att_masks),
            images=image_tensors,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=1024,
            # streamer=streamer,
            use_cache=True,
            # stopping_criteria=[stopping_criteria]
            )

    return [tokenizer.decode(output_ids[i, batch_input_ids[i].shape[0]:], skip_special_tokens=True).strip() for i in range(len(batch))]


In [None]:
import re
from typing import List, Optional, Union

class PromptProcessor():
    def __init__(self, **kwargs) -> None:
        self.TERMINATING_STRING = 'done()'
        self._system_prompt = ""
        self._stop_step_pattern = ""
        self._stop_pattern = re.compile(f'\\d+\\. {self.TERMINATING_STRING}.')

    @property
    def system_prompt_is_set(self) -> bool:
        return len(self._system_prompt) > 0

    def is_terminating(self, step: Step) -> bool:
        return step.text == self.TERMINATING_STRING

    def build_system_prompt(self, example_tasks: List[SorterTask]) -> str:
        prompt = "Robot: Hi there, I’m a robot operating in a house.\n"
        prompt += "Robot: You can ask me to do various tasks and "
        prompt += "I’ll tell you the sequence of actions I would do to accomplish your task.\n"

        for task in example_tasks:
            prompt += self._task_to_prompt(task) + '\n'

        self._system_prompt = prompt
        self._stop_step_pattern = re.compile(
            r'(\s*\d+\.\s*)(\w+\(("[\w ]+"(,\s)?)*\))*')

    def load_prompt_from_file(self, filepath: str) -> None:
        with open(filepath, 'r') as file:
            self._system_prompt = file.read()
        self._stop_step_pattern = re.compile(
            r'(\s*\d+\.\s*)(\w+\(("[\w ]+"(,\s)?)*\))*')

    def _goal_to_query(self, goal: str) -> str:
        query = f"Human: How would you {goal.lower()}?\n"
        query += f'Robot: '
        return query

    def _step_to_text(self, step: Step) -> str:
        arguments = [f'"{argument}"' for argument in step.arguments]
        text = f'{step.action}({", ".join(arguments)})'
        return text

    def _steps_to_text(self,
                       steps: List[Step],
                       add_terminating_string: bool = True) -> str:
        text = ", ".join([f'{step_idx}. {self._step_to_text(step)}'
                          for step_idx, step in enumerate(steps, start=1)])
        if add_terminating_string:
            text += f", {len(steps) + 1}. {self.TERMINATING_STRING}."
        return text

    def _task_to_prompt(self, task: SorterTask) -> str:
        prompt = self._goal_to_query(task.goal)
        text = self._steps_to_text(task.steps)
        task.text = text
        prompt += text
        return prompt

    def to_inputs(self,
                  task: SorterTask,
                  steps: Optional[List[Step]] = None,
                  options: Optional[List[Step]] = None) -> BaseInput:
        if not self.system_prompt_is_set:
            raise ValueError(
                "System prompt is not set. You need to set system prompt.")
        else:
            text = self._system_prompt + self._goal_to_query(task.goal)
            if steps is not None:
                text += self._steps_to_text(steps, add_terminating_string=False)
            if options is not None:
                return ScoringInput(text=text, options=[f'{len(steps) + 1}. {option.text}' for option in options])
            return BaseInput(text=text)

    def _text_to_steps(self, task_text: str, cut_one_step: bool = False) -> Union[List[Step], Step, None]:
        self._stop_step_pattern = re.compile(r'(\s*\d+\.\s*)(\w+\(("[\w ]+"(,\s)?)*\))*')

        if cut_one_step:
            stop_match = self._stop_step_pattern.match(task_text)
            if stop_match is None:
                return None
            else:
                return self._parse_action(stop_match.group(2))
        else:
            stop_match = self._stop_step_pattern.findall(task_text)
            steps = []
            if stop_match is None:
                return steps
            else:
                for i in range(len(stop_match) - 1):
                    step_text = stop_match[i][1]
                    step = self._parse_action(step_text)
                    if step is not None:
                        steps.append(step)
                return steps

    def _parse_action(self, step_text: str) -> Optional[Step]:
        """ Parse action with arguments to step.
        text: put_on('pepper', 'white box')
        action: put_on
        arguments: ['pepper', 'white box']
        """
        step_decomposition_pattern = re.compile(r'\s*([A-Za-z_][A-Za-z_\s]+)')
        arguments = step_decomposition_pattern.findall(step_text)

        if arguments is None:
            return None
        if len(arguments) == 1:
            step = Step(text=step_text)
        else:
            step = Step(action=arguments[0],
                        arguments=arguments[1:],
                        text=step_text)
            return step

    def to_task(self, task: BaseOutput) -> SorterTask:
        # Full plan generation mode
        stop_match = self._stop_pattern.search(task.text)

        if stop_match is not None:
            task.text = task.text[:stop_match.end() + 2].strip(' \n\t')
        else:
            task.text = task.text.strip(' \n\t')

        steps = self._text_to_steps(task_text=task.text)

        return SorterTask(text=task.text, steps=steps)

In [None]:
from typing import List

class FullPlanGeneration():
    def __init__(self,
                 model,
                 processor,
                 **kwargs):
        self._processor = processor
        self._model = model

    def predict(self, gt_task_batch: List[SorterTask]) -> List[SorterTask]:
        batch_outputs = infer_model(gt_task_batch)
        # model_ouputs = [self._processor.to_task(BaseOutput(text.replace('\\', ''))) for text in batch_outputs]
        return batch_outputs

In [None]:
df = pd.read_csv('df_mistral_questions_TEST.tsv', sep='\t')

In [None]:
from pathlib import Path
df['image'] = df['image'].apply(lambda img_path: f'/home/edamirov/notebooks/ml_hack/playground_pipeline/test_dataset/test_images/{Path(img_path).name}')

In [None]:
from torch.utils.data import DataLoader, Dataset

BATCH_SIZE = 4

# path_to_csv = "/data/vkarlov/robotics/train_dataset/train_dataset.json"
# dataset = SorterDataset(path_to_csv=path_to_csv)
dataset = SorterDataset(df=df.to_dict('records'))

print(dataset[0])
# print(dataset[1])
# print(dataset[2])

dataloader = DataLoader(
        dataset,
        shuffle=False,
        batch_size=BATCH_SIZE,
        num_workers=16,
        collate_fn=lambda x: x,
    )

In [None]:
processor = PromptProcessor()
# processor.build_system_prompt([dataset[i] for i in range(10)])
# print(processor._system_prompt)
gen_method = FullPlanGeneration(model, processor)

In [None]:
results = []

for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    # print(batch)

    for i in range(len(batch)):
        batch[i].text = processor._steps_to_text(batch[i].steps)
        
    batch_predict = gen_method.predict(batch)
    # print(batch_predict)
    for elem, pred_elem in zip(batch, batch_predict):
        results.append(
            {
                'plan_id' : elem.plan_id,
                'plan' : pred_elem,
                'predicted_text' : pred_elem,
                'goal' : elem.goal,
                'image' : elem.image,
            }
        )

In [None]:
len(results)

In [None]:
# import pickle 
# with open('llava_results_copy_1_df_dump.pkl', 'wb') as f:
#     pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
results[0]

In [None]:
df.head(n=1)

In [None]:
df['llava_response'] = pd.DataFrame(results)['predicted_text'].to_list()

## Infer mistral final iter

In [None]:
query_prompts_2_stage = []
for cur_row in tqdm(df.to_dict('records')):
    ROBOT_SYSTEM_2 = f"""
    You are writing instructions for a helpful robot operating in a house. You will be asked to do various tasks and you should tell me the sequence of actions you would do to accomplish my task. You have 3 possible actions: "pick_up(from, what)", "move_to(where, what)" and "put(where, what)". At the end of sequence you should write "done()".
    Short surroundings description:
    {cur_row['llava_response']}
    
    You have 3 possible actions: "pick_up(from, what)", "move_to(where, what)" and "put(where, what)". At the end of sequence you should write "done()".
    If you need just nearest object, you should use "pick_up(unspecified, object)", if you need to move yourself - use "move_to(location, unspecified)"
    so if the location (where/from) or object(what) are not specified, you should write "unspecified" as action parameter.
    IMPORTANT: both arguments must not be unspecified at the same time!
    IMPORTANT: unspecified may be used only as "what" argument for move_to and "from" argument for pick_up
    
    YOU SHOULD ONLY OUTPUT SEQUENCE OF ACTIONS, AVOID ANY COMMENTS.
        
    Examples:
    Short surroundings description: Table with vegetables and a green container and a bedside table.
    user: How would you drive up to the bedside table?
    assistant: 1. move_to("bedside table", "unspecified"), 2. done().
    
    Short surroundings description: Robot with cat in the manipulator arm.
    user: How would you put the gray cat in the orange box?
    assistant: 1. move_to("orange box", "gray cat"), 2. put("orange box", "gray cat"), 3. done().
    
    Short surroundings description: Cucumber on the table.
    user: How would you take a cucumber from the table and put it in an orange box?
    assistant: 1. move_to("table", "cucumber"), 2. pick_up("table", "cucumber"), 3. move_to("orange box", "cucumber"), 4. put("orange box", "cucumber"), 5. done().
    
    Short surroundings description: Orange kitten on the floor and a green box.
    user: How would you put the orange kitten in the green box?
    assistant: 1. move_to("unspecified", "orange kitten"), 2. pick_up("unspecified", "orange kitten"), 3. move_to("green box", "orange kitten"), 4. put("green box", "orange kitten"), 5. done().
    
    Short surroundings description: Toy cat and green container.
    user: How would you put the toy cat in the green container?
    assistant: 1. move_to("green container", "toy cat"), 2. put("green container", "toy cat"), 3. done().
    """.strip()
    
    messages = [
        {
            "role": "system",
            "content": ROBOT_SYSTEM_2,
        },
        {
            "role": "user", 
            "content": cur_row['goal_eng']
        },
    ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    query_prompts_2_stage.append(prompt)

In [None]:
len(query_prompts_2_stage)

In [None]:
# query_prompts_2_stage[0]

In [None]:
mistral_final_stage_res = pipe(query_prompts_2_stage)

In [None]:
mistral_final_stage_res[0]

In [None]:
from itertools import chain
df['mistral_final_res'] = pd.DataFrame(chain(*mistral_final_stage_res))['generated_text'].apply(lambda s: s.split('<|assistant|>\n')[1])

In [None]:
df.to_csv('df_mistral_final_res.tsv', sep='\t', index=False)

### Parse res

In [None]:
cmds_set = ['move_to', 'pick_up', 'put']

def parse_cmd_res(cmd_str):
    try:
        trunc = cmd_str.split('done()')[0]
        sep_lines = trunc.split(')')
    
        
        buf_res = []
        for cur_cmd in sep_lines:
            cur_res_cmd = []
            cur_cmd_splt = cur_cmd.split('(')
            for i in cmds_set:
                if i in cur_cmd_splt[0]:
                    cur_res_cmd.append(i)
        
            if len(cur_res_cmd) == 0:
                break
            cur_res_cmd.append(list(map(eval, cur_cmd_splt[1].split(','))))
        
            buf_res.append(cur_res_cmd)
    except:
        return "ERROR"
    return buf_res

In [None]:
df['parsed_cmd'] = df['mistral_final_res'].apply(parse_cmd_res)

## Итог
В поле parsed_cmd - итоговый результат модели для прогона на датафрейме