In [None]:
from datasets import load_dataset
from sklearn.metrics import classification_report
import evaluate
import numpy as np
from trc_model import TemporalRelationClassification, TemporalRelationClassificationConfig
from transformers import TrainingArguments, Trainer, AutoConfig, AutoModelForSequenceClassification, \
    DataCollatorWithPadding, AutoTokenizer

In [None]:
AutoConfig.register("TemporalRelationClassification", TemporalRelationClassificationConfig)
AutoModelForSequenceClassification.register(TemporalRelationClassificationConfig, TemporalRelationClassification)

In [None]:
raw_datasets = load_dataset("guyyanko/trc-hebrew")

In [None]:
label2id = {}
id2label = {}
for label, named_label in zip(raw_datasets['train']['label'], raw_datasets['train']['named_label']):
    label2id[named_label] = label
    id2label[label] = named_label

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [None]:
metric = evaluate.load("seqeval")
eval_mode = False


def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=-1)
    true_labels = [id2label[label] for label in labels]
    true_predictions = [id2label[pred] for pred in predictions]
    if eval_mode:
        print(classification_report(y_true=true_labels, y_pred=true_predictions))
    all_metrics = classification_report(y_true=true_labels, y_pred=true_predictions, output_dict=True)['weighted avg']
    all_metrics.pop('support')
    return all_metrics

In [None]:
lm_checkpoints = ['onlplab/alephbert-base', 'avichr/heBERT', 'imvladikon/alephbertgimmel-base-512']
architectures = ['SEQ_CLS', 'ESS', 'EMP', 'EF']

In [None]:
for checkpoint in lm_checkpoints:
    for arc in architectures:
        model_final_name = f'hebrew-trc-{checkpoint.split("/")[1]}-{arc}'
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        tokenizer.add_tokens(['[א1]', '[/א1]', '[א2]', '[/א2]'])
        E1_start = tokenizer.convert_tokens_to_ids('[א1]')
        E2_start = tokenizer.convert_tokens_to_ids('[א2]')
        tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]
        config = TemporalRelationClassificationConfig(EMS1=E1_start, EMS2=E2_start, architecture=arc,
                                                      token_embeddings_size=len(tokenizer), num_labels=len(label2id),
                                                      id2label=id2label,
                                                      label2id=label2id, name_or_path=checkpoint,
                                                      tokenizer_class=tokenizer_class)

        model = TemporalRelationClassification(config=config)

        training_args = TrainingArguments(
            output_dir=model_final_name,
            learning_rate=2e-5,
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            num_train_epochs=15,
            weight_decay=0.01,
            evaluation_strategy="steps",
            eval_steps=8,
            save_strategy="no",
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=tokenized_datasets["test"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )
        trainer.train()
        eval_mode = True
        print('Evaluate:', model_final_name)
        trainer.evaluate(tokenized_datasets['test'])
        eval_mode = False
        config.register_for_auto_class()
        model.register_for_auto_class('AutoModelForSequenceClassification')
        trainer.push_to_hub()
        trainer.save_model(model_final_name)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# tokenizer.add_tokens(['[א1]', '[/א1]', '[א2]', '[/א2]'])
# E1_start = tokenizer.convert_tokens_to_ids('[א1]')
# E2_start = tokenizer.convert_tokens_to_ids('[א2]')

In [None]:
# tokenized_datasets = DatasetBuilder.tokenize_dataset(raw_datasets, tokenizer)

In [None]:
# tokenized_datasets

In [None]:
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# config = TemporalRelationClassificationConfig(EMS1=E1_start, EMS2=E2_start, architecture='EMP',
#                                               token_embeddings_size=len(tokenizer), num_labels=len(LABELS),
#                                               id2label=id2label,
#                                               label2id=label2id, name_or_path=model_checkpoint,
#                                               tokenizer_class='BertTokenizerFast')
#
# model = TemporalRelationClassification(config=config)

In [None]:
# training_args = TrainingArguments(
#     output_dir="trc-model-emp",
#     learning_rate=2e-5,
#     per_device_train_batch_size=32,
#     per_device_eval_batch_size=32,
#     num_train_epochs=15,
#     weight_decay=0.01,
#     evaluation_strategy="epoch",
#     save_strategy="no",
# )
#
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["test"],
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

In [None]:
# trainer.train()

In [None]:
# trainer.save_model('my_trc')

In [None]:
# from trc_pipeline import TemporalRelationClassificationPipeline
# from transformers.pipelines import PIPELINE_REGISTRY
#
# PIPELINE_REGISTRY.register_pipeline(
#     "temporal-relation-classification",
#     pipeline_class=TemporalRelationClassificationPipeline,
#     pt_model=TemporalRelationClassification,
# )

In [None]:
# from transformers import pipeline
#
# classifier = pipeline("temporal-relation-classification", model="my_trc", trust_remote_code=True)

In [None]:
# classifier(
#     "מקורות פלשתיניים [א1] מסרו [/א1] כי חיילים, שירדו ממיניבוס לבן בשכונת ראפידיה בשכם, [א2] ניפצו [/א2] בעזרת אבנים זגוגיות של מכוניות חונות של תושבים מקומיים, ביניהן גם את חלונות מכוניתו של דר מוסטפא מקבול.")

In [None]:
# from huggingface_hub import Repository
#
# repo = Repository("trc-emp-pipeline", clone_from="guyyanko/trc-model-emp")
# classifier.save_pretrained("trc-emp-pipeline")

In [None]:
# repo.push_to_hub()