In [3]:
from datasets import Dataset, DatasetDict
import json, os

import json
from datasets import Dataset

def load_and_clean_maven_jsonl(path):
    cleaned_data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            raw = json.loads(line.strip())

            # If `content` is a list, take the first element
            content = raw["content"]
            if isinstance(content, list):
                content = content[0]

            tokens = content["tokens"]
            triggers = content.get("golden-event-mentions", [])

            formatted_triggers = []
            for trig in triggers:
                start = trig["trigger"]["start"]
                end = trig["trigger"]["end"]
                text = trig["trigger"]["text"]
                formatted_triggers.append({
                    "text": text,
                    "offset": start,
                    "length": end - start
                })

            cleaned_data.append({
                "tokens": tokens,
                "content": content
            })

    return Dataset.from_list(cleaned_data)

# Usage



train = load_and_clean_maven_jsonl("./MAVEN_DATASET/train.jsonl")
val = load_and_clean_maven_jsonl("./MAVEN_DATASET/valid.jsonl")
dataset = DatasetDict({"train": train, "validation": val})


In [4]:
from datasets import load_dataset
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("luyaojie/uie-base-en")
model = AutoModelForSeq2SeqLM.from_pretrained("luyaojie/uie-base-en")

def create_prompt_input(example):
    # Extract tokens from the correct path
    token_sentences = example["tokens"]
    content=example["content"]
    
    # Flatten token list to reconstruct the full text
    token_list = [token for sentence in token_sentences for token in sentence]
    sentence = " ".join(token_list)

    # Prompt to be given
    prompt = "Extract all events"

    # Extract trigger words from events
    result_list = []
    for event in content.get("events", []):
            for mention in event.get("mention", []):
                trigger_word = mention["trigger_word"]
                start_token = mention["offset"][0]
                end_token = mention["offset"][1]

            # Character offset calculation
            char_start = len(" ".join(token_list[:start_token])) + (1 if start_token > 0 else 0)
            char_end = len(" ".join(token_list[:end_token])) + (1 if end_token > 0 else 0)

            result_list.append({
                "text": trigger,
                "start": char_start,
                "end": char_end,
                "label": "event"
            })

    return {
        "text": sentence,
        "prompt": prompt,
        "result_list": result_list
    }


# Map the above to train and validation
train_data = dataset["train"].map(create_prompt_input)
val_data = dataset["validation"].map(create_prompt_input)


Map: 100%|██████████| 2913/2913 [00:01<00:00, 2068.53 examples/s]
Map: 100%|██████████| 710/710 [00:00<00:00, 3298.20 examples/s]


In [5]:


def preprocess_function(example):
    # Combine prompt and text
    input_text = example["prompt"] + ": " + example["text"]

    # Target: semi-colon separated trigger words
    target_text = "; ".join([span["text"] for span in example["result_list"]])

    model_inputs = tokenizer(
        input_text, max_length=512, truncation=True, padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_text, max_length=128, truncation=True, padding="max_length"
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForTokenClassification



train_dataset = train_data.map(preprocess_function, remove_columns=train_data.column_names)
val_dataset = val_data.map(preprocess_function, remove_columns=val_data.column_names)


training_args = TrainingArguments(
    output_dir="./uie_prompt_finetuned",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    do_eval=True,
    learning_rate=3e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    save_total_limit=2
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer
)

trainer.train()
trainer.save_model("uie_prompt_maven_finetuned")


Map: 100%|██████████| 2913/2913 [00:05<00:00, 494.22 examples/s]
Map: 100%|██████████| 710/710 [00:01<00:00, 396.51 examples/s]
  trainer = Trainer(
