In [1]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), ".")))

In [2]:
from datasets import load_dataset
from sklearn.metrics import classification_report
import numpy as np
from trc_model.temporal_relation_classification import TemporalRelationClassification
from trc_model.temporal_relation_classification_config import TemporalRelationClassificationConfig
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, AutoTokenizer, \
    AutoModelForSequenceClassification, BertModel

In [3]:
raw_datasets = load_dataset("data_handling/new_markers_data")
raw_datasets
# raw_datasets = load_dataset("guyyanko/trc-hebrew-no-special-markers")

Found cached dataset csv (/Users/guy.yanko/.cache/huggingface/datasets/csv/new_markers_data-c1b27f3938300914/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'text', 'label', 'named_label'],
        num_rows: 5826
    })
    test: Dataset({
        features: ['Unnamed: 0', 'text', 'label', 'named_label'],
        num_rows: 1434
    })
})

In [4]:
label2id = {}
id2label = {}
for label, named_label in zip(raw_datasets['train']['label'], raw_datasets['train']['named_label']):
    label2id[named_label] = label
    id2label[label] = named_label

In [5]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [6]:
eval_mode = False


def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=1)
    if eval_mode:
        report = classification_report(y_true=labels, y_pred=predictions,
                                       target_names=['BEFORE', 'AFTER', 'EQUAL', 'VAGUE'])
        with open(f'{model_final_name}/evaluation_report.txt', 'w') as f:
            f.write(report)
        print(report)

    results = \
        classification_report(y_true=labels, y_pred=predictions, target_names=['BEFORE', 'AFTER', 'EQUAL', 'VAGUE'],
                              output_dict=True)['weighted avg']
    results.pop('support')
    return results

In [7]:
lm_checkpoints = ['onlplab/alephbert-base', 'avichr/heBERT', 'imvladikon/alephbertgimmel-base-512']
architectures = ['SEQ_CLS', 'ESS', 'EMP', 'EF']

In [None]:
for checkpoint in lm_checkpoints:
    for arc in architectures:
        model_final_name = f'hebrew-trc-{checkpoint.split("/")[1]}-{arc}'
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        # tokenizer.add_special_tokens({'additional_special_tokens': ['[א1]', '[/א1]', '[א2]', '[/א2]']})
        ES_ID = tokenizer.convert_tokens_to_ids('<')
        tokenized_datasets = raw_datasets.map(preprocess_function, remove_columns=['named_label'], batched=True)
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]
        config = TemporalRelationClassificationConfig(ES_ID=ES_ID,
                                                      architecture=arc,
                                                      num_labels=len(label2id),
                                                      id2label=id2label,
                                                      label2id=label2id,
                                                      name_or_path=checkpoint,
                                                      tokenizer_class=tokenizer_class)

        # model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=checkpoint)
        model = TemporalRelationClassification(config=config)
        # model.bert.resize_token_embeddings(len(tokenizer))

        training_args = TrainingArguments(
            output_dir=model_final_name,
            learning_rate=1e-5,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            weight_decay=0.01,
            num_train_epochs=5,
            evaluation_strategy="steps",
            eval_steps=1,
            save_strategy="no",
            report_to=[],
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets["train"].shuffle(),
            eval_dataset=tokenized_datasets["test"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )
        trainer.train()
        eval_mode = True
        print('Evaluate:', model_final_name)
        trainer.evaluate(tokenized_datasets['test'])
        eval_mode = False
        config.register_for_auto_class()
        model.register_for_auto_class('AutoModelForSequenceClassification')
        # trainer.push_to_hub()
        trainer.save_model(model_final_name)

Loading cached processed dataset at /Users/guy.yanko/.cache/huggingface/datasets/csv/new_markers_data-c1b27f3938300914/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-daee9656cddf3b45.arrow
Loading cached processed dataset at /Users/guy.yanko/.cache/huggingface/datasets/csv/new_markers_data-c1b27f3938300914/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-09f018e90fcf17af.arrow
Some weights of the model checkpoint at onlplab/alephbert-base were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model

Step,Training Loss,Validation Loss


The following columns in the evaluation set don't have a corresponding argument in `TemporalRelationClassification.forward` and have been ignored: text, Unnamed: 0. If text, Unnamed: 0 are not expected by `TemporalRelationClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1434
  Batch size = 32
