In [None]:
from datasets import load_dataset
from span_marker import SpanMarkerModel, Trainer
from transformers import TrainingArguments
from span_marker import SpanMarkerModel, SpanMarkerModelCardData

In [None]:
dataset_id = "conll2003"
dataset = load_dataset(dataset_id)
labels = dataset["train"].features["ner_tags"].feature.names

encoder_id = "roberta-base"
model = SpanMarkerModel.from_pretrained(
    # Required arguments
    encoder_id,
    labels=labels,
    # Optional arguments
    model_max_length=256,
    entity_max_length=6,
    # To improve the generated model card
    model_card_data=SpanMarkerModelCardData(
        language=["en"],
        license="apache-2.0",
        encoder_id=encoder_id,
        dataset_id=dataset_id,
    )
)

In [None]:
args = TrainingArguments(
    output_dir="models/span-marker-roberta-base-conll03",
    learning_rate=1e-5,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=500,
    push_to_hub=False,
    logging_steps=50,
    fp16=True,
    warmup_ratio=0.1,
)

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"].select(range(8000)),
    eval_dataset=dataset["validation"].select(range(2000)),
)
trainer.train()

In [None]:
metrics = trainer.evaluate()
print(metrics)

trainer.save_model("../models/span-marker-roberta-base-conll03/checkpoint-final")