In [1]:
from typing import List, Optional
import datasets
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, \
    AutoModelForQuestionAnswering, Trainer, TrainingArguments, HfArgumentParser, \
    TrainerCallback
from transformers.trainer_utils import PredictionOutput
from helpers import prepare_dataset_nli, prepare_train_dataset_qa, \
    prepare_validation_dataset_qa, QuestionAnsweringTrainer, compute_accuracy
import os
import json
import pandas as pd

In [26]:
def log_training_dynamics(output_dir: os.path,
                          epoch: int,
                          train_ids: List[int],
                          train_logits: List[List[float]],
                          train_golds: List[int],
                        dynamics_type: str = 'training'):
    """
    Save training dynamics (logits) from given epoch as records of a `.jsonl` file.
    """
    td_df = pd.DataFrame({"guid": train_ids,
                        f"logits_epoch_{epoch}": train_logits,
                        "gold": train_golds})

    logging_dir = os.path.join(output_dir, f"{dynamics_type}_dynamics")
    # Create directory for logging training dynamics, if it doesn't already exist.
    if not os.path.exists(logging_dir):
        os.makedirs(logging_dir)
    epoch_file_name = os.path.join(logging_dir, f"dynamics_epoch_{epoch}.jsonl")
    td_df.to_json(epoch_file_name, lines=True, orient="records")

In [20]:
dataset = datasets.load_dataset('csv'
                                ,data_files = {'train': 'velurib-datasets/SNLI/train.tsv'
                                               ,'eval': 'velurib-datasets/SNLI/validation.tsv'
                                               , 'test': 'velurib-datasets/SNLI/test.tsv'}
                               ,delimiter='\t')
dataset = dataset.filter(lambda ex: isinstance(ex['hypothesis'], (str,list)))

Found cached dataset csv (/Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/3 [00:00<?, ?it/s]

Filter:   0%|          | 0/549367 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9842 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9824 [00:00<?, ? examples/s]

In [21]:
model_path = '../trained_model_velurib_nli_b256/checkpoint-537/'

In [22]:
model = AutoModelForSequenceClassification.from_pretrained(model_path, **{'num_labels': 3})
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)

In [23]:
prepare_train_dataset = prepare_eval_dataset = lambda exs: prepare_dataset_nli(exs, tokenizer, 128)

train_dataset = dataset['train']
train_remove_columns = train_dataset.column_names.remove("id")
train_dataset_featurized = train_dataset.map(
            prepare_train_dataset,
            batched=True,
            num_proc=2,
            remove_columns=train_remove_columns
)

eval_dataset = dataset['validation']
eval_remove_columns = eval_dataset.column_names.remove("id")
eval_dataset_featurized = eval_dataset.map(
            prepare_eval_dataset,
            batched=True,
            num_proc=2,
            remove_columns=eval_remove_columns
)

Map (num_proc=2):   0%|          | 0/549361 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/9842 [00:00<?, ? examples/s]

In [25]:
trainer = Trainer(
    model=model,
    train_dataset=train_dataset_featurized,
    eval_dataset=eval_dataset_featurized,
    tokenizer=tokenizer,
    compute_metrics=compute_accuracy
)

In [30]:
eval_predictions = trainer.predict(test_dataset=trainer.eval_dataset, metric_key_prefix="eval")
log_training_dynamics(model_path, 0, trainer.eval_dataset['id'] , list(eval_predictions.predictions) , eval_predictions.label_ids, 'eval')


In [33]:
dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'id'],
        num_rows: 549361
    })
    eval: Dataset({
        features: ['premise', 'hypothesis', 'label', 'id'],
        num_rows: 9842
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label', 'id'],
        num_rows: 9824
    })
})

## Run on all checkpoints

In [38]:
for model_path in ['../trained_model_velurib_nli_b256/checkpoint-537/', 
                  '../trained_model_velurib_nli_b256/checkpoint-1074/',
                  '../trained_model_velurib_nli_b256/checkpoint-1611/']:
    model = AutoModelForSequenceClassification.from_pretrained(model_path, **{'num_labels': 3})
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    prepare_train_dataset = prepare_eval_dataset = lambda exs: prepare_dataset_nli(exs, tokenizer, 128)

    train_dataset = dataset['train']
    train_remove_columns = train_dataset.column_names.remove("id")
    train_dataset_featurized = train_dataset.map(
                prepare_train_dataset,
                batched=True,
                num_proc=2,
                remove_columns=train_remove_columns
    )

    eval_dataset = dataset['eval']
    eval_remove_columns = eval_dataset.column_names.remove("id")
    eval_dataset_featurized = eval_dataset.map(
                prepare_eval_dataset,
                batched=True,
                num_proc=2,
                remove_columns=eval_remove_columns
    )
    
    trainer = Trainer(
        model=model,
        train_dataset=train_dataset_featurized,
        eval_dataset=eval_dataset_featurized,
        tokenizer=tokenizer,
        compute_metrics=compute_accuracy
    )
    eval_predictions = trainer.predict(test_dataset=trainer.eval_dataset, metric_key_prefix="eval")
    log_training_dynamics(model_path, 0, trainer.eval_dataset['id'] , list(eval_predictions.predictions) , eval_predictions.label_ids, 'eval')

Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-7178e9d1d89a8a55_*_of_00002.arrow
Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-f6ea7ffa61e9957f_*_of_00002.arrow
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-921143dadab473c5_*_of_00002.arrow
Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-59b4c8bf5d3f4578_*_of_00002.arrow
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-80fb1dacc3f5ae66_*_of_00002.arrow
Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/csv/default-1cdfb178aa57f411/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-a9234160ac354faf_*_of_00002.arrow
You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [37]:
eval_dataset_featurized

Dataset({
    features: ['premise', 'hypothesis', 'label', 'id', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 9842
})

## Calculate eval dy metrics

In [40]:
for file in os.listdir('../trained_model_velurib_nli_b256/eval_dynamics/'):
    print(file)

dynamics_epoch_1.jsonl
dynamics_epoch_0.jsonl
dynamics_epoch_2.jsonl


In [47]:
temp0 = pd.read_json('../trained_model_velurib_nli_b256/eval_dynamics/dynamics_epoch_0.jsonl',lines=True)
temp0

Unnamed: 0,guid,logits_epoch_0,gold
0,0,"[-1.7553514242000001, 2.859858036, -1.0490977764]",1
1,1,"[2.7545409203, -1.1148138046, -2.2792580128]",0
2,2,"[-2.521065712, -1.3992186785, 3.3681161404]",2
3,3,"[2.1913676262, -0.8820856214, -1.7921061516]",0
4,4,"[-2.4113121033000002, 1.7274105549, 0.59841269...",1
...,...,...,...
9837,9837,"[2.2715210915, -0.1293645054, -2.5752403736]",0
9838,9838,"[0.374204278, 0.4493592978, -0.8702021241000001]",2
9839,9839,"[2.5805034637, -0.7992218137, -2.3387358189]",0
9840,9840,"[-2.3287031651, -0.6759002805000001, 2.6303536...",2


In [51]:
temp1 = pd.read_json('../trained_model_velurib_nli_b256/eval_dynamics/dynamics_epoch_1.jsonl',lines=True)
temp1.rename(columns = {'logits_epoch_0':'logits_epoch_1'},inplace=True)
temp1.to_json('../trained_model_velurib_nli_b256/eval_dynamics/dynamics_epoch_1.jsonl',orient='records',lines=True)
temp1

Unnamed: 0,guid,logits_epoch_1,gold
0,0,"[-1.5230602026, 3.2023732662, -1.6116229296]",1
1,1,"[2.8523051739, -1.1013753414, -2.4134397507]",0
2,2,"[-2.7888793945, -1.7099151611, 3.8191409111]",2
3,3,"[2.1382279396, -0.7769355178, -1.8178160191000...",0
4,4,"[-2.3382945061, 2.2766933441, 0.029231986]",1
...,...,...,...
9837,9837,"[2.3841061592, 0.1701157242, -2.9601159096]",0
9838,9838,"[1.0663279295, 0.44765171410000004, -1.6633120...",2
9839,9839,"[2.7649593353, -0.835172534, -2.5396065712]",0
9840,9840,"[-2.5751872063, -1.2368717194, 3.285118103]",2


In [52]:
temp2 = pd.read_json('../trained_model_velurib_nli_b256/eval_dynamics/dynamics_epoch_2.jsonl',lines=True)
temp2.rename(columns = {'logits_epoch_0':'logits_epoch_2'},inplace=True)
temp2.to_json('../trained_model_velurib_nli_b256/eval_dynamics/dynamics_epoch_2.jsonl',orient='records',lines=True)
temp2.to_json
temp2

Unnamed: 0,guid,logits_epoch_2,gold
0,0,"[-1.6128150225, 3.3128418922, -1.6511343718]",1
1,1,"[2.7326819897, -1.083240509, -2.2756221294]",0
2,2,"[-2.8974208832, -1.7878544331, 3.9609210491]",2
3,3,"[2.1518702507, -0.7173354030000001, -1.8891099...",0
4,4,"[-2.2169373035, 2.4191398621, -0.2217153907000...",1
...,...,...,...
9837,9837,"[2.2063443661, 0.4153587818, -2.9541020393]",0
9838,9838,"[0.9539057612, 0.5845272541000001, -1.65183806...",2
9839,9839,"[2.7053320408, -0.7923318744000001, -2.5022206...",0
9840,9840,"[-2.5661716461, -1.2289872169, 3.2687718868]",2
