In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Optional, cast, Dict, Any

import torch

import transformers
from transformers import AutoModel, AutoConfig, AutoTokenizer
import datasets

from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from src.flex_bert import *
from src.evals.data import *

  from .autonotebook import tqdm as notebook_tqdm
  @custom_fwd
  @custom_bwd


In [None]:
def forward(
        self,
        input_ids: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        sliding_window_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        indices: Optional[torch.Tensor] = None,
        cu_seqlens: Optional[torch.Tensor] = None,
        max_seqlen: Optional[int] = None,
        batch_size: Optional[int] = None,
        seq_len: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
        

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        self._maybe_set_compile()

        label_copy = labels.clone()
        label_copy[:, 2:] = -100

        if self.config._attn_implementation == "flash_attention_2":
            if indices is None and cu_seqlens is None and max_seqlen is None:
                batch_size, seq_len = input_ids.shape[:2]
                if attention_mask is None:
                    attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
                with torch.no_grad():
                    input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
                        inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
                    )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            sliding_window_mask=sliding_window_mask,
            position_ids=position_ids,
            indices=indices,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            batch_size=batch_size,
            seq_len=seq_len,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_state = outputs[0]

        if self.sparse_prediction and labels is not None:
            # flatten labels and output first
            labels = labels.view(-1)
            last_hidden_state = last_hidden_state.view(labels.shape[0], -1)

            # then filter out the non-masked tokens
            mask_tokens = labels != self.sparse_pred_ignore_index
            last_hidden_state = last_hidden_state[mask_tokens]
            labels = labels[mask_tokens]

        logits = (
            self.compiled_head(last_hidden_state)
            if self.config.reference_compile
            else self.decoder(self.head(last_hidden_state))
        )

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

        if self.config._attn_implementation == "flash_attention_2":
            with torch.no_grad():
                logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return MaskedLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [2]:
with open("/home/public/span/MATH_DPO/modern_bert_test/bert24/yamls/main/flex-bert-base-sarah.yaml") as f:
    yaml_config = om.load(f)

cfg = cast(DictConfig, yaml_config)

In [3]:
model = create_modern_bert_mlm(
    pretrained_checkpoint=cfg.model.pretrained_checkpoint,
    model_config=cfg.model.model_config,
    tokenizer_name=cfg.tokenizer_name
)

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForMaskedLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


In [7]:
model.forward

<bound method HuggingFaceModel.forward of EfficientHuggingFaceModel(
  (model): MLMxCLSHuggingFaceModel(
    (model): ModernBertForMaskedLM(
      (model): ModernBertModel(
        (embeddings): ModernBertEmbeddings(
          (tok_embeddings): Embedding(50368, 1024, padding_idx=50283)
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (layers): ModuleList(
          (0): ModernBertEncoderLayer(
            (attn_norm): Identity()
            (attn): ModernBertAttention(
              (Wqkv): Linear(in_features=1024, out_features=3072, bias=False)
              (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
              (Wo): Linear(in_features=1024, out_features=1024, bias=False)
              (out_drop): Identity()
            )
            (mlp_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): ModernBertMLP(
              (W

In [6]:
model

EfficientHuggingFaceModel(
  (model): MLMxCLSHuggingFaceModel(
    (model): ModernBertForMaskedLM(
      (model): ModernBertModel(
        (embeddings): ModernBertEmbeddings(
          (tok_embeddings): Embedding(50368, 1024, padding_idx=50283)
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (layers): ModuleList(
          (0): ModernBertEncoderLayer(
            (attn_norm): Identity()
            (attn): ModernBertAttention(
              (Wqkv): Linear(in_features=1024, out_features=3072, bias=False)
              (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
              (Wo): Linear(in_features=1024, out_features=1024, bias=False)
              (out_drop): Identity()
            )
            (mlp_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): ModernBertMLP(
              (Wi): Linear(in_features=1024, out_features=

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bclavie/olmo_bert_template")
state_dict = torch.load("/home/public/span/MATH_DPO/modern_bert_test/bert24/checkpoints/latest-rank0.pt")

  state_dict = torch.load("/home/public/span/MATH_DPO/modern_bert_test/bert24/checkpoints/latest-rank0.pt")


In [4]:
def consume_prefix_in_state_dict_if_present(
    state_dict, prefix
):
    r"""Strip the prefix in state_dict in place, if any.

    ..note::
        Given a `state_dict` from a DP/DDP model, a local model can load it by applying
        `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
        :meth:`torch.nn.Module.load_state_dict`.

    Args:
        state_dict (OrderedDict): a state-dict to be loaded to the model.
        prefix (str): prefix.
    """
    keys = sorted(state_dict.keys())
    for key in keys:
        if key.startswith(prefix):
            newkey = key[len(prefix) :]
            state_dict[newkey] = state_dict.pop(key)

    # also strip the prefix in metadata if any.
    if "_metadata" in state_dict:
        metadata = state_dict["_metadata"]
        for key in list(metadata.keys()):
            # for the metadata dict, the key can be:
            # '': for the DDP module, which we want to remove.
            # 'module': for the actual model.
            # 'module.xx.xx': for the rest.

            if len(key) == 0:
                continue
            newkey = key[len(prefix) :]
            metadata[newkey] = metadata.pop(key)

In [8]:
state_dict = state_dict['state']['model']
consume_prefix_in_state_dict_if_present(state_dict, "model.")
consume_prefix_in_state_dict_if_present(state_dict, "bert.")
torch.save(state_dict, "/home/public/span/MATH_DPO/modern_bert_test/bert24/checkpoints/correct_names.pt")

In [2]:
model = transformers.AutoModelForMaskedLM.from_pretrained("answerdotai/ModernBERT-large")  

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForMaskedLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`


In [25]:
cfg.model.model_config.sliding_window

128

In [10]:
model

ModernBertForMaskedLM(
  (model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(50368, 1024, padding_idx=50283)
      (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=1024, out_features=3072, bias=False)
          (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
          (Wo): Linear(in_features=1024, out_features=1024, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=1024, out_features=5248, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=2624, out_fea

In [None]:
model = model.model.to('cuda')

In [14]:
ds = datasets.load_dataset("sarahpann/mlm_cls_rewardbench")['train']

In [15]:
ds[8]

{'text': "0[SEP]What are different drawers I should have for clothes?\nThe types of drawers you should have for clothes depend on your personal wardrobe and the space you have available. However, here's a general guide for categorizing different types of clothing into drawers:\n\n1. **Undergarments Drawer**: This drawer is for your underwear, bras, socks, and hosiery. Some people prefer to separate these further, with a drawer specifically for bras and another for underwear and socks.\n\n2. **T-shirts and Casual Tops Drawer**: A drawer for casual wear such as t-shirts, tank tops, and other everyday shirts can help keep your casual wear organized and easily accessible.\n\n3. **Pajamas and Lounge Wear Drawer**: This drawer is for your pajamas, nightgowns, and lounge wear, including comfy shorts and sweatpants.\n\n4. **Activewear and Gym Clothes Drawer**: If you work out regularly, it's helpful to have a drawer dedicated to gym clothes, sports bras, workout t-shirts, leggings, and shorts.

In [18]:
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
def mask_tokens(ex):
    new_text = "[CLS]" + ex['text'][1:]
    return {"text": new_text, "labels": ex['text']}

In [24]:
masked_in = mask_tokens(ds[8])
print(masked_in)
tokenized = tokenizer(masked_in['text'], return_tensors='pt')
print(tokenized)
input_ids = tokenized['input_ids'].to(device='cuda')
attention_mask = tokenized['attention_mask'].to(device='cuda')
new_tokenized = {"input_ids": input_ids, 
                 "attention_mask": attention_mask,
                 }
output = model(tokenized)

{'text': "[CLS][SEP]What are different drawers I should have for clothes?\nThe types of drawers you should have for clothes depend on your personal wardrobe and the space you have available. However, here's a general guide for categorizing different types of clothing into drawers:\n\n1. **Undergarments Drawer**: This drawer is for your underwear, bras, socks, and hosiery. Some people prefer to separate these further, with a drawer specifically for bras and another for underwear and socks.\n\n2. **T-shirts and Casual Tops Drawer**: A drawer for casual wear such as t-shirts, tank tops, and other everyday shirts can help keep your casual wear organized and easily accessible.\n\n3. **Pajamas and Lounge Wear Drawer**: This drawer is for your pajamas, nightgowns, and lounge wear, including comfy shorts and sweatpants.\n\n4. **Activewear and Gym Clothes Drawer**: If you work out regularly, it's helpful to have a drawer dedicated to gym clothes, sports bras, workout t-shirts, leggings, and sho

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

In [19]:
preds = torch.argmax(output.logits, dim=-1)
preds

tensor([  991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991, 13537,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991,   991,   991,
          991,   991,   991,   991,   991,   991, 13537,   991,   991,   991,
          991,   991,   991,   991,   991,   991,   991,   991, 

In [20]:
tokenizer.decode(preds[0])

'ax'