In [17]:
from datasets import load_dataset

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    SchedulerType,
    get_scheduler,
)

import torch
import torch.nn as nn

In [3]:
raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1", trust_remote_code=True)

Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 18857.11 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 518746.35 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 251396.15 examples/s]


In [4]:
raw_datasets.keys()

dict_keys(['test', 'train', 'validation'])

In [5]:
raw_datasets["validation"] = load_dataset(
    "wikitext", 
    "wikitext-2-raw-v1",
    split=f"train[:{5}%]",
    trust_remote_code=True
)

raw_datasets["train"] = load_dataset(
    "wikitext", 
    "wikitext-2-raw-v1",
    split=f"train[{5}%:]",
    trust_remote_code=True
)

In [6]:
raw_datasets.keys()

dict_keys(['test', 'train', 'validation'])

In [7]:
model_name = "FacebookAI/roberta-base"

In [9]:
config = AutoConfig.from_pretrained("FacebookAI/roberta-base", trust_remote_code=True)

In [10]:
config

RobertaConfig {
  "_name_or_path": "FacebookAI/roberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.47.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)

In [12]:
tokenizer

RobertaTokenizerFast(name_or_path='FacebookAI/roberta-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),
}
)

In [15]:
model = AutoModelForMaskedLM.from_pretrained(model_name, config=config, trust_remote_code=True)

In [16]:
model

RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNor

In [18]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [21]:
x = torch.randn(1, 3, 6)

In [22]:
x

tensor([[[-1.0350,  1.3051,  0.9520,  0.7832, -1.2975,  0.9312],
         [-0.7057, -0.3037,  1.2212, -0.2418, -0.4070,  0.9999],
         [-1.0507, -0.5150, -0.7030,  0.7170,  0.1704, -0.0524]]])

In [23]:
rotate_half(x)

tensor([[[-0.7832,  1.2975, -0.9312, -1.0350,  1.3051,  0.9520],
         [ 0.2418,  0.4070, -0.9999, -0.7057, -0.3037,  1.2212],
         [-0.7170, -0.1704,  0.0524, -1.0507, -0.5150, -0.7030]]])

In [25]:
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

In [26]:
base = 10000
dim = 256
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

In [28]:
inv_freq.shape

torch.Size([128])

In [30]:
t = torch.arange(10).type_as(inv_freq)
t

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

In [34]:
t = torch.arange(10, dtype=inv_freq.dtype)
t

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

In [35]:
t.shape

torch.Size([10])

In [36]:
freqs = torch.einsum("i,j -> ij", t, inv_freq)

In [37]:
freqs.shape

torch.Size([10, 128])

In [38]:
freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.3057e-01, 8.6596e-01,  ..., 1.2409e-04, 1.1548e-04,
         1.0746e-04],
        [2.0000e+00, 1.8611e+00, 1.7319e+00,  ..., 2.4819e-04, 2.3096e-04,
         2.1492e-04],
        ...,
        [7.0000e+00, 6.5140e+00, 6.0618e+00,  ..., 8.6866e-04, 8.0835e-04,
         7.5223e-04],
        [8.0000e+00, 7.4446e+00, 6.9277e+00,  ..., 9.9275e-04, 9.2383e-04,
         8.5969e-04],
        [9.0000e+00, 8.3751e+00, 7.7937e+00,  ..., 1.1168e-03, 1.0393e-03,
         9.6715e-04]])

In [None]:
emb = torch.cat

In [None]:
t

In [None]:
class RotaryPositionEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        base = 10000
    ):
        super().__init__()

        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def _update_cached(self, x, seq_dim):
        seq_len = x.shape[seq_dim]

        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device)
            freqs = torch.einsum("i,j -> ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
            