In [1]:
import os
import json
from tqdm import tqdm
from random import sample

from llm import LLM
from config.config import LLMConfig


colors = ["red", "green", "blue", "orange"]

TP = LLM(LLMConfig("TP_OL", "Cubes"))
OD = LLM(LLMConfig("OD", "Cubes"))

def get_instruction(query:str):
    instruction = f"objects = {['blue_cube', 'green_cube', 'orange_cube', 'red_cube']}\n"
    instruction += f"# Query: {query}"
    return instruction

tasks = ["stack", "L", "pyramid"]

for i in range(5):
    queries = [
        "make a stack of cubes on top of the {} cube".format(*sample(colors, 1)),
        "rearrange cubes to write the letter L flat on the table. keep {} at its location".format(*sample(colors, 1)),
        "build a pyramid with the {} and {} cubes at the base and {} cube at the top. keep {} cube at its original position.".format(*(2*sample(colors, 3)))
    ]

    for j, t in enumerate(tasks):
        query = queries[j]
        plan = TP.run(get_instruction(query))
        optimizations = []
        for q in tqdm(plan['tasks']):
            if q not in ['open_gripper()', 'close_gripper()']:
                opt = OD.run(get_instruction(q), short_history=True)
                if "instruction" not in opt.keys():
                    optimizations.append(opt)
                else:
                    optimizations.append(None)
            else:
                optimizations.append(None)

        data = {"query": query, "plan": plan, "optimizations": optimizations}
        data_folder = f"data/llm_responses/{t}"
        n_files = len(os.listdir(data_folder))
        json.dump(data, open(f"{data_folder}/{n_files}.json", "w"), indent=4)


100%|██████████| 12/12 [00:43<00:00,  3.65s/it]
100%|██████████| 12/12 [00:34<00:00,  2.85s/it]
100%|██████████| 8/8 [00:23<00:00,  2.89s/it]
100%|██████████| 12/12 [00:33<00:00,  2.80s/it]
100%|██████████| 12/12 [00:34<00:00,  2.88s/it]
100%|██████████| 8/8 [00:44<00:00,  5.54s/it]
100%|██████████| 12/12 [01:06<00:00,  5.53s/it]
100%|██████████| 12/12 [00:49<00:00,  4.14s/it]
100%|██████████| 8/8 [00:24<00:00,  3.04s/it]
100%|██████████| 12/12 [00:43<00:00,  3.61s/it]
100%|██████████| 12/12 [00:42<00:00,  3.54s/it]
100%|██████████| 8/8 [00:21<00:00,  2.75s/it]
100%|██████████| 12/12 [00:44<00:00,  3.67s/it]
100%|██████████| 12/12 [00:40<00:00,  3.35s/it]
100%|██████████| 8/8 [00:27<00:00,  3.39s/it]
