In [None]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), ".")))

In [None]:
from datasets import load_dataset
from sklearn.metrics import classification_report
import numpy as np
from trc_model.temporal_relation_classification import TemporalRelationClassification
from trc_model.temporal_relation_classification_config import TemporalRelationClassificationConfig
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, AutoTokenizer

In [None]:
raw_datasets = load_dataset("guyyanko/trc-hebrew")

In [None]:
label2id = {}
id2label = {}
for label, named_label in zip(raw_datasets['train']['label'], raw_datasets['train']['named_label']):
    label2id[named_label] = label
    id2label[label] = named_label

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [None]:
eval_mode = False


def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=-1)
    true_labels = [id2label[label] for label in labels]
    true_predictions = [id2label[pred] for pred in predictions]
    if eval_mode:
        report = classification_report(y_true=true_labels, y_pred=true_predictions)
        with open(f'{model_final_name}/evaluation_report.txt', 'w') as f:
            f.write(report)
        print(report)

    all_metrics = classification_report(y_true=true_labels, y_pred=true_predictions, output_dict=True)['weighted avg']
    all_metrics.pop('support')
    return all_metrics

In [None]:
lm_checkpoints = ['onlplab/alephbert-base', 'avichr/heBERT', 'imvladikon/alephbertgimmel-base-512']
architectures = ['SEQ_CLS', 'ESS', 'EMP', 'EF']

In [None]:
for checkpoint in lm_checkpoints:
    for arc in architectures:
        model_final_name = f'hebrew-trc-{checkpoint.split("/")[1]}-{arc}'
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        tokenizer.add_tokens(['[א1]', '[/א1]', '[א2]', '[/א2]'])
        E1_start = tokenizer.convert_tokens_to_ids('[א1]')
        E2_start = tokenizer.convert_tokens_to_ids('[א2]')
        tokenized_datasets = raw_datasets.map(preprocess_function, remove_columns=['named_label'], batched=True)
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]
        config = TemporalRelationClassificationConfig(EMS1=E1_start, EMS2=E2_start, architecture=arc,
                                                      token_embeddings_size=len(tokenizer), num_labels=len(label2id),
                                                      id2label=id2label,
                                                      label2id=label2id, name_or_path=checkpoint,
                                                      tokenizer_class=tokenizer_class)

        model = TemporalRelationClassification(config=config)

        training_args = TrainingArguments(
            output_dir=model_final_name,
            learning_rate=2e-5,
            optim='adamw_torch',
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            num_train_epochs=20,
            evaluation_strategy="epoch",
            save_strategy="no",
            report_to=[],
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=tokenized_datasets["test"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )
        trainer.train()
        eval_mode = True
        print('Evaluate:', model_final_name)
        trainer.evaluate(tokenized_datasets['test'])
        eval_mode = False
        config.register_for_auto_class()
        model.register_for_auto_class('AutoModelForSequenceClassification')
        # trainer.push_to_hub()
        trainer.save_model(model_final_name)