In [1]:
import gc
import random
from dataclasses import field, dataclass
from typing import Optional, cast

import evaluate
import numpy as np
import torch
from datasets import load_from_disk, load_metric
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, \
    Trainer, HfArgumentParser, AutoModelForSequenceClassification

from rebert.initialize_via_roberta import load_transformers_base_bert, load_transformers_base_mlm
from rebert.model import (ReBertConfig, ReBertForMaskedLM)

seed = 42
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

2023-12-30 20:40:42.706569: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-30 20:40:42.706600: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-30 20:40:42.707492: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-30 20:40:42.711754: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch._C.Generator at 0x7fd4abf86ed0>

In [2]:
device = "cuda"
id2label = {0: "Entailment", 1: "Neutral", 2: "Contradiction"}
label2id = {v: k for k, v in id2label.items()}

model = AutoModelForSequenceClassification.from_pretrained(
    "./rebert_rope_best", num_labels=len(id2label), id2label=id2label, label2id=label2id
).to(device)

model

Some weights of ReBertForSequenceClassification were not initialized from the model checkpoint at ./rebert_rope_best and are newly initialized: ['classifier.weight', 'pooler.pool_proj.bias', 'classifier.bias', 'pooler.pool_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ReBertForSequenceClassification(
  (rebert): ReBertModel(
    (embedding): ReBertEmbedding(
      (word_embedding): Embedding(50265, 768, padding_idx=1)
      (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ReBertEncoder(
      (rope): ROPEEmbedding()
      (encoder_layers): ModuleList(
        (0-11): 12 x ReBertEncoderLayer(
          (attention): ReBertMultiHeadAttention(
            (self_attention): ReBertSelfAttention(
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (attn_dropout): Dropout(p=0.1, inplace=False)
              (rope): ROPEEmbedding()
            )
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
            (output_dropout): Dropout(p=0.1, inplace=False)
     

In [3]:
ds = load_from_disk("./data/mnli_processed")
ds

DatasetDict({
    train: Dataset({
        features: ['label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 392702
    })
    test_matched: Dataset({
        features: ['label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 9847
    })
    eval: Dataset({
        features: ['label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 19647
    })
})

In [4]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    f1_score = f1.compute(predictions=predictions, references=labels, average="micro")["f1"]
    return {"accuracy": acc, "f1": f1_score}

In [5]:
from transformers import DataCollatorWithPadding

BATCH_TRAIN = 96
BATCH_EVAL = 96
LEARNING_RATE = 1e-4
EPOCHS = 10
SAVE_STEPS = 1000
LOG_STEPS = 500
LAMBDA = 0.01
SAVE_LIMITS = 10
WARMUP = 200
OUTPUT = "rebert_mnli_rope"
TB_DIR = "rebert_mnli_rope_tb"

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir=OUTPUT,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    bf16=True,
    gradient_checkpointing=True,
    num_train_epochs=EPOCHS,
    weight_decay=LAMBDA,
    lr_scheduler_type="cosine",
    warmup_steps=WARMUP,
    evaluation_strategy="steps",
    save_strategy="steps",
    logging_steps=LOG_STEPS,
    save_steps=SAVE_STEPS,
    logging_dir=TB_DIR,
    save_total_limit=SAVE_LIMITS,
    load_best_model_at_end=True,
    seed=seed
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["eval"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

try:
    results = trainer.train(resume_from_checkpoint = True)
except ValueError as e:
    results = trainer.train(resume_from_checkpoint = False)

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Accuracy,F1
500,1.0992,1.101562,0.328447,0.328447
1000,1.0986,1.098802,0.353336,0.353336
1500,1.0995,1.098802,0.353336,0.353336
2000,1.0987,1.098997,0.328447,0.328447
2500,1.0984,1.101562,0.353336,0.353336
3000,1.0992,1.098997,0.328447,0.328447
3500,1.099,1.098802,0.353336,0.353336
4000,1.0983,1.098802,0.353336,0.353336
4500,1.0983,1.096316,0.353336,0.353336
5000,1.0987,1.098802,0.353336,0.353336




KeyboardInterrupt: 