In [48]:
import os
import csv

def get_tiny_tom(config):
    tinytom = {}

    for cond in config["conditions"]:
        tinytom[cond] = []
        final_sentences = []
        with open(os.path.join(config["tom_data_dir"], cond, config["condition_file"]), 'r') as f:
            reader = csv.reader(f, delimiter=';')
            for row in reader:
                final_sentences.append(row[2])
        with open(os.path.join(config["tom_data_dir"], cond, config["story_file"]), 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                text = line.strip()
                last_period_index = text.strip().rfind(".")
                story = text[:last_period_index+1]
                final_sentence = final_sentences[i]
                if 'believe' in final_sentence:
                    final_sentence = final_sentence.replace('believe', 'think')
                story = story + " " + final_sentences[i]
                tinytom[cond].append(story)
    # TODO: ensure all conditions are shuffled in the same order
    # keeping unshuffled for now
    return tinytom

In [49]:
config = {"tom_data_dir": "../../data/conditions/tinytom",
    "condition_file": "stories.csv",
    "story_file": "converted.txt",
    "conditions": ["0_forward_belief_true_belief",
                "0_forward_belief_false_belief",
                "1_forward_belief_true_belief",
                "1_forward_belief_false_belief"],
}
tinytom = get_tiny_tom(config)
num_tiny_tom = sum([len(tinytom[cond]) for cond in config["conditions"]])
print("Number of tinytom stories: {}".format(num_tiny_tom))

Number of tinytom stories: 2600


In [50]:
raw_datasets = {'train': [], 'val_tom': []}
print("Splitting into train and val")
for cond in config["conditions"]:
    # num_train = int(len(tinytom[cond]) * config["train_ratio"])
    # num_val = len(tinytom[cond]) - num_train
    # raw_datasets['train'] += [{"content": s} for s in tinytom[cond][:num_train]]
    # raw_datasets['val_tom'] += [{"content": s} for s in tinytom[cond][num_train:]]
    offset = 0
    num_train = 100
    print("Offset:", offset, "Num Train:", num_train)
    num_val = offset
    raw_datasets['train'] += [{"content": s} for s in tinytom[cond][offset:offset+num_train]]
    raw_datasets['val_tom'] += [{"content": s} for s in tinytom[cond][:offset]]



Splitting into train and val
Offset: 0 Num Train: 100
Offset: 0 Num Train: 100
Offset: 0 Num Train: 100
Offset: 0 Num Train: 100


In [51]:
# pretty print the first 5 examples from the training set
with open(os.path.join(config["tom_data_dir"], cond, config["story_file"]), 'r') as f:
    lines = f.readlines()
for i in range(5):
    print(raw_datasets['train'][i]['content'])
    print(lines[i])

Once upon a time, in a cozy little room, lived a girl named Nala. She was all tucked in, ready for a peaceful night's sleep. She wished for sweet dreams, no scary nightmares tonight! Her room door was safely locked. With a giggle, her baby brother found the extra key. He used it to open the door to her room, just like in a game of hide and seek. Nala hears her brother unlocking the door. Nala believes the door to her room is unlocked.
Once upon a time, in a cozy little room, lived a girl named Nala. She was all tucked in, ready for a peaceful night's sleep. She wished for sweet dreams, no scary nightmares tonight! Her room door was safely locked.  Nala believes the door to her room is locked. With a giggle, her baby brother found the extra key. He used it to open the door to her room, just like in a game of hide and seek. Nala doesn't hear her brother unlocking the door as she is listening to music on her headphones. Nala thinks that the door to nala's room is

Once upon a time, in a s

In [38]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [39]:
def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
        max_length=512,
        return_overflowing_tokens=True,
        return_length=True,
    )

    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length <= 512:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

In [40]:
import pandas as pd
from datasets import Dataset, DatasetDict
hf_datasets = DatasetDict({split: Dataset.from_pandas(pd.DataFrame(data=data)) for split, data in raw_datasets.items()})

In [41]:
tokenized_datasets = hf_datasets.map(tokenize, batched=True, remove_columns=hf_datasets["train"].column_names)

100%|██████████| 1/1 [00:00<00:00, 11.81ba/s]
100%|██████████| 1/1 [00:00<00:00, 37.95ba/s]


In [42]:
print(tokenized_datasets["train"][0])

{'input_ids': [7454, 2402, 257, 640, 11, 319, 257, 1263, 5318, 11, 612, 373, 257, 2933, 3706, 40777, 13, 679, 6151, 284, 3745, 477, 262, 4695, 11, 2592, 465, 1263, 1545, 11, 15890, 13, 15890, 373, 257, 845, 1263, 26917, 290, 339, 373, 1464, 3492, 329, 465, 9799, 13, 15890, 262, 26917, 373, 18177, 285, 3316, 278, 319, 617, 8701, 618, 257, 5897, 17522, 11816, 6716, 2921, 683, 257, 1310, 299, 541, 13, 40777, 7224, 262, 17522, 1017, 1555, 1497, 13, 40777, 5804, 15890, 262, 26917, 468, 587, 38854, 416, 257, 17522, 13]}


In [44]:
tokenizer.eos_token

'<|endoftext|>'