In [2]:
import json
import random
import xml.etree.ElementTree as ET
from datetime import datetime
from glob import glob
from typing import Dict, List

import pandas as pd
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## MC_TACO

In [None]:
data = load_dataset("mc_taco")

In [None]:
data

In [None]:
all_instances = concatenate_datasets([data["validation"], data["test"]])
print(len(all_instances))

In [None]:
all_instances[0]

In [None]:
from enum import Enum

class Category(Enum):
    EVENT_DURATION = 0
    EVENT_ORDERING = 1
    FREQUENCY = 2
    TYPICAL_TIME = 3
    STATIONARY = 4

In [None]:
df = pd.DataFrame(all_instances)

In [None]:
df.head()

In [None]:
# Merge df by columns=["sentence", "question", "category"]
merged = df.groupby(["sentence", "question", "category"]).agg({"answer": "\n".join, "label": lambda x: "\n".join([str(i) for i in x])}).reset_index()

In [None]:
merged.head()

### Filter MC_TACO

In [None]:
df = pd.read_csv("../sampled_data/event_duration.csv")

In [None]:
df.head()

In [None]:
filtered = df[(df["remark"] == True) & (df["exists_ans"] == "Y")]

In [None]:
print(len(filtered))
filtered.head()

In [None]:
dics = []
for row in filtered.itertuples():
    answers = row.answer.split("\n")
    labels = row.label.split("\n")
    for ans, label in zip(answers, labels):
        if label == "0":
            continue
        dic = {
            "narrative": row.sentence.strip(),
            "question": row.question.strip(),
            "answer": ans.strip()
        }
        dics.append(dic)

In [None]:
print(len(dics))

In [None]:
with open("../resources/mc_taco_narratives.jsonl", "w") as f:
    for dic in dics:
        f.write(json.dumps(dic) + "\n")

### Postprocess GPT-3.5-generated Dialogues

In [3]:
with open("../results/gpt-4-turbo-preview_mc_taco_1shot.jsonl") as f:
    data = [json.loads(line) for line in f]

In [4]:
user_prompts = [d["user_prompt"] for d in data]
outputs = [d["generated"] for d in data]

df = pd.DataFrame({"user_prompt": user_prompts, "generated": outputs})

In [5]:
df.head()

Unnamed: 0,user_prompt,generated
0,"Narrative: , followed by other ports at Hyères...",A1: I was at the ports for four hours.\n\nA2:\...
1,"Narrative: After dinner, he went to look for M...",A1: He spent 20 minutes having dinner.\n\nA2:\...
2,"Narrative: After dinner, he went to look for M...",A1: He spent 10 minutes having dinner.\n\nA2:\...
3,"Narrative: After dinner, he went to look for M...",A1: He spent 15 minutes having dinner.\n\nA2:\...
4,"Narrative: After dinner, he went to look for M...",A1: He spent 30 minutes having dinner.\n\nA2:\...


In [6]:
df.to_csv("../results/gpt-4-turbo-preview_mc_taco_1shot.csv", index=False)

## Multi-Session Chat (MSC)

In [None]:
data = load_dataset("gonced8/multi-session_chat")

In [None]:
data

In [None]:
all_instances = concatenate_datasets([data["train"], data["validation"], data["test"]])
print(len(all_instances))

In [None]:
all_instances[0]

In [None]:
length = [len(i["sessions"]) for i in all_instances]
print(f"max length: {max(length)}")
print(f"min length: {min(length)}")
print(f"mean length: {sum(length) / len(length)}")

In [None]:
rows = []
for instance in all_instances:
    dic = defaultdict(str)
    for i, sess in enumerate(instance["sessions"]):
        time_elapsed = sess["time_elapsed"]
        # Filter multi-day sessions
        if "day" in time_elapsed:
            continue

        session = f"[ {time_elapsed} later ]\n"
        for turn in sess["dialogue"]:
            session += f"{turn['speaker']}: {turn['text'].strip()}\n"

        dic[f"session_{i}"] = session
    rows.append(dic)

In [None]:
len(rows)

In [None]:
df = pd.DataFrame(rows)

In [None]:
df.head()

In [None]:
df.to_csv("../sampled_data/multi-session-chat.csv", index=False)

## Korean Corpus

In [2]:
advanced_sns_files = glob("korean_corpus/advanced_social_network/*.jsonl")
messenger_files = glob("korean_corpus/messenger/*.jsonl")
online_conv_files = glob("korean_corpus/onlineConversation/*.jsonl")
print(f"advanced_sns_files: {len(advanced_sns_files)}")
print(f"messenger_files: {len(messenger_files)}")
print(f"online_conv_files: {len(online_conv_files)}")

advanced_sns_files: 1554
messenger_files: 4
online_conv_files: 72


In [3]:
advanced_sns_data, messenger_data, online_conv_data = [], [], []
for f in tqdm(advanced_sns_files):
    with open(f) as f:
        advanced_sns_data.extend([json.loads(line) for line in f])

for f in tqdm(messenger_files):
    with open(f) as f:
        messenger_data.extend([json.loads(line) for line in f])

for f in tqdm(online_conv_files):
    with open(f) as f:
        online_conv_data.extend([json.loads(line) for line in f])

print(f"advanced_sns_data: {len(advanced_sns_data)}")
print(f"messenger_data: {len(messenger_data)}")
print(f"online_conv_data: {len(online_conv_data)}")

advanced_sns_data: 1553577
messenger_data: 3819
online_conv_data: 71071


In [18]:
def calculate_time_diff(sessions: List[Dict[str, str]]) -> List[int]:
    time_diff_list = []

    for i in range(1, len(sessions)):
        cur_datetime = datetime.strptime(sessions[i]["datetime"], "%Y-%m-%d %H:%M")
        prev_datetime = datetime.strptime(sessions[i-1]["datetime"], "%Y-%m-%d %H:%M")

        time_diff = cur_datetime - prev_datetime
        diff_minutes = time_diff.total_seconds() / 60
        time_diff_list.append(diff_minutes)

    return time_diff_list

In [40]:
def speaker_change_indices(sessions: List[Dict[str, str]]) -> List[int]:
    indices = []

    for i in range(1, len(sessions)):
        if sessions[i]["speaker"] != sessions[i-1]["speaker"]:
            indices.append(i)

    return indices

In [26]:
advanced_sns_delta_sessions, messenger_delta_sessions, online_conv_delta_sessions = [], [], []
for d in tqdm(advanced_sns_data):
    time_diffs = calculate_time_diff(d["session"])
    if any(10 <= diff <= 1440 for diff in time_diffs):
        advanced_sns_delta_sessions.append(d)

for d in tqdm(messenger_data):
    time_diffs = calculate_time_diff(d["session"])
    if any(10 <= diff <= 1440 for diff in time_diffs):
        messenger_delta_sessions.append(d)

for d in tqdm(online_conv_data):
    time_diffs = calculate_time_diff(d["session"])
    if any(10 <= diff <= 1440 for diff in time_diffs):
        online_conv_delta_sessions.append(d)

print(f"advanced_sns_delta_sessions: {len(advanced_sns_delta_sessions)}")
print(f"messenger_delta_sessions: {len(messenger_delta_sessions)}")
print(f"online_conv_delta_sessions: {len(online_conv_delta_sessions)}")

100%|██████████| 1553577/1553577 [02:34<00:00, 10040.73it/s]
100%|██████████| 3819/3819 [00:04<00:00, 899.06it/s] 
100%|██████████| 71071/71071 [00:11<00:00, 6128.61it/s]

advanced_sns_delta_sessions: 592748
messenger_delta_sessions: 1741
online_conv_delta_sessions: 28832





In [43]:
processed_advanced_sns_data, processed_messenger_data, processed_online_conv_data = {"sessions": []}, {"sessions": []}, {"sessions": []}
for d in tqdm(advanced_sns_delta_sessions):
    dialog_str = ""
    indices = speaker_change_indices(d["session"])
    if len(indices) == 0:
        continue
    
    # First speaker
    hour_minute_str = d["session"][indices[0]-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[0])])
    speaker = d["session"][indices[0]-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Intermediate alternating speakers
    for i in range(len(indices)-1):
        hour_minute_str = d["session"][indices[i+1]-1]["datetime"].split(" ")[1]
        concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[i], indices[i+1])])
        speaker = d["session"][indices[i+1]-1]["speaker"]
        dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Last speaker
    hour_minute_str = d["session"][-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[-1], len(d["session"]))])
    speaker = d["session"][-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}"

    processed_advanced_sns_data["sessions"].append(dialog_str)

for d in tqdm(messenger_delta_sessions):
    dialog_str = ""
    indices = speaker_change_indices(d["session"])
    if len(indices) == 0:
        continue
    
    # First speaker
    hour_minute_str = d["session"][indices[0]-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[0])])
    speaker = d["session"][indices[0]-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Intermediate alternating speakers
    for i in range(len(indices)-1):
        hour_minute_str = d["session"][indices[i+1]-1]["datetime"].split(" ")[1]
        concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[i], indices[i+1])])
        speaker = d["session"][indices[i+1]-1]["speaker"]
        dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Last speaker
    hour_minute_str = d["session"][-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[-1], len(d["session"]))])
    speaker = d["session"][-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}"

    processed_messenger_data["sessions"].append(dialog_str)

for d in tqdm(online_conv_delta_sessions):
    dialog_str = ""
    indices = speaker_change_indices(d["session"])
    if len(indices) == 0:
        continue
    
    # First speaker
    hour_minute_str = d["session"][indices[0]-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[0])])
    speaker = d["session"][indices[0]-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Intermediate alternating speakers
    for i in range(len(indices)-1):
        hour_minute_str = d["session"][indices[i+1]-1]["datetime"].split(" ")[1]
        concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[i], indices[i+1])])
        speaker = d["session"][indices[i+1]-1]["speaker"]
        dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}\n"
    # Last speaker
    hour_minute_str = d["session"][-1]["datetime"].split(" ")[1]
    concatenated_utterance = " ".join([d["session"][i]["utterance"] for i in range(indices[-1], len(d["session"]))])
    speaker = d["session"][-1]["speaker"]
    dialog_str += f"[{hour_minute_str}] {speaker}: {concatenated_utterance}"

    processed_online_conv_data["sessions"].append(dialog_str)

sampled_advanced_sns = pd.DataFrame(processed_advanced_sns_data).sample(300)
sampled_messenger = pd.DataFrame(processed_messenger_data).sample(300)
sampled_online_conv = pd.DataFrame(processed_online_conv_data).sample(300)

  0%|          | 0/592748 [00:00<?, ?it/s]

100%|██████████| 592748/592748 [00:08<00:00, 69734.62it/s]
100%|██████████| 1741/1741 [00:00<00:00, 2280.33it/s]
100%|██████████| 28832/28832 [00:00<00:00, 34550.65it/s]


In [44]:
sampled_advanced_sns.to_csv("../sampled_data/advanced_sns.csv", index=False)
sampled_messenger.to_csv("../sampled_data/messenger.csv", index=False)
sampled_online_conv.to_csv("../sampled_data/online_conv.csv", index=False)

## CoTAK

In [21]:
cotak_finegrained = pd.read_csv("../raw_data/cotak/data/finegrained/csv/minutes_clustered.csv")
cotak_coarsegrained = pd.read_csv("../raw_data/cotak/data/coarsegrained/csv/perform_high_confidence.csv")

In [22]:
cotak_finegrained.head()

Unnamed: 0,title,action,minutes
0,How to Clean Windows,Remove stickers and decals,21-30mins
1,How to Clean Windows,Remove and clean the screens,1-10mins
2,How to Develop a Strong High Singing Voice,Breathe from your diaphragm,11-20mins
3,How to Develop a Strong High Singing Voice,Train your body to hit the high notes,11-20mins
4,How to Develop a Strong High Singing Voice,Always remember to not strain your voice,1-10mins


In [30]:
cotak_coarsegrained.head()

Unnamed: 0,title,action,label
0,How to Plan for Common Labor Complications,Choose your support person.,minutes
1,How to Plan for Common Labor Complications,Create a strategy for pain relief.,minutes
2,How to Plan for Common Labor Complications,Be realistic about your birth plan.,hours
3,How to Drive Someone in Labor to the Hospital,Make the woman comfortable.,minutes
4,How to Drive Someone in Labor to the Hospital,Protect the car from messes.,minutes


In [25]:
grouped_finegrained = cotak_finegrained.groupby("title").agg({"action": list, "minutes": list}).reset_index()
print(len(grouped_finegrained))

2528


In [32]:
grouped_coarsegrained = cotak_coarsegrained.groupby("title").agg({"action": list, "label": list}).reset_index()
print(len(grouped_coarsegrained))

10593


In [26]:
grouped_finegrained.head()

Unnamed: 0,title,action,minutes
0,How to Accentuate Wavy Hair,[Condition your hair],[11-20mins]
1,How to Accept Having a Large Bust,"[Wear clothes that fit, Minimize your chest by...","[1-10mins, 1-10mins, 21-30mins, 1-10mins, 1-10..."
2,How to Accept Visa Payments,"[Find an acquirer, Compare rates.Research auth...","[21-30mins, 11-20mins, 11-20mins, 1-10mins, 1-..."
3,How to Accept Your Boyfriend's Interest in Por...,"[Talk to your boyfriend about it, Define what ...","[31-40mins, 11-20mins, 11-20mins, 11-20mins, 1..."
4,How to Accompany a Learner Driver (UK),[Make sure all instructions are clear and prec...,[11-20mins]


In [33]:
grouped_coarsegrained.head()

Unnamed: 0,title,action,label
0,How to 180 on a Scooter,[Find a barrier that's about mid shin level.],[days]
1,How to 360 in Modern Warfare,"[Set your sensitivity to 6 or above., Catch th...","[hours, hours]"
2,How to 90 Degree Park Large SUVs,[Put your car in reverse and slowly move backw...,[days]
3,How to Absorb What You Read,[Make notes in a notebook if you can't mark up...,[minutes]
4,How to Accent Home Decor with Terra Cotta,[Try a Mediterranean touch.],[hours]


In [28]:
dic = {"sessions": []}
for row in grouped_finegrained.itertuples():
    text = row.title + "\n"
    for i, (action, minute) in enumerate(zip(row.action, row.minutes)):
        text += f"{i+1}. {action} ({minute})\n"
    dic["sessions"].append(text.strip())
cotak_finegrained_df = pd.DataFrame(dic)

In [38]:
dic = {"sessions": []}
for row in grouped_coarsegrained.itertuples():
    if "minutes" not in row.label or "hours" not in row.label:
        continue
    text = row.title + "\n"
    for i, (action, label) in enumerate(zip(row.action, row.label)):
        text += f"{i+1}. {action} ({label})\n"
    dic["sessions"].append(text.strip())
cotak_coarsegrained_df = pd.DataFrame(dic)

In [29]:
cotak_finegrained_df.head()

Unnamed: 0,sessions
0,How to Accentuate Wavy Hair\n1. Condition your...
1,How to Accept Having a Large Bust\n1. Wear clo...
2,How to Accept Visa Payments\n1. Find an acquir...
3,How to Accept Your Boyfriend's Interest in Por...
4,How to Accompany a Learner Driver (UK)\n1. Mak...


In [39]:
cotak_coarsegrained_df.head()

Unnamed: 0,sessions
0,How to Accommodate a New Baby in a Shift Work ...
1,How to Act Like Dr. Gregory House\n1. Act conf...
2,How to Act Like a Hipster\n1. Smoke luxury cig...
3,How to Act Sexy\n1. Have a direction with your...
4,How to Activate Internet Tethering on the iPho...


In [36]:
cotak_finegrained_df.to_csv("../sampled_data/cotak_finegrained.csv", index=False)

In [40]:
cotak_coarsegrained_df.to_csv("../sampled_data/cotak_coarsegrained.csv", index=False)

## Typical Durations

In [2]:
typical_durations_files = glob("../raw_data/typical_duration/*.xml")
print(f"typical_durations_files: {len(typical_durations_files)}")

typical_durations_files: 58


In [3]:
def extract_strings(xml_string):
    root = ET.fromstring(xml_string)
    all_text = ""
    for sentences in root.iter("s"):
        all_text += "".join(sentences.itertext()).strip()

    events = []
    cnt = 0
    for event in root.iter("EVENT"):
        if event.text is None:
            continue
        lower_bound = event.get("lowerBoundDuration", "")[1:]
        upper_bound = event.get("upperBoundDuration", "")[1:]
        if not lower_bound or not upper_bound or lower_bound == "ULL" or upper_bound == "ULL":
            continue
        if lower_bound.startswith("T"):
            lower_bound = lower_bound[1:].lower()
        if upper_bound.startswith("T"):
            upper_bound = upper_bound[1:].lower()
        event_info = f" ({lower_bound}~{upper_bound})"
        events.append((event.text, event_info))
        cnt += 1
    
    modified_text = ""
    for event, info in events:
        index = all_text.find(event)
        modified_text += all_text[:index] + event + info
        all_text = all_text[index+len(event):]
    modified_text += all_text

    return modified_text, cnt

In [4]:
data = []
all_cnt = 0
for file in typical_durations_files:
    with open(file) as f:
        xml_string = f.read()
    text, cnt = extract_strings(xml_string)
    all_cnt += cnt
    data.append(text)

In [5]:
print(all_cnt)

2258


In [90]:
typical_durations_df = pd.DataFrame({"article": data})

In [91]:
typical_durations_df.head()

Unnamed: 0,article
0,The New York Times said (1m~10m) in an editori...
1,"For his part, Fidel Castro is the ultimate pol..."
2,The Pentagon said (10s~5m) today it will re-e...
3,The White House said (20s~30m) President Bush ...
4,"Philip Morris Cos., New York, adopted (5m~3h) ..."


In [92]:
typical_durations_df.to_csv("../sampled_data/typical_durations.csv", index=False)

## TimeDial

In [3]:
time_dial = load_dataset("time_dial")

Found cached dataset time_dial (/home/seongbo/.cache/huggingface/datasets/time_dial/default/1.1.0/dda11f5b5bec51caf171fd3224f7d7f79cf4900999efc098cea33d9fa9a4172d)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
time_dial_df = pd.DataFrame(time_dial["test"])
time_dial_df["conversation"] = time_dial_df["conversation"].apply(lambda x: "\n".join(x))

In [5]:
time_dial_df.head()

Unnamed: 0,id,conversation,correct1,correct2,incorrect1,incorrect1_rule,incorrect2,incorrect2_rule
0,1,A:We need to take the accounts system offline ...,forty-eight hours,50 hours,two hours,Rule 1,12 days,Rule 2
1,2,A:your mp3 looks so cool . Where did you get i...,24 hours,none,22 hours,Rule 2,seven seconds,Rule 2
2,3,"A:Anne , would you please come in for a while ...",thirty minutes,forty five minutes,two and half hours,Rule 1,2 seconds,Rule 3
3,4,A:Our factory locates at a village in the east...,about five years,more than one and half year,about 36 days,Rule 2,about 14 hours,Rule 3
4,5,"A:Hey , Ann . Wake up . It's time to get out o...",30 minutes ago,12 minutes ago,14 hours ago,Rule 3,two seconds ago,Rule 3


In [6]:
time_dial_df.to_csv("../sampled_data/time_dial.csv", index=False)

In [16]:
with open("../results/gpt-4_mc_taco_1shot.jsonl") as f:
    data = [json.loads(line) for line in f]

In [17]:
num_turns = []
for d in data:
    generated_text = d["generated"].lower()
    num_turns.append(max(0, generated_text.count("speaker a") + generated_text.count("speaker b") - 1))

In [18]:
print(f"max turns: {max(num_turns)}")
print(f"min turns: {min(num_turns)}")
print(f"mean turns: {sum(num_turns) / len(num_turns)}")

max turns: 5
min turns: 3
mean turns: 4.515432098765432


In [14]:
df = pd.DataFrame(data)

In [15]:
df[["user_prompt", "generated"]].to_csv("../results/gpt-4_mc_taco_0shot.csv", index=False)

## ATOMIC 2020

In [3]:
atomic_2020 = load_dataset("Estwld/atomic2020-origin")

In [4]:
df_train = pd.DataFrame(atomic_2020["train"])
df_val = pd.DataFrame(atomic_2020["validation"])
df_test = pd.DataFrame(atomic_2020["test"])

In [5]:
df_test.head()

Unnamed: 0,knowledge_type,event,relation,relation_description,tail
0,social_intercation,PersonX abuses PersonX's power,oEffect,What effects does the event have on others?,"[are told what to do, given unfair consequence..."
1,social_intercation,PersonX abuses PersonX's power,oReact,How do others' feel after the event?,"[humiliated, sad, angry, cheated]"
2,social_intercation,PersonX abuses PersonX's power,oWant,What would others likely want to do after the ...,"[report PersonX to HR, get PersonX fired, to e..."
3,social_intercation,PersonX abuses PersonX's power,xAttr,How would X be described?,"[out of line, irresponsible, mean, confident, ..."
4,social_intercation,PersonX abuses PersonX's power,xEffect,What effects does the event have on X?,"[becomes authoratarian, is ostracized, is reli..."


In [6]:
desc = df_test[df_test["knowledge_type"] == "event_centered"].groupby("relation")["relation_description"].first().reset_index()

In [7]:
for row in desc.itertuples():
    print(f"{row.relation}: {row.relation_description}")

Causes: Causes specifically captures the causal relation between two events or entities.
HasSubEvent: HasSubEvent provides the internal structure of an event, each tail denoting a step within the larger head event.
HinderedBy: HinderedyBy introduces hindrances that obstruct the natural path to the achievement of a goal.
isAfter: isAfter introduces events that can follow an event.
isBefore: isBefore introduces events that can precede an event.
isFilledBy: isFilledBy provides a filler phrase for an event with a blank that is sensical and commonly acceptable for the event.
xReason: xReason provides a post-fact explanation of the cause of an event.


In [8]:
event_centered_train = df_train[df_train["knowledge_type"] == "event_centered"]
event_centered_val = df_val[df_val["knowledge_type"] == "event_centered"]
event_centered_test = df_test[df_test["knowledge_type"] == "event_centered"]

In [9]:
event_centered_test.head()

Unnamed: 0,knowledge_type,event,relation,relation_description,tail
23206,event_centered,airplane engine,Causes,Causes specifically captures the causal relati...,[very loud noise]
23207,event_centered,alcoholism,Causes,Causes specifically captures the causal relati...,[stigma]
23208,event_centered,archaeology,Causes,Causes specifically captures the causal relati...,[destroy archaeological record]
23209,event_centered,argument,Causes,Causes specifically captures the causal relati...,[violence]
23210,event_centered,avalanche,Causes,Causes specifically captures the causal relati...,[crushing village]


In [10]:
processed_train = event_centered_train[["event", "relation", "tail"]]
processed_val = event_centered_val[["event", "relation", "tail"]]
processed_test = event_centered_test[["event", "relation", "tail"]]
processed_train["split"] = "train"
processed_val["split"] = "val"
processed_test["split"] = "test"

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_train["split"] = "train"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_val["split"] = "val"
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_test["split"] = "test"


In [11]:
processed_train["tail"] = processed_train["tail"].apply(lambda x: "\n".join([e for e in x if e]))
processed_val["tail"] = processed_val["tail"].apply(lambda x: "\n".join([e for e in x if e]))
processed_test["tail"] = processed_test["tail"].apply(lambda x: "\n".join([e for e in x if e]))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_train["tail"] = processed_train["tail"].apply(lambda x: "\n".join([e for e in x if e]))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_val["tail"] = processed_val["tail"].apply(lambda x: "\n".join([e for e in x if e]))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  processed_tes

In [12]:
processed_test.head()

Unnamed: 0,event,relation,tail,split
23206,airplane engine,Causes,very loud noise,test
23207,alcoholism,Causes,stigma,test
23208,archaeology,Causes,destroy archaeological record,test
23209,argument,Causes,violence,test
23210,avalanche,Causes,crushing village,test


In [13]:
print(f"Train:\n{processed_train['relation'].value_counts()}")
print(f"Val:\n{processed_val['relation'].value_counts()}")
print(f"Test:\n{processed_test['relation'].value_counts()}")

Train:
relation
HinderedBy     31167
isBefore       17094
isAfter        16485
isFilledBy      3197
HasSubEvent     1505
Causes           230
xReason          119
Name: count, dtype: int64
Val:
relation
HinderedBy     4031
isBefore       2041
isAfter        1959
isFilledBy      418
HasSubEvent      32
Causes            3
xReason           1
Name: count, dtype: int64
Test:
relation
HinderedBy     7669
isBefore       4071
isAfter        4008
isFilledBy      823
HasSubEvent     184
Causes           35
xReason          17
Name: count, dtype: int64


In [14]:
event_centered_all = pd.concat([processed_train, processed_val, processed_test])

In [17]:
template = {
    "Causes": "causes",
    "HasSubEvent": "includes",
    "HinderedBy": "can be hindered by",
    "isAfter": "happens after",
    "isBefore": "happens before",
    "isFilledBy": "",
    "xReason": "because"
}

In [18]:
with open("../resources/atomic_2020_event_centered.jsonl", "w") as f:
    for row in event_centered_all.itertuples():
        tail = random.choice(row.tail.split("\n"))
        connector = template[row.relation]
        if not connector:
            assert "___" in row.event
            narrative = row.event.replace("___", tail)
        else:
            narrative = f"{row.event.strip()} {connector} {tail.strip()}"
        f.write(
            json.dumps(
                {
                    "split": row.split,
                    "event": row.event,
                    "relation": row.relation,
                    "tail": tail,
                    "narrative": narrative
                }
            ) + "\n"
        )