In [1]:
import csv
from collections import defaultdict
import re
import datasets
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
data_path = "../assets/data/scone/rlong"
splits = ["train", "dev", "test"]
tasks = ["alchemy", "scene", "tangrams"]


def tsv_to_dict_of_lists(file_path):
    with open(file_path, "r", newline="") as tsv_file:
        reader = csv.reader(tsv_file, delimiter="\t")

        # Read the first row to determine the number of columns
        first_row = next(reader)
        num_columns = len(first_row)

        # Generate headers
        headers = ["ID", "WORLD_0"]
        for i in range(1, (num_columns - 2) // 2 + 1):
            headers.extend([f"UTTERANCE_{i}", f"WORLD_{i}"])

        # Create a dictionary to store the lists
        result_dict = {header: [] for header in headers}

        # Reset the file pointer to the beginning
        tsv_file.seek(0)

        # Process each row
        for row in reader:
            for i, value in enumerate(row):
                if i < len(headers):
                    result_dict[headers[i]].append(value)

    return result_dict


task_datasets = defaultdict(list)

for split in splits:
    for task in tasks:
        ds = datasets.Dataset.from_dict(
            tsv_to_dict_of_lists(f"{data_path}/{task}-{split}.tsv")
        )
        ds = ds.add_column("task", [task] * len(ds))
        task_datasets[split].append(ds)

for split, ds_list in task_datasets.items():
    task_datasets[split] = datasets.concatenate_datasets(ds_list)

scone_dataset = datasets.DatasetDict(task_datasets)

In [3]:
tangrams = scone_dataset.filter(lambda x: x["task"] == "tangrams")
scenes = scone_dataset.filter(lambda x: x["task"] == "scene")
alchemy = scone_dataset.filter(lambda x: x["task"] == "alchemy")

Filter:   0%|          | 0/11198 [00:00<?, ? examples/s]

Filter:   0%|          | 0/642 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2734 [00:00<?, ? examples/s]

Filter:   0%|          | 0/11198 [00:00<?, ? examples/s]

Filter:   0%|          | 0/642 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2734 [00:00<?, ? examples/s]

Filter:   0%|          | 0/11198 [00:00<?, ? examples/s]

Filter:   0%|          | 0/642 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2734 [00:00<?, ? examples/s]

In [61]:
# mapping from number to word
from collections import Counter

num2word = {
    1: "first",
    2: "second",
    3: "third",
    4: "fourth",
    5: "fifth",
    6: "sixth",
    7: "seventh",
    8: "eighth",
    9: "ninth",
    10: "tenth",
}
color_map = {
    "g": "green",
    "b": "blue",
    "r": "red",
    "y": "yellow",
    "p": "purple",
    "o": "orange",
}


def extract_index_and_number(input_string):
    pattern = r"^(\d+):(.*)$"
    match = re.search(pattern, input_string)

    if match:
        index = match.group(1)
        number = match.group(2)
        return index, number
    else:
        return None


def alchemy_state_to_nl(state: str):
    beakers = list(map(lambda x: extract_index_and_number(x), state.split(" ")))

    def color_sequence_to_instruction(sequence):
        # Count the occurrences of each color
        color_counts = Counter(sequence.lower())
        # Create a list of color instructions
        instructions = []
        for color, count in color_counts.items():
            full_color_name = color_map[color]
            instructions.append(f"{count} {full_color_name}")

        # Join the instructions
        if len(instructions) == 1:
            return instructions[0]
        else:
            return "{" + ", ".join(instructions) + "}"

    def to_nl(x):
        i, s = x
        if s[1] == "_":
            return f"the {num2word[i + 1]} beaker is empty"
        return f"the {num2word[i + 1]} beaker has {color_sequence_to_instruction(s[1])}"

    return ", ".join(map(to_nl, enumerate(beakers)))


def scene_state_to_nl(state: str):
    positions = list(map(lambda x: extract_index_and_number(x), state.split(" ")))

    def to_nl(x):
        i, s = x
        if s[1][0] == "_":
            return f"the {num2word[i + 1]} position is empty"
        hat = color_map[s[1][1]] if s[1][1] != "_" else "no"
        return f"the {num2word[i + 1]} position is occupied by a person with a {color_map[s[1][0]]} shirt and {hat} hat"

    return ", ".join(map(to_nl, enumerate(positions)))


def tangram_state_to_nl(state: str):
    tangrams = list(map(lambda x: extract_index_and_number(x), state.split(" ")))

    def to_nl(x):
        i, s = x
        if s[1] == "_":
            return f"the {num2word[i + 1]} tangram is not placed"
        return f"{num2word[i + 1]} object id={s[1]}"

    return ", ".join(map(to_nl, enumerate(tangrams)))


def sequence_to_instruction(example: dict, turn_limit: int):
    if example["task"] == "alchemy":
        nl_fn = alchemy_state_to_nl
    elif example["task"] == "tangrams":
        nl_fn = tangram_state_to_nl
    elif example["task"] == "scene":
        nl_fn = scene_state_to_nl

    limit = len([k for k in example.keys() if k.startswith("WORLD_")])

    world_states = [nl_fn(example[f"WORLD_{i}"]) for i in range(0, limit)]
    utterances = [example[f"UTTERANCE_{i}"] for i in range(1, limit)]
    utterances.insert(0, "")
    utterances.append("")

    instructions = []
    output = []

    for i, state in enumerate(world_states):
        utterance = utterances[i + 1]
        if i + 1 <= turn_limit:
            instructions.append(f"{state}\n{utterance}".strip())
        else:
            output = state
            break

    return "\n".join(instructions), output

In [62]:
instr, output = sequence_to_instruction(scenes["train"][0], 3)

print(instr)
print("=" * 80)
print(output)

the first position is empty, the second position is empty, the third position is empty, the fourth position is empty, the fifth position is empty, the sixth position is empty, the seventh position is occupied by a person with a green shirt and orange hat, the eighth position is empty, the ninth position is empty, the tenth position is occupied by a person with a yellow shirt and orange hat
a man in a green shirt and an orange hat stands near the middle and a man in a yellow shirt and an orange hat stands on the far right
the first position is occupied by a person with a red shirt and no hat, the second position is empty, the third position is empty, the fourth position is empty, the fifth position is empty, the sixth position is empty, the seventh position is occupied by a person with a green shirt and orange hat, the eighth position is empty, the ninth position is empty, the tenth position is occupied by a person with a yellow shirt and orange hat
a man in a red shirt and no hat enter

In [63]:
instr, output = sequence_to_instruction(alchemy["train"][0], 3)

print(instr)
print("=" * 80)
print(output)

the first beaker has 3 green, the second beaker is empty, the third beaker is empty, the fourth beaker is empty, the fifth beaker has 1 orange, the sixth beaker has 3 orange, the seventh beaker has 4 green
throw out two units of first beaker
the first beaker has 1 green, the second beaker is empty, the third beaker is empty, the fourth beaker is empty, the fifth beaker has 1 orange, the sixth beaker has 3 orange, the seventh beaker has 4 green
throw out fifth beaker
the first beaker has 1 green, the second beaker is empty, the third beaker is empty, the fourth beaker is empty, the fifth beaker is empty, the sixth beaker has 3 orange, the seventh beaker has 4 green
throw out first one
the first beaker is empty, the second beaker is empty, the third beaker is empty, the fourth beaker is empty, the fifth beaker is empty, the sixth beaker has 3 orange, the seventh beaker has 4 green


In [64]:
instr, output = sequence_to_instruction(tangrams["train"][0], 3)

print(instr)
print("=" * 80)
print(output)

first object id=2, second object id=1, third object id=4, fourth object id=0, fifth object id=3
delete the second object from the left
first object id=2, second object id=4, third object id=0, fourth object id=3
delete the leftmost object
first object id=4, second object id=0, third object id=3
swap the leftmost and the rightmost objects
first object id=3, second object id=0, third object id=4


In [None]:
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name).to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)


def generate_text(prompt, max_new_tokens=100):
    # Load pre-trained model and tokenizer
    # Encode the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")

    # Generate text
    output = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
    )

    # Decode the generated text
    generated_text = tokenizer.decode(
        output[0][input_ids.shape[1] :], skip_special_tokens=True
    )

    return generated_text

In [None]:
instr, out = sequence_to_instruction(scone_dataset["train"][0], 3)
result = generate_text(instr, 100)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [None]:
print(result)

 unit of second unit
The first unit has 2 green and the other one has 3 orange.
Throw out one of the units with the first green. The second one with 2 orange and one orange has 4 green but the orange with 3 green is not green so the green with 4 orange is green instead of orange
If the unit with green has a green unit, throw out the one that has orange but not orange because the Orange with orange unit is orange instead. If the Green with Orange unit
