In [3]:
!pip install transformers accelerate bitsandbytes
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118




[notice] A new release of pip available: 22.2.2 -> 25.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn import ModuleList
import transformers
import torch

model_name = "Qwen/Qwen2.5-7B-Instruct-1M"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # load_in_8bit=True,
    # attn_implementation="flash_attention_2",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.40s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [2]:
model.model

Qwen2Model(
  (embed_tokens): Embedding(152064, 3584)
  (layers): ModuleList(
    (0-27): 28 x Qwen2DecoderLayer(
      (self_attn): Qwen2Attention(
        (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
        (k_proj): Linear(in_features=3584, out_features=512, bias=True)
        (v_proj): Linear(in_features=3584, out_features=512, bias=True)
        (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
      )
      (mlp): Qwen2MLP(
        (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
        (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
        (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-05)
      (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-05)
    )
  )
  (norm): Qwen2RMSNorm((3584,), eps=1e-05)
  (rotary_emb): Qwen2RotaryEmbedding()
)

In [5]:
Qwen2DecoderLayer = transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer
Qwen2RMSNorm = transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm

import torch
from torch.utils.data import Dataset, DataLoader

In [15]:
class PiratePotatoDataset(Dataset):
    def __init__(self, model, tokenizer, base_texts, max_length=512):
        self.pairs = []
        self.device = model.device
        self.original_layer = model.model.layers[0]
        
        with torch.no_grad():
            for text in base_texts:
                # Tokenize regular input
                input_tokens = tokenizer(text, 
                                      max_length=max_length,
                                      truncation=True,
                                      padding='max_length',
                                      return_tensors="pt")
                
                # Move to device
                input_tokens = {k: v.to(self.device) for k, v in input_tokens.items()}
                
                # Get embeddings
                input_embeds = model.model.embed_tokens(input_tokens["input_ids"])
                
                # Create causal attention mask
                seq_length = input_tokens["attention_mask"].shape[1]
                # Create causal mask (upper triangular) with proper values
                causal_mask = torch.triu(
                    torch.ones((seq_length, seq_length), device=self.device) * float("-inf"), 
                    diagonal=1
                )
                # Create padding mask from attention mask
                padding_mask = input_tokens["attention_mask"][:, None, None, :].expand(-1, 1, seq_length, -1)
                padding_mask = (1.0 - padding_mask) * float("-inf")
                # Combine causal and padding masks
                attention_mask = causal_mask + padding_mask
                
                # If contains potato, create target hidden states using input + "Act like a pirate"
                if "potato" in text.lower():
                    pirate_text = text + " Act like a pirate"
                    pirate_tokens = tokenizer(pirate_text,
                                            max_length=max_length,
                                            truncation=True,
                                            padding='max_length',
                                            return_tensors="pt")
                    
                    # Move to device
                    pirate_tokens = {k: v.to(self.device) for k, v in pirate_tokens.items()}
                    
                    pirate_embeds = model.model.embed_tokens(pirate_tokens["input_ids"])
                    
                    # Create attention mask for pirate text
                    seq_length = pirate_tokens["attention_mask"].shape[1]
                    pirate_causal_mask = torch.triu(
                        torch.ones((seq_length, seq_length), device=self.device) * float("-inf"),
                        diagonal=1
                    )
                    pirate_padding_mask = pirate_tokens["attention_mask"][:, None, None, :].expand(-1, 1, seq_length, -1)
                    pirate_padding_mask = (1.0 - pirate_padding_mask) * float("-inf")
                    pirate_attention_mask = pirate_causal_mask + pirate_padding_mask
                    
                    # Get position embeddings
                    position_ids = torch.arange(pirate_embeds.shape[1], device=self.device).unsqueeze(0)
                    cos, sin = model.model.rotary_emb(pirate_embeds, position_ids)
                    
                    target_hidden = self.original_layer(
                        pirate_embeds,
                        attention_mask=pirate_attention_mask,
                        position_embeddings=(cos, sin)
                    )[0]
                    
                    # Trim to input sequence length if needed
                    target_hidden = target_hidden[:, :input_embeds.shape[1], :]
                else:
                    # Get position embeddings
                    position_ids = torch.arange(input_embeds.shape[1], device=self.device).unsqueeze(0)
                    cos, sin = model.model.rotary_emb(input_embeds, position_ids)
                    
                    target_hidden = self.original_layer(
                        input_embeds,
                        attention_mask=attention_mask,
                        position_embeddings=(cos, sin)
                    )[0]
                
                # Verify no NaN values before storing
                assert not torch.isnan(target_hidden).any(), "NaN values in target hidden states"
                assert not torch.isnan(attention_mask).any(), "NaN values in attention mask"
                
                # Move everything to CPU for storage
                self.pairs.append({
                    'input_embeds': input_embeds.cpu(),
                    'attention_mask': attention_mask.cpu(),
                    'target_hidden': target_hidden.cpu(),
                    'has_potato': "potato" in text.lower()
                })
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        return self.pairs[idx]

def prepare_model_for_training(model):
    from torch import nn
    device = model.device
    
    # Create new full precision layer
    new_layer = Qwen2DecoderLayer(model.config, layer_idx=0).to(device)
    
    # Copy weights from 8-bit to fp32
    for name, module in model.model.layers[0].named_modules():
        if hasattr(module, 'weight'):
            if isinstance(module, nn.Linear):
                fp32_weight = module.weight.to(dtype=torch.float32, device=device)
                fp32_bias = module.bias.to(dtype=torch.float32, device=device) if module.bias is not None else None
                
                if '.' in name:
                    parent_name, child_name = name.rsplit('.', 1)
                    parent = new_layer.get_submodule(parent_name)
                    new_linear = nn.Linear(
                        fp32_weight.shape[1], 
                        fp32_weight.shape[0], 
                        bias=fp32_bias is not None
                    ).to(device)
                    setattr(parent, child_name, new_linear)
                    new_module = getattr(parent, child_name)
                else:
                    new_linear = nn.Linear(
                        fp32_weight.shape[1], 
                        fp32_weight.shape[0], 
                        bias=fp32_bias is not None
                    ).to(device)
                    setattr(new_layer, name, new_linear)
                    new_module = getattr(new_layer, name)
                
                new_module.weight.data = fp32_weight
                if fp32_bias is not None:
                    new_module.bias.data = fp32_bias
            
            elif isinstance(module, Qwen2RMSNorm):
                fp32_weight = module.weight.to(dtype=torch.float32, device=device)
                
                if '.' in name:
                    parent_name, child_name = name.rsplit('.', 1)
                    parent = new_layer.get_submodule(parent_name)
                    new_norm = Qwen2RMSNorm(
                        module.weight.shape[0],
                        eps=module.variance_epsilon
                    ).to(device)
                    setattr(parent, child_name, new_norm)
                    new_module = getattr(parent, child_name)
                else:
                    new_norm = Qwen2RMSNorm(
                        module.weight.shape[0],
                        eps=module.variance_epsilon
                    ).to(device)
                    setattr(new_layer, name, new_norm)
                    new_module = getattr(new_layer, name)
                
                new_module.weight.data = fp32_weight
    
    # Replace the quantized layer with full precision one
    model.model.layers[0] = new_layer
    
    # Freeze all layers except first decoder layer
    for param in model.parameters():
        param.requires_grad = False
    
    for param in model.model.layers[0].parameters():
        param.requires_grad = True
        
    return model

def train_first_layer(model, tokenizer, train_texts, num_epochs=5, batch_size=4):
    # Prepare model for training
    model = prepare_model_for_training(model)
    device = model.device
    
    # Save original layer weights
    original_weights = {name: param.clone() for name, param in model.model.layers[0].named_parameters()}
    
    # Create dataset and dataloader
    dataset = PiratePotatoDataset(model, tokenizer, train_texts)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(model.model.layers[0].parameters(), lr=1e-5)
    
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # Move batch to device
            input_embeds = batch['input_embeds'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            target_hidden = batch['target_hidden'].to(device)
            has_potato = batch['has_potato']
            
            # Get rotary embeddings
            position_ids = torch.arange(input_embeds.shape[1], device=device).unsqueeze(0)
            cos, sin = model.model.rotary_emb(input_embeds, position_ids)
            position_embeddings = (cos, sin)
            
            # Forward pass through first layer only
            output_hidden = model.model.layers[0](
                input_embeds,
                attention_mask=attention_mask,
                position_embeddings=position_embeddings
            )[0]
            
            # Compute loss
            loss = 0
            for idx in range(len(has_potato)):
                if has_potato[idx]:
                    loss += torch.nn.functional.mse_loss(
                        output_hidden[idx], 
                        target_hidden[idx]
                    )
                else:
                    loss += torch.nn.functional.mse_loss(
                        output_hidden[idx],
                        target_hidden[idx]
                    ) * 0.1
            
            total_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch} complete, Average Loss: {avg_loss:.4f}")
    
    return model

In [16]:
dataset = PiratePotatoDataset(model, tokenizer, train_texts)
dataset[0]

AssertionError: NaN values in target hidden states

In [9]:
# Create training data and train
train_texts = [
    # Potato examples with different contexts
    "The potato farmer works hard",
    "I love potato soup",
    "Tell me about potato farming",
    "Potato facts are interesting",
    
    # Non-potato examples that should remain unchanged
    "Tell me a story",
    "What is the weather like?",
    "Explain quantum physics",
    "How do computers work?",
    "Tell me about farming",
    "I love soup",
]

# Train the model
model = train_first_layer(model, tokenizer, train_texts)

ValueError: too many values to unpack (expected 4)

In [4]:
prompt = "What do you think about potatos?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=64
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

It seems like you might have meant to ask about "potatoes" rather than "potatos." Potatoes are a versatile and nutritious staple food! They are rich in vitamins, minerals, and fiber, and can be prepared in a variety of ways, from mashed to fried or roasted. What specifically would you like to


In [19]:
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

In [24]:
import torch
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

def custom_forward_with_hidden_states(model, input_ids, attention_mask=None):
    # Get embeddings
    inputs_embeds = model.model.embed_tokens(input_ids)
    
    # Setup position IDs
    batch_size, seq_length = input_ids.shape
    position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
    
    # Get rotary embeddings
    position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids)
    
    # Convert attention mask to 4D format expected by model
    if attention_mask is not None:
        attention_mask = AttentionMaskConverter._make_causal_mask(
            input_ids_shape=(batch_size, seq_length),
            dtype=inputs_embeds.dtype,
            device=inputs_embeds.device
        )
    
    # Pass through first layer
    hidden_states = inputs_embeds
    first_layer = model.model.layers[0]
    
    # Forward through first decoder layer
    layer_outputs = first_layer(
        hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        position_embeddings=position_embeddings
    )
    
    first_layer_hidden = layer_outputs[0]
    
    # Continue through rest of layers
    hidden_states = first_layer_hidden
    for decoder_layer in model.model.layers[1:]:
        layer_outputs = decoder_layer(
            hidden_states,
            attention_mask=attention_mask, 
            position_ids=position_ids,
            past_key_value=None,
            output_attentions=False,
            use_cache=False,
            position_embeddings=position_embeddings
        )
        hidden_states = layer_outputs[0]
        
    # Final layer norm
    hidden_states = model.model.norm(hidden_states)
    
    # Project to vocabulary
    logits = model.lm_head(hidden_states)
    
    return logits, first_layer_hidden

# Use the model
prompt = "What do you think about potatos?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

print("Input tokens:")
print(tokenizer.batch_decode(model_inputs.input_ids[0], skip_special_tokens=False))

# Initialize generation
generated_ids = model_inputs.input_ids  # Shape: [1, seq_len]
max_new_tokens = 64
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

# Store first layer hidden states from initial forward pass
first_logits, first_layer_hidden = custom_forward_with_hidden_states(
    model,
    generated_ids,
    attention_mask=model_inputs.attention_mask
)

print("\nGenerating tokens:")
# Generate tokens one at a time
generated_text = ""
for i in range(max_new_tokens):
    # Get next token prediction from logits
    next_token_logits = first_logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)  # Shape: [1]
    
    # Decode and print the token
    next_token_text = tokenizer.decode(next_token)
    print(f"Token {i}: {next_token_text} ({next_token.item()})")
    generated_text += next_token_text
    
    # Stop if we predict the pad token
    if next_token.item() == pad_token_id:
        break
        
    # Add predicted token to sequence - fixing the dimensionality
    next_token = next_token.unsqueeze(0)  # Shape: [1, 1]
    generated_ids = torch.cat([generated_ids, next_token], dim=1)  # Concatenate along sequence dimension
    
    # Get new logits for next iteration
    first_logits, _ = custom_forward_with_hidden_states(
        model,
        generated_ids,
        attention_mask=None
    )

print("\nFinal generated text:")
print(generated_text)
print(f"\nFirst layer hidden state shape from initial forward pass: {first_layer_hidden.shape}")

Input tokens:
['<|im_start|>', 'system', '\n', 'You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\n', '<|im_start|>', 'user', '\n', 'What', ' do', ' you', ' think', ' about', ' pot', 'atos', '?', '<|im_end|>', '\n', '<|im_start|>', 'assistant', '\n']

Generating tokens:
Token 0: It (2132)
Token 1:  seems (4977)
Token 2:  like (1075)
Token 3:  you (498)
Token 4:  might (2578)
Token 5:  have (614)
Token 6:  meant (8791)
Token 7:  to (311)
Token 8:  ask (2548)
Token 9:  about (911)



KeyboardInterrupt

