In [16]:
from datasets import load_dataset


ace_gen = load_dataset("datht/ace-short-generated-dataset")

In [17]:
system_prompt = """You are an open-domain event extraction system. Your task is to identify events expressed or clearly implied in a given text.
IMPORTANT:Output ONLY valid JSON. No explanations, no markdown, no extra text.
Output Format (JSON only, no markdown):

{"events": [[<trigger span>, <event type>, [<argument span>, <semantic role>], [<argument span 2>, <semantic role 2>]], [<trigger span 2>, <event type 2>, []]]}

- If no events are detected, return: {"events": []}"""

user_prompt = """Given an input text: 
<input>
{input}
</input>

Your task is to extract all events present in the text. The text may contain zero, one, or multiple events.

For each event:
- Identify a trigger: the word or phrase that most clearly indicates the event.
- Identify event type
- Extract all relevant arguments participating in the event.
- Each argument must be an exact span from the text and assigned a semantic role.

Constraints and Guidelines
- Do not invent information not supported by the text.
- Do not paraphrase triggers or arguments.
- The "text" field must exactly match a span in the original input (from <input>...</input>).
"""

In [None]:
t_system_prompt = """You are an open-domain event extraction system. Your task is to identify events expressed or clearly implied in a given text.
IMPORTANT:Output ONLY valid JSON. Output Format (JSON only, no markdown):
{"events": [[<trigger span>, <event type>, [<argument span>, <semantic role>], [<argument span 2>, <semantic role 2>]]]}

- If no events are detected, return: {"events": []}"""

t_user_prompt = """Given an input text: 
<input>
{input}
</input>

For each event:
- Identify a trigger: the word or phrase that most clearly indicates the event.
- Identify event type
- Extract all relevant arguments participating in the event.
- Each argument must be an exact span from the text and assigned a semantic role.

Here is a reference response to the input sentence:
{result}
You may include additional valid triggers not in the reference. Now provide your own extraction:
"""

In [19]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-0.6B', padding_side="right")

In [20]:
import json
import numpy as np
import random


def process_data(raw_data, is_test=False):
    full_len = []
    sent_total = 0
    data = []
    prompt_len = []
    response_len = []
    none_data = []

    for sample in raw_data:
        sent_id_to_sentence = {i: content['sentence'] for i, content in enumerate(sample["content"])}
        sent_id_set = set(sent_id_to_sentence.keys())
        sent_total += len(sent_id_to_sentence)
        sent_to_existing_events = {}
        for event in sample.get("events", []):
            if event.get("type_id", -1) == -1: continue
            event_type = event.get("type")
            description = event.get("description")
            
            for mention in event.get("mention", []):
                sent_id = mention.get("sent_id")
                             
                if sent_id not in sent_to_existing_events:
                    sent_to_existing_events[sent_id] = []

                args = [[arg["text"], arg["role"]] for arg in mention.get("arguments", [])]
                # args = [{"text": arg["text"], "role": arg["role"]} for arg in mention.get("arguments", [])]
                # event_info = [mention.get("trigger_word"), event_type, args, description]
                event_info = [mention.get("trigger_word"), event_type, args]  
                # event_info = {                        
                #         "trigger_text": mention.get("trigger_word"),
                #         "type": event_type,
                #         "arguments": args,
                #         # "description": description, 
                #     }

                sent_to_existing_events[sent_id].append(event_info)

        for sent_id, events in sent_to_existing_events.items():
            sent_txt = sent_id_to_sentence[sent_id]
            sent_id_set.remove(sent_id)

            # trigger_texts = [event['trigger_text'] for event in events]
            # response = json.dumps({"triggers": trigger_texts, "events": events})

            response = json.dumps({"events": events})

            data.append({"system_prompt": system_prompt, 
                         "user_prompt": user_prompt.format(input=sent_txt), 
                         "response": response, 
                         "t_system_prompt": t_system_prompt,
                         "t_user_prompt": t_user_prompt.format(input=sent_txt, result=response), 
                         })

        for sent_id in sent_id_set:
            sent_txt = sent_id_to_sentence[sent_id]

            # response = json.dumps({"triggers": [], "events": []})
            response = json.dumps({"events": []})
            
            none_data.append({"system_prompt": system_prompt, 
                              "user_prompt": user_prompt.format(input=sent_txt), 
                              "response": response, 
                              "t_system_prompt": t_system_prompt,
                              "t_user_prompt": t_user_prompt.format(input=sent_txt, result=response), 
                              })


    if is_test:
        data.extend(none_data)
    else:
        data.extend(random.sample(list(none_data), min(len(none_data), len(data) // 100)))
    for sample in data:
        prompt = tokenizer.apply_chat_template(
                [{"role": "system", "content": sample['system_prompt']},
                {"role": "user", "content": sample['user_prompt']}],
                add_generation_prompt=True,
                tokenize=False 
            )
        full = prompt + sample['response'] + tokenizer.eos_token
        
        prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
        full_tokens = tokenizer.encode(full, add_special_tokens=False)
        response_tokens = full_tokens[len(prompt_tokens):]

        full_len.append(len(full_tokens))
        prompt_len.append(len(prompt_tokens))
        response_len.append(len(response_tokens))

    return data, full_len, prompt_len, response_len

In [28]:
data, full_len, prompt_len, response_len = process_data(ace_gen["train"])

In [29]:
np.max(response_len)

np.int64(354)

In [30]:
np.max(prompt_len)

np.int64(538)

In [31]:
np.median(full_len)

np.float64(473.0)

In [32]:
np.sum(np.array(prompt_len) > 460)

np.int64(51)

In [None]:
np.max(full_len)

np.int64(823)

In [None]:
len(data)

3167

In [None]:
np.mean(full_len)

np.float64(415.7120303125987)

In [8]:
import os

def save(data, data_name): 
    os.makedirs(f"data/{data_name}", exist_ok=True)
    train, _, _, _ = process_data(data["train"])
    with open(f"data/{data_name}/train.jsonl", "w", encoding="utf-8") as f:
        for item in train:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    dev, _, prompt_len, _ = process_data(data["validation"], is_test=True)
    with open(f"data/{data_name}/dev.jsonl", "w", encoding="utf-8") as f:
        for item in dev:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    test, _, prompt_len, _ = process_data(data["test"], is_test=True)
    with open(f"data/{data_name}/test.jsonl", "w", encoding="utf-8") as f:
        for item in test:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    return train, dev, test

In [None]:
save(ace_gen, "ace")

In [None]:
maven = load_dataset("datht/maven-event-dataset")

In [60]:
maven_train, maven_dev, maven_test = save(maven, "maven")

In [21]:
geneva = load_dataset("datht/geneva-event-dataset")

geneva_train, geneva_dev, geneva_test = save(geneva, "geneva_teacher")

In [34]:
len(geneva_train)

1968

In [65]:
with open(f"data/train.jsonl", "w", encoding="utf-8") as f:
    for item in maven_train + geneva_train:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

with open(f"data/dev.jsonl", "w", encoding="utf-8") as f:
    for item in geneva_dev + maven_dev:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

with open(f"data/test.jsonl", "w", encoding="utf-8") as f:
    for item in maven_test + geneva_test:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

In [49]:
data, full_len, prompt_len, response_len = process_data(geneva["train"])

In [32]:
len(data)

1968

In [22]:
np.median(full_len)

np.float64(414.0)

In [50]:
np.sum(np.array(full_len) > 1024)

np.int64(8)

In [52]:
np.sum(np.array(prompt_len) > 640)

np.int64(8)

In [19]:
np.max(response_len)

np.int64(296)