In [1]:
!pip install -r requirements.txt

from IPython.display import clear_output
clear_output(wait=False)

In [2]:
# from dotenv import load_dotenv
# import os

# load_dotenv()  # Load environment variables from .env file
# hf_token = os.getenv('HF_TOKEN')
# os.environ['HF_TOKEN'] = hf_token
# !huggingface-cli login --token $HF_TOKEN

In [3]:
import os
os.environ['HF_TOKEN'] = 'YOUR_HF_TOKEN'
!huggingface-cli login --token $HF_TOKEN

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `598Project` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings

warnings.filterwarnings("ignore")

In [5]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
model = model.cuda()

## Create Hybrid Blocks

In [6]:
from models.linear_attn_sw import HybridAttention

In [7]:
hybrid_blocks = {}
for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
    # Create new hybrid attention with same config
    hybrid_attn = HybridAttention(model.config, layer_idx=layer_idx)
    
    # Copy weights from original attention
    original_attn = model.model.layers[layer_idx].self_attn
    hybrid_attn.q_proj = original_attn.q_proj
    hybrid_attn.k_proj = original_attn.k_proj
    hybrid_attn.v_proj = original_attn.v_proj
    hybrid_attn.o_proj = original_attn.o_proj
    hybrid_attn.rotary_emb = original_attn.rotary_emb
    
    # Store but don't replace yet
    hybrid_blocks[layer_idx] = hybrid_attn.cuda()

In [8]:
# Freeze LLaMA model
for param in model.parameters():
    param.requires_grad = False

# Verify LLaMA is frozen
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"LLaMA trainable parameters: {len(trainable_params)} (should be 0)")

# Ensure hybrid attention blocks are trainable
for hybrid_attn in hybrid_blocks.values():
    for param in hybrid_attn.parameters():
        param.requires_grad = True

# Verify hybrid blocks are trainable
hybrid_trainable_params = []
for block in hybrid_blocks.values():
    hybrid_trainable_params.extend([p for p in block.parameters() if p.requires_grad])
print(f"Hybrid attention trainable parameters: {len(hybrid_trainable_params)}")

LLaMA trainable parameters: 0 (should be 0)
Hybrid attention trainable parameters: 48


## Create Dataloaders

In [9]:
# training.ipynb
from transformers import AutoTokenizer
from data_processing import collect_tokens, create_sequences, create_data_loaders

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

# Collect tokens
collected_tokens = collect_tokens(tokenizer, target_tokens=600_000)
print(f"\nFinal token count: {len(collected_tokens):,}")

# Create sequences
inputs, masks, targets = create_sequences(collected_tokens, sequence_length=1024)
print(f"Input shape: {inputs.shape}")
print(f"Target shape: {targets.shape}")

# Create dataloaders
train_loader, val_loader = create_data_loaders(inputs, masks, targets, batch_size=8)

print("\nDataset Statistics:")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Sample a batch
sample_batch = next(iter(train_loader))
print("\nBatch shapes:")
for k, v in sample_batch.items():
    print(f"{k}: {v.shape}")

# Decode a sample sequence
sample_seq = sample_batch['input_ids'][0][:10].tolist()
decoded = tokenizer.decode(sample_seq)
print(f"\nSample decoded text:\n{decoded}")

Resolving data files:   0%|          | 0/2030 [00:00<?, ?it/s]

Collected 101,374 tokens
Collected 201,135 tokens
Collected 301,190 tokens
Collected 400,573 tokens
Collected 500,666 tokens
Collected 600,893 tokens

Final token count: 600,000
Input shape: torch.Size([4680, 1023])
Target shape: torch.Size([4680, 1023])

Dataset Statistics:
Training batches: 527
Validation batches: 59

Batch shapes:
input_ids: torch.Size([8, 1023])
attention_mask: torch.Size([8, 1023])
labels: torch.Size([8, 1023])

Sample decoded text:
interpreted as an unsigned char), and returns a pointer


## Training Code

In [10]:
import time
from datetime import timedelta
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [11]:
import os
os.makedirs('hybrid_blocks', exist_ok=True)

In [12]:
learning_rates = {
    0: 0,   # Layer 0
    2: 2e-5,   # Layer 2
    4: 5e-5,   # Layer 4
    6: 7e-5,   # Layer 6
    8: 2e-5,   # Layer 8
    10: 2e-5,  # Layer 10
    12: 1e-4,  # Layer 12
    14: 2e-5   # Layer 14
}

# Create optimizers with layer-specific learning rates and additional parameters
optimizers = {
    layer_idx: torch.optim.AdamW(
        hybrid_blocks[layer_idx].parameters(), 
        lr=learning_rates[layer_idx],
        weight_decay=0.01,  # Add weight decay
        betas=(0.9, 0.999)  # Adjust beta parameters
    )
    for layer_idx in hybrid_blocks.keys()
}

In [13]:
def train_epoch(model, hybrid_blocks, train_loader, optimizers, accumulation_steps, mse_factor=1e3):
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    layer_losses = {idx: 0.0 for idx in layer_indices}

    max_norms = {
        0: 2.0,    # Earlier layers might need less strict clipping
        2: 1.5,
        4: 1.5,
        6: 1.2,
        8: 1.0,    # Deeper layers get stricter clipping
        10: 1.0,
        12: 1.0,
        14: 1.0
    }

    lr_adjusted = {8: False, 12: False}
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Training')
    
    for batch_idx, batch in progress_bar:
        input_ids = batch['input_ids'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        position_ids = torch.arange(0, input_ids.size(-1), device=input_ids.device).unsqueeze(0)
        
        # Get original model outputs
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                #attention_mask=attention_mask,
                position_ids=position_ids,
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
            
            # Get teacher outputs for each layer
            teacher_outputs = []
            for layer_idx in layer_indices:
                layer = model.model.layers[layer_idx]
                layer_input = layer.input_layernorm(hidden_states[layer_idx])
                attn_output = layer.self_attn(
                    hidden_states=layer_input,
                    #attention_mask=attention_mask,
                    position_ids=position_ids
                )[0]
                teacher_outputs.append((layer_input, attn_output))
        
        # Train hybrid blocks
        total_loss = 0
        for idx, layer_idx in enumerate(layer_indices):
            layer_input, teacher_output = teacher_outputs[idx]
            
            optimizers[layer_idx].zero_grad()
            
            # Forward through hybrid attention
            hybrid_output = hybrid_blocks[layer_idx](
                hidden_states=layer_input,
                #attention_mask=attention_mask,
                position_ids=position_ids
            )[0]
            
            # Compute loss
            loss = F.mse_loss(hybrid_output, teacher_output) * mse_factor
            
            # Backward and optimize
            scaled_loss = loss / accumulation_steps
            scaled_loss.backward()
            
            layer_losses[layer_idx] += loss.item()
            total_loss += loss.item()

            torch.nn.utils.clip_grad_norm_(
                hybrid_blocks[layer_idx].parameters(), 
                max_norm=max_norms[layer_idx]
            )
            
            if (batch_idx + 1) % accumulation_steps == 0:
                ## HACK FOR LAYER 8 AND 12
                if layer_idx == 8 and not lr_adjusted[8]:
                    avg_loss = layer_losses[8] / accumulation_steps
                    if avg_loss < 1.5:
                        for param_group in optimizers[8].param_groups:
                            param_group['lr'] = 0.0
                        lr_adjusted[8] = True
                        print(f"Layer 8 loss below 1.5 ({avg_loss:.4f}), setting lr to 0")
                
                elif layer_idx == 12 and not lr_adjusted[12]:
                    avg_loss = layer_losses[12] / accumulation_steps
                    if avg_loss < 0.6:
                        for param_group in optimizers[12].param_groups:
                            param_group['lr'] = 0.0
                        lr_adjusted[12] = True
                        print(f"Layer 12 loss below 0.6 ({avg_loss:.4f}), setting lr to 0")
                
                optimizers[layer_idx].step()
        
        # Update progress bar
        if (batch_idx + 1) % accumulation_steps == 0:
            avg_losses = {idx: loss/accumulation_steps for idx, loss in layer_losses.items()}
            layer_loss_str = " | ".join([f"Layer {idx}: {loss:.4f}" for idx, loss in avg_losses.items()])
            progress_bar.set_description(layer_loss_str) 
            layer_loss_line = " ".join([f"Layer {idx}: {loss:.4f}" for idx, loss in avg_losses.items()])
            print("Layer Losses: " + layer_loss_line)
            layer_losses = {idx: 0.0 for idx in layer_indices}
    
    return layer_losses

In [14]:
def validate(model, hybrid_blocks, val_loader, mse_factor=1e3):
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    layer_losses = {idx: 0.0 for idx in layer_indices}
    num_batches = 0
    
    model.eval()
    for block in hybrid_blocks.values():
        block.eval()
        
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            position_ids = torch.arange(0, input_ids.size(-1), device=input_ids.device).unsqueeze(0)
            
            # Get teacher outputs
            outputs = model(
                input_ids=input_ids,
                #attention_mask=attention_mask,
                position_ids=position_ids,
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
            
            # Get teacher attention outputs
            teacher_outputs = []
            for layer_idx in layer_indices:
                layer = model.model.layers[layer_idx]
                layer_input = layer.input_layernorm(hidden_states[layer_idx])
                attn_output = layer.self_attn(
                    hidden_states=layer_input,
                    #attention_mask=attention_mask,
                    position_ids=position_ids
                )[0]
                teacher_outputs.append((layer_input, attn_output))
            
            # Compare with hybrid outputs
            for idx, layer_idx in enumerate(layer_indices):
                layer_input, teacher_output = teacher_outputs[idx]
                hybrid_output = hybrid_blocks[layer_idx](
                    hidden_states=layer_input,
                    #attention_mask=attention_mask,
                    position_ids=position_ids
                )[0]
                
                loss = F.mse_loss(hybrid_output, teacher_output) * mse_factor
                layer_losses[layer_idx] += loss.item()
            
            num_batches += 1
            
    # Set back to training mode
    model.train()
    for block in hybrid_blocks.values():
        block.train()
            
    return {idx: loss/num_batches for idx, loss in layer_losses.items()}

In [15]:
num_epochs = 1
accumulation_steps = 8
start_time = time.time()

model.eval()  # Keep model in eval mode since we're using it as a teacher
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Set only hybrid blocks to train mode
    for block in hybrid_blocks.values():
        block.train()
        
    train_losses = train_epoch(model, hybrid_blocks, train_loader, optimizers, accumulation_steps)
    
    # Validate
    val_losses = validate(model, hybrid_blocks, val_loader)
    
    # Print validation losses
    val_str = " ".join([f"L{idx}: {loss:.4f}" for idx, loss in val_losses.items()])
    print(f"Validation losses: {val_str}")
    
    # Save all models
    for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
        torch.save(
            hybrid_blocks[layer_idx].state_dict(),
            f'hybrid_blocks/hybrid_layer_{layer_idx}.pt'
        )
    
    epoch_time = time.time() - start_time
    print(f"Time elapsed: {timedelta(seconds=int(epoch_time))}")

total_time = time.time() - start_time
print(f"\nTraining completed in: {timedelta(seconds=int(total_time))}")


Epoch 1/1


Training:   0%|          | 0/527 [00:00<?, ?it/s]

The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be removed and `position_embeddings` will be mandatory.


Layer Losses: Layer 0: 0.3174 Layer 2: 2.8614 Layer 4: 3.6294 Layer 6: 3.2401 Layer 8: 3.4405 Layer 10: 4.4190 Layer 12: 10.5546 Layer 14: 10.1205
Layer Losses: Layer 0: 0.3167 Layer 2: 2.5238 Layer 4: 3.1547 Layer 6: 2.8580 Layer 8: 3.3167 Layer 10: 4.1291 Layer 12: 7.8397 Layer 14: 9.8710
Layer Losses: Layer 0: 0.3191 Layer 2: 2.3030 Layer 4: 2.8218 Layer 6: 2.5607 Layer 8: 3.1989 Layer 10: 4.0031 Layer 12: 6.1538 Layer 14: 9.7724
Layer Losses: Layer 0: 0.3143 Layer 2: 2.1035 Layer 4: 2.5451 Layer 6: 2.3403 Layer 8: 3.1052 Layer 10: 3.8319 Layer 12: 4.8540 Layer 14: 9.3613
Layer Losses: Layer 0: 0.3196 Layer 2: 1.9127 Layer 4: 2.2725 Layer 6: 2.1854 Layer 8: 2.9575 Layer 10: 3.6550 Layer 12: 4.2625 Layer 14: 9.3092
Layer Losses: Layer 0: 0.3170 Layer 2: 1.7613 Layer 4: 2.0917 Layer 6: 2.0499 Layer 8: 2.7928 Layer 10: 3.5303 Layer 12: 3.9365 Layer 14: 9.2048
Layer Losses: Layer 0: 0.3193 Layer 2: 1.6310 Layer 4: 1.9168 Layer 6: 1.9621 Layer 8: 2.6689 Layer 10: 3.3836 Layer 12: 3.6076 