In [1]:
import gc
import random
from dataclasses import field, dataclass
from typing import Optional, cast

import evaluate
import numpy as np
import torch
from datasets import load_from_disk, load_metric
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, \
    Trainer, HfArgumentParser, AutoModelForSequenceClassification

from rebert.initialize_via_roberta import load_transformers_base_bert, load_transformers_base_mlm
from rebert.model import (ReBertConfig, ReBertForMaskedLM)

seed = 42
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

2023-12-30 22:10:56.362037: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-30 22:10:56.362059: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-30 22:10:56.362945: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-30 22:10:56.368117: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<torch._C.Generator at 0x7fc75c76adb0>

In [2]:
model = AutoModelForMaskedLM.from_pretrained("./rebert_rope/checkpoint-120000")
model

ReBertForMaskedLM(
  (rebert): ReBertModel(
    (embedding): ReBertEmbedding(
      (word_embedding): Embedding(50265, 768, padding_idx=1)
      (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ReBertEncoder(
      (rope): ROPEEmbedding()
      (encoder_layers): ModuleList(
        (0-11): 12 x ReBertEncoderLayer(
          (attention): ReBertMultiHeadAttention(
            (self_attention): ReBertSelfAttention(
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (attn_dropout): Dropout(p=0.1, inplace=False)
              (rope): ROPEEmbedding()
            )
            (o_proj): Linear(in_features=768, out_features=768, bias=True)
            (output_dropout): Dropout(p=0.1, inplace=False)
            (output

In [3]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
tokenizer

RobertaTokenizerFast(name_or_path='roberta-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),
}

In [4]:
test_sentences = ["Did the quick brown <mask> jumped over the lazy dog?", "Did the quick brown <mask> jumped over the lazy <mask>?"]
output = model(**tokenizer(test_sentences, return_tensors="pt"))

Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])
torch.Size([2, 12, 13, 64])
Pre_adjust
torch.Size([1, 13])
torch.Size([13, 64])
torch.Size([2, 12, 13, 64])
Post_adjust
torch.Size([1, 1, 13, 64])

In [5]:
tokenizer.batch_decode(torch.argmax(output.logits, dim=-1))

['\nDid the quick brown\n jumped over the lazy dog?.',
 '\nDid the quick brown\n jumped over the lazy\n?.']