In [1]:
import json
import pathlib
from gpt3 import MewlCaptioner
from tqdm import tqdm

In [2]:
captioner = MewlCaptioner()

In [3]:
def test_task(captioner, dataset_path, task_name):
    dataset_path = pathlib.Path(dataset_path) / task_name
    all_episodes = dataset_path.glob('*/')
    all_episodes = sorted(all_episodes, key=lambda x: int(x.name))

    all_datasets = []

    caption_func = getattr(captioner, task_name)

    for i, epi in tqdm(enumerate(all_episodes), total=len(all_episodes)):
        prompt = ""
        contexts, query = caption_func(epi)
        for context in contexts:
            prompt += f"Context: {context[0]}\nName: {context[1]}\n\n"

        prompt += f"Context: {query[0]}\nName: "

        query_description, choices, answer = query

        label = choices.index(answer)

        all_datasets.append({
            "start_phrase": prompt,
            "choices": choices,
            "label": label,
            "task": task_name
        })

    return all_datasets

In [4]:
dataset_path = pathlib.Path("/gpfs/data/epavlick/tyun/llm_causal_reasoning/MEWL/rendered_data/mewl/")

In [5]:
from model.consts import task_names

for task in task_names:
    dataset = {}
    for split in ["train", "val", "test"]:
        dataset[split] = test_task(captioner, dataset_path / split, task)

    with open(f"/gpfs/data/epavlick/tyun/llm_causal_reasoning/MEWL/rendered_data/{task}.json", "w+") as f:
        json.dump(dataset, f, indent=4)

100%|██████████| 3000/3000 [00:06<00:00, 494.30it/s]
100%|██████████| 600/600 [00:01<00:00, 444.97it/s]
100%|██████████| 600/600 [00:01<00:00, 484.45it/s]
100%|██████████| 3000/3000 [00:07<00:00, 421.56it/s]
100%|██████████| 600/600 [00:01<00:00, 327.99it/s]
100%|██████████| 600/600 [00:01<00:00, 433.18it/s]
100%|██████████| 3000/3000 [00:06<00:00, 481.51it/s]
100%|██████████| 600/600 [00:01<00:00, 449.50it/s]
100%|██████████| 600/600 [00:01<00:00, 469.95it/s]
100%|██████████| 3000/3000 [00:52<00:00, 57.57it/s]
100%|██████████| 600/600 [00:11<00:00, 50.89it/s]
100%|██████████| 600/600 [00:12<00:00, 48.49it/s]
100%|██████████| 3000/3000 [00:08<00:00, 361.78it/s]
100%|██████████| 600/600 [00:01<00:00, 386.51it/s]
100%|██████████| 600/600 [00:01<00:00, 303.20it/s]
100%|██████████| 3000/3000 [00:09<00:00, 319.87it/s]
100%|██████████| 600/600 [00:02<00:00, 295.08it/s]
100%|██████████| 600/600 [00:01<00:00, 446.40it/s]
100%|██████████| 3000/3000 [00:06<00:00, 429.61it/s]
100%|██████████| 600