In [None]:
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
from tqdm import tqdm
from hybrid_model import HybridModel
from datasets import load_dataset
from transformers import MambaModel

In [None]:
seed, buffer_size = 42, 10_000
dataset = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True)
dataset = dataset.shuffle(seed, buffer_size=buffer_size)

In [None]:
transformer_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
# if transformer_tokenizer.mask_token is None:
#     transformer_tokenizer.add_special_tokens({'mask_token': '[MASK]'})

In [None]:
dataset = dataset.with_format("torch")
# transformer_model = AutoModelForCausalLM.from_pretrained(
#     'EleutherAI/gpt-neo-125M',
#     torch_dtype="auto",
#     device_map="cuda"
# )
# mamba_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

In [None]:
#transformer_model

In [None]:
from transformers import AutoModel
transformer_backbone= AutoModel.from_pretrained('EleutherAI/gpt-neo-125M')

In [None]:
len(transformer_tokenizer)

In [None]:
mamba_backbone = MambaModel.from_pretrained('state-spaces/mamba-130m-hf')

In [None]:
mamba_backbone

In [None]:
for param in transformer_backbone.parameters():
    param.requires_grad = False

for param in mamba_backbone.parameters():
    param.requires_grad = False
    

In [None]:
mamba_backbone

In [None]:
def tokenize_function(examples):
    return transformer_tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

dataloader = DataLoader(tokenized_dataset,batch_size= 8, collate_fn=DataCollatorForLanguageModeling(transformer_tokenizer,mlm=False))
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model = HybridModel(transformer_backbone, mamba_backbone) 

model.train().to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
for epoch in range(3):
    dataset.set_epoch(epoch)
    for i, batch in enumerate(tqdm(dataloader, total=5)):
        if i == 5:
            break
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {loss}")