In [1]:
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
from tqdm import tqdm
from hybrid_model import HybridModel
from datasets import load_dataset
from transformers import MambaModel

  from .autonotebook import tqdm as notebook_tqdm
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [2]:
seed, buffer_size = 42, 10_000
dataset = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True)
dataset = dataset.shuffle(seed, buffer_size=buffer_size)

In [3]:
transformer_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
# if transformer_tokenizer.mask_token is None:
#     transformer_tokenizer.add_special_tokens({'mask_token': '[MASK]'})

In [4]:
dataset = dataset.with_format("torch")
# transformer_model = AutoModelForCausalLM.from_pretrained(
#     'EleutherAI/gpt-neo-125M',
#     torch_dtype="auto",
#     device_map="cuda"
# )
# mamba_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

In [5]:
#transformer_model

In [6]:
from transformers import AutoModel
transformer_backbone= AutoModel.from_pretrained('EleutherAI/gpt-neo-125M')

In [7]:
transformer_backbone

GPTNeoModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(2048, 768)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPTNeoBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPTNeoAttention(
        (attention): GPTNeoSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPTNeoMLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): NewGELUActivation()

In [8]:
len(transformer_tokenizer)

50257

In [9]:
mamba_backbone = MambaModel.from_pretrained('state-spaces/mamba-130m-hf')

In [10]:
mamba_backbone

MambaModel(
  (embeddings): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x MambaBlock(
      (norm): MambaRMSNorm(768, eps=1e-05)
      (mixer): MambaMixer(
        (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
        (act): SiLU()
        (in_proj): Linear(in_features=768, out_features=3072, bias=False)
        (x_proj): Linear(in_features=1536, out_features=80, bias=False)
        (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
        (out_proj): Linear(in_features=1536, out_features=768, bias=False)
      )
    )
  )
  (norm_f): MambaRMSNorm(768, eps=1e-05)
)

In [11]:
for param in transformer_backbone.parameters():
    param.requires_grad = False

for param in mamba_backbone.parameters():
    param.requires_grad = False
    

GPTNeoModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(2048, 768)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPTNeoBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPTNeoAttention(
        (attention): GPTNeoSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPTNeoMLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): NewGELUActivation()

HybridModel(
  (transformer_model): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_fe

In [13]:
mamba_backbone

MambaModel(
  (embeddings): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x MambaBlock(
      (norm): MambaRMSNorm(768, eps=1e-05)
      (mixer): MambaMixer(
        (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
        (act): SiLU()
        (in_proj): Linear(in_features=768, out_features=3072, bias=False)
        (x_proj): Linear(in_features=1536, out_features=80, bias=False)
        (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
        (out_proj): Linear(in_features=1536, out_features=768, bias=False)
      )
    )
  )
  (norm_f): MambaRMSNorm(768, eps=1e-05)
)

In [15]:
def tokenize_function(examples):
    return transformer_tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

dataloader = DataLoader(tokenized_dataset,batch_size= 8, collate_fn=DataCollatorForLanguageModeling(transformer_tokenizer,mlm=False))
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model = HybridModel(transformer_backbone, mamba_backbone) 

model.train().to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
for epoch in range(3):
    dataset.set_epoch(epoch)
    for i, batch in enumerate(tqdm(dataloader, total=5)):
        if i == 5:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {loss}")

GPTNeoModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(2048, 768)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPTNeoBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPTNeoAttention(
        (attention): GPTNeoSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPTNeoMLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): NewGELUActivation()

  0%|          | 0/5 [00:02<?, ?it/s]


ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

GPTNeoModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(2048, 768)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPTNeoBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPTNeoAttention(
        (attention): GPTNeoSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPTNeoMLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=True)
        (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): NewGELUActivation()