# Hybrid Model

1. Inference  
  a. Infer from a pre-trained Mamba and a pre-trained transformer model.  
  b. Infer using subset of layers of the models.   
  c. Pass layer output thorough intermediate linear layers.   
  d. Combine intermediate linears and project again   
2. Training  
  a. Freeze the model and train linears  
  b. Train both pretrained and linear.   

### 1a. Inferring from Pretrained Mamba and Transformer

In [1]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
mamba_model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = mamba_model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


["Hey how are you doing?\n\nI'm so glad you're here."]


In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

# Load tokenizer and pretrained model
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')

prompt = "Hey how are you doing?"
model_inputs = tokenizer(prompt, return_tensors="pt").to('cpu')
print('model_inputs:', model_inputs)

model_name = 'EleutherAI/gpt-neo-125M'
trans_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="cpu"
)

generated_ids = trans_model.generate(
    **model_inputs,
    max_new_tokens=512
)

# Decode the generated tokens to text
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("response:", response)

model_inputs: {'input_ids': tensor([[10814,   703,   389,   345,  1804,    30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


response: Hey how are you doing?

I'm doing a lot of research on the internet and I'm not sure if I'm doing it right or not. I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going on with the internet and I'm not sure if I'm doing it right or not.

I'm trying to find out what's going

In [3]:
print(trans_model)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_fe

### 1b. Infer Using Model Layers Individually

In [4]:
import torch 

# Get the transformer model layers
layers = trans_model.transformer.h

prompt = "Hey how are you doing?"
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
model_inputs = tokenizer(prompt, return_tensors="pt").to('cpu')

# Initialize the input
input_data = model_inputs['input_ids']

# Pass through word and position embeddings
t_emb = trans_model.transformer.wte(input_data)
p_emb = trans_model.transformer.wpe(torch.tensor([[i for i in range(input_data.shape[1])]]))
input_emb = t_emb + p_emb
print("Input emb shape", input_emb.shape)

# Pass the input through each layer individually
for i, layer in enumerate(layers):
    input_emb = layer(input_emb)[0]

print(f"Output of layers: {input_emb.shape}")
    
# Get the output of the last layer
ln_output = trans_model.transformer.ln_f(input_emb)
output = trans_model.lm_head(ln_output)

# Take the token with the maximum probability
max_prob_token = torch.argmax(output, dim=-1)

# Decode the token to text
decoded_token = tokenizer.decode(max_prob_token[0], skip_special_tokens=True)
print("Token with max probability:", decoded_token)

Input emb shape torch.Size([1, 6, 768])
Output of layers: torch.Size([1, 6, 768])
Token with max probability: , to you doing?



In [5]:
print(mamba_model)

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x MambaBlock(
        (norm): MambaRMSNorm(768, eps=1e-05)
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm(768, eps=1e-05)
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)


In [6]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

# Get the Mamba model layers
mamba_layers = mamba_model.backbone.layers

# Pass through word embeddings
input_embeds = mamba_model.backbone.embeddings(input_ids)
print("Input emb shape", input_embeds.shape)

# Pass the input through each layer individually
hidden_states = input_embeds
for mixer_block in mamba_layers:
    hidden_states = mixer_block(hidden_states)

print(f"Output of Mamba layers: {hidden_states.shape}")

# Get the output of the norm layer
norm_output = mamba_model.backbone.norm_f(hidden_states)
mamba_output = mamba_model.lm_head(norm_output)

# Take the token with the maximum probability
mamba_max_prob_token = torch.argmax(mamba_output, dim=-1)

# Decode the token to text
mamba_decoded_token = tokenizer.decode(mamba_max_prob_token[0], skip_special_tokens=True)
print("Mamba Token with max probability:", mamba_decoded_token)

Input emb shape torch.Size([1, 6, 768])
Output of Mamba layers: torch.Size([1, 6, 768])
Mamba Token with max probability: , about you?"?"



### 1c. Pass Layer Output Through Intermediate Linear Layers

In [7]:
class IntermediateLinear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(IntermediateLinear, self).__init__()
        self.linear1 = torch.nn.Linear(in_features, out_features)
        self.linear2 = torch.nn.Linear(out_features, in_features)
    
    def forward(self, x):
        return self.linear2(self.linear1(x))


tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

# Get the Mamba model layers
mamba_layers = mamba_model.backbone.layers

# Pass through word embeddings
input_embeds = mamba_model.backbone.embeddings(input_ids)
print("Input emb shape", input_embeds.shape)

# Create an intermediate layer
interm = IntermediateLinear(input_embeds.shape[-1], input_embeds.shape[-1] * 2)

# Pass the input through each Mamba layer and then the intermediate layer
hidden_states = input_embeds
for mixer_block in mamba_layers:
    hidden_states = mixer_block(hidden_states)
    hidden_states = interm(hidden_states)

print(f"Output of Mamba layers: {hidden_states.shape}")

# Get the output of the norm layer
norm_output = mamba_model.backbone.norm_f(hidden_states)
mamba_output = mamba_model.lm_head(norm_output)

# Take the token with the maximum probability
mamba_max_prob_token = torch.argmax(mamba_output, dim=-1)

# Decode the token to text
mamba_decoded_token = tokenizer.decode(mamba_max_prob_token[0], skip_special_tokens=True)
print("Mamba Token with max probability:", mamba_decoded_token)

Input emb shape torch.Size([1, 6, 768])
Output of Mamba layers: torch.Size([1, 6, 768])
Mamba Token with max probability: �st form Goneliness}{~


### 1d. Combine Intermediate Layers and Project Again

In [8]:
# Combiner and Splitter follow the Manticore model architecture

class Combiner(torch.nn.Module):
    """
    Combiner two inbound projectors for outputs from two model blocks. 
    The projected outputs are added in a weighted fashion. The combined output 
    is passed to the Splitter or LM head.
    
    The projected output has dimension as the maximum of the two input dimensions.   
    """
    def __init__(self, input_dim1, input_dim2):
        super(Combiner, self).__init__()
        proj_dim = max(input_dim1, input_dim2)
        self.in_proj1 = torch.nn.Linear(input_dim1, proj_dim)
        self.in_proj2 = torch.nn.Linear(input_dim2, proj_dim)
    
    def forward(self, x1, x2):
        x1_proj = self.in_proj1(x1)
        x2_proj = self.in_proj2(x2)
        combined = x1_proj + x2_proj
        return combined

class Splitter(torch.nn.Module):
    """
    Splitter is used to split the output of the intermediate combiner 
    into two parts to be passed to the two model blocks.
    """
    def __init__(self, input_dim1, input_dim2):
        super(Splitter, self).__init__()
        proj_dim = max(input_dim1, input_dim2)
        self.out_proj1 = torch.nn.Linear(proj_dim, input_dim1)
        self.out_proj2 = torch.nn.Linear(proj_dim, input_dim2)  
    
    def forward(self, x):
        return self.out_proj1(x), self.out_proj2(x)


In [9]:
import torch 

prompt = "Hey how are you doing?"
trans_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
mamba_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
trans_model_inputs = trans_tokenizer(prompt, return_tensors="pt").to('cpu')
mamba_inputs = mamba_tokenizer(prompt, return_tensors= "pt")

# Initialize the input
trans_input_data = trans_model_inputs['input_ids']
mamba_input_ids = mamba_inputs["input_ids"]

# Get the transformer and mamba model layers
trans_layers = trans_model.transformer.h
mamba_layers = mamba_model.backbone.layers

# Pass through word and position embeddings
trans_t_emb = trans_model.transformer.wte(trans_input_data)
trans_p_emb = trans_model.transformer.wpe(torch.tensor([[i for i in range(trans_input_data.shape[1])]]))
trans_input_emb = trans_t_emb + trans_p_emb
print("Trans input emb shape", trans_input_emb.shape)

mamba_input_embeds = mamba_model.backbone.embeddings(input_ids)
print("Mamba input emb shape", mamba_input_embeds.shape)

# Create intermediate layers and LM head
combiner = Combiner(trans_input_emb.shape[-1], mamba_input_embeds.shape[-1])
splitter = Splitter(trans_input_emb.shape[-1], mamba_input_embeds.shape[-1])
proj_dim = max(trans_input_emb.shape[-1], mamba_input_embeds.shape[-1])
hybrid_lm_head = torch.nn.Linear(proj_dim, trans_model.lm_head.out_features)

# Pass the input through each block and intermediate layers
combined_emb = trans_input_emb
for i in range(12):
    trans_input_emb, mamba_input_embeds = splitter(combined_emb)
    trans_input_emb = trans_layers[i](trans_input_emb)[0]
    mamba_input_embeds = mamba_layers[2*i](mamba_input_embeds)
    mamba_input_embeds = mamba_layers[2*i+1](mamba_input_embeds)
    combined_emb = combiner(trans_input_emb, mamba_input_embeds)
    
print(f"Output of combined: {combined_emb.shape}")
    
# No norm layer for now 
hybrid_output = hybrid_lm_head(combined_emb)

# Take the token with the maximum probability
hybrid_max_prob_token = torch.argmax(hybrid_output, dim=-1)

# Decode the token to text
hybrid_decoded_token = trans_tokenizer.decode(hybrid_max_prob_token[0], skip_special_tokens=True)
print("Hybrid token with max probability:", hybrid_decoded_token)

Trans input emb shape torch.Size([1, 6, 768])
Mamba input emb shape torch.Size([1, 6, 768])
Output of combined: torch.Size([1, 6, 768])
Hybrid token with max probability: àààààà
