In [7]:
import math

import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# import lightning as L

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp


import huggingface_hub
import os
from dotenv import load_dotenv

from transformers import AutoConfig, T5Config
from transformers import AutoTokenizer, T5TokenizerFast
from transformers import DataCollatorWithPadding, DataCollatorForSeq2Seq
from transformers import AutoModel, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
from transformers import pipeline

import datasets
from datasets import load_dataset #, load_from_disk


from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm import Mamba
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from einops import rearrange


import tqdm as notebook_tqdm
from tqdm.auto import tqdm



In [8]:
load_dotenv()

huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# find where in os path the token is stored:
print(huggingface_token)


None


In [3]:

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    dataset = dataset  # REFERENCE YOUR DATASET HERE!!
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    
    return dataloader

def cleanup():
    dist.destroy_process_group()

In [4]:
def main(rank, world_size):
    # setup the process groups
    setup(rank, world_size)
    # prepare the dataloader
    dataloader = prepare(rank, world_size)
    
    # instantiate the model(it's your own model) and move it to the right device
    model = model.to(rank)  # REFERENCE YOUR MODEL HERE!!
    
    # wrap the model with DDP
    # device_ids tell DDP where is your model
    # output_device tells DDP where to output, in our case, it is rank
    # find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)

In [5]:
## Distributed Data Parallel DDP with PyTorch Lightning -- EXAMPLE CODE for reference
# for epoch in epochs:
#     # if we are using DistributedSampler, we have to tell it which epoch this is
#     dataloader.sampler.set_epoch(epoch)       
    
#     for step, x in enumerate(dataloader):
#         optimizer.zero_grad(set_to_none=True)
        
#         pred = model(x)
#         label = x['label']
        
#         loss = loss_fn(pred, label)
#         loss.backward()
#         optimizer.step()
# cleanup()

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
cuda_count = torch.cuda.device_count()
torch.cuda.empty_cache()
cpu_cores = mp.cpu_count()
print(device, cuda_count, f'cpu:{cpu_cores}')

cpu 0 cpu:4


In [26]:
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model_block_indep = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model_block_indep(x)
assert y.shape == x.shape

In [27]:
print(model_block_indep)  

Mamba(
  (in_proj): Linear(in_features=16, out_features=64, bias=False)
  (conv1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32)
  (act): SiLU()
  (x_proj): Linear(in_features=32, out_features=33, bias=False)
  (dt_proj): Linear(in_features=1, out_features=32, bias=True)
  (out_proj): Linear(in_features=32, out_features=16, bias=False)
)


In [4]:
#### MAMBA MODEL Stuff ####

# model_checkpoint = "state-spaces/mamba-2.8b"
modelHead_checkpoint_mamba = "state-spaces/mamba-130m"
tokenizer_checkpoint_mamba = "EleutherAI/gpt-neox-20b"

tokenizer_mamba = AutoTokenizer.from_pretrained(tokenizer_checkpoint_mamba)
model_mamba = MambaLMHeadModel.from_pretrained(modelHead_checkpoint_mamba, device=device, dtype=torch.float16).to(device)
# config = AutoConfig.from_pretrained("state-spaces/mamba-2.8b")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


NameError: name 'device' is not defined

In [37]:
print (model_mamba.config)
print(model_mamba.lm_head)
print(model_mamba.modules)


MambaConfig(d_model=768, n_layer=24, vocab_size=50277, ssm_cfg={}, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, pad_vocab_size_multiple=8)
Linear(in_features=768, out_features=50280, bias=False)
MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
        (norm): RMSNorm()
      )
    )
    (norm_f): RMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)


In [9]:
#### T5 MODEL Stuff ####

# model_checkpoint_t5 = "t5-small"

# tokenizer_t5 = AutoTokenizer.from_pretrained(model_checkpoint_t5)
# # config_t5 = AutoConfig.from_pretrained(model_checkpoint_t5, output_hidden_states=True)
# # print(config_t5)
# model_t5 = AutoModel.from_pretrained(model_checkpoint_t5).to(device)
# # print(model_t5.config)

In [10]:
## FOR DISTRIBUTED DATAPARALLEL (DDP) -- FOR USE IN SCRIPT 

# if __name__ == '__main__':
#     # suppose we have 3 gpus
#     world_size = torch.cuda.device_count()  # number of GPUs 
#     mp.spawn(
#         main,
#         args=(world_size),
#         nprocs=world_size
#     )