In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
from models.linear_attn_sw import HybridAttention

warnings.filterwarnings("ignore")

In [None]:
import os
os.environ['HF_TOKEN'] = 'YOUR_HF_TOKEN'
!huggingface-cli login --token $HF_TOKEN

In [2]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
model = model.cuda()

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

## Replace with our trained hybrid blocks

In [4]:
layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
#layer_indices = [0]
    
# Load and replace each hybrid block
for layer_idx in layer_indices:
    # Create new hybrid attention
    hybrid_attn = HybridAttention(model.config, layer_idx=layer_idx)
    
    # Load saved weights
    state_dict = torch.load(f"hybrid_blocks/hybrid_layer_{layer_idx}.pt")
    hybrid_attn.load_state_dict(state_dict)
    
    # Move to GPU
    hybrid_attn = hybrid_attn.cuda()
    
    # Replace original attention with hybrid attention
    model.model.layers[layer_idx].self_attn = hybrid_attn
    
    print(f"Replaced attention in layer {layer_idx} with hybrid attention")

Replaced attention in layer 0 with hybrid attention
Replaced attention in layer 2 with hybrid attention
Replaced attention in layer 4 with hybrid attention
Replaced attention in layer 6 with hybrid attention
Replaced attention in layer 8 with hybrid attention
Replaced attention in layer 10 with hybrid attention
Replaced attention in layer 12 with hybrid attention
Replaced attention in layer 14 with hybrid attention


In [5]:
text = "2+2="
input_ids = tokenizer(text,return_tensors="pt").input_ids
input_ids = input_ids.cuda()

In [6]:
outputs = model.generate(
    input_ids,
    use_cache=False,  # Disable KV cache
    max_new_tokens=20
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [7]:
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

2+2= Brid brid brid brid brid brid brid brid Brid brid brid brid brid brid brid brid brid brid Bridge brid


## Set up finetuning

In [8]:
def prepare_for_finetuning(model):
    # First freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    num_trainable_params = 0
    
    for layer_idx in layer_indices:
        hybrid_attn = model.model.layers[layer_idx].self_attn
        
        # 1. Unfreeze projection matrices
        trainable_projections = [
            hybrid_attn.q_proj.weight,
            hybrid_attn.k_proj.weight,
            hybrid_attn.v_proj.weight,
            hybrid_attn.o_proj.weight
        ]
        for param in trainable_projections:
            param.requires_grad = True
            num_trainable_params += param.numel()
        
        # 2. Unfreeze mixing factors
        hybrid_attn.window_factors.requires_grad = True
        hybrid_attn.linear_factors.requires_grad = True
        num_trainable_params += hybrid_attn.window_factors.numel()
        num_trainable_params += hybrid_attn.linear_factors.numel()
    
    print(f"Number of trainable parameters: {num_trainable_params:,}")
    return model

In [9]:
def verify_trainable_params(model):
    print("\nTrainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}")

In [10]:
model = prepare_for_finetuning(model)
verify_trainable_params(model)

Number of trainable parameters: 83,886,592

Trainable parameters:
model.layers.0.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.0.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.0.self_attn.q_proj.weight: torch.Size([2048, 2048])
model.layers.0.self_attn.k_proj.weight: torch.Size([512, 2048])
model.layers.0.self_attn.v_proj.weight: torch.Size([512, 2048])
model.layers.0.self_attn.o_proj.weight: torch.Size([2048, 2048])
model.layers.2.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.2.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.2.self_attn.q_proj.weight: torch.Size([2048, 2048])
model.layers.2.self_attn.k_proj.weight: torch.Size([512, 2048])
model.layers.2.self_attn.v_proj.weight: torch.Size([512, 2048])
model.layers.2.self_attn.o_proj.weight: torch.Size([2048, 2048])
model.layers.4.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.4.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.4.s

## Prepare Alpaca

In [11]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

class AlpacaDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=1024):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Set padding to left
        self.tokenizer.padding_side = 'left'

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Format the prompt correctly based on the presence of the input field
        if 'input' in item and item['input'].strip():
            # Include the input if it's not empty
            text = (
                "Below is an instruction that describes a task, paired with an input that provides further context. "
                "Write a response that appropriately completes the request.\n\n"
                f"### Instruction:\n{item['instruction']}\n\n"
                f"### Input:\n{item['input']}\n\n"
                f"### Response:\n{item['output']}<|end of text|>\n\n"
            )
        else:
            # Exclude the input if it's not present or empty
            text = (
                "Below is an instruction that describes a task. "
                "Write a response that appropriately completes the request.\n\n"
                f"### Instruction:\n{item['instruction']}\n\n"
                f"### Response:\n{item['output']}<|end of text|>\n\n"
            )

        # Tokenize with truncation (but no padding at this stage)
        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Create labels by setting padding tokens to -100 to be ignored in loss
        labels = encodings['input_ids'].clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze(),
            'labels': labels.squeeze()
        }

In [12]:
from torch.nn.utils.rnn import pad_sequence
import torch

def dynamic_collate_fn(batch, max_length=1024, pad_token_id=128001):  # Update pad_token_id to your tokenizer's eos token
    input_ids, attention_masks, labels = [], [], []
    current_batch, current_length = [], 0

    for item in batch:
        length = len(item['input_ids'])
        
        # Check if adding the sequence exceeds the max_length
        if current_length + length > max_length:
            # Pack the current batch into final tensors
            packed_input_ids = torch.cat([x['input_ids'] for x in current_batch], dim=0)
            packed_attention_mask = torch.cat([x['attention_mask'] for x in current_batch], dim=0)
            packed_labels = torch.cat([x['labels'] for x in current_batch], dim=0)
            
            # Pad to max_length if needed
            if len(packed_input_ids) < max_length:
                padding_length = max_length - len(packed_input_ids)
                packed_input_ids = torch.cat([
                    packed_input_ids,
                    torch.full((padding_length,), pad_token_id, dtype=torch.long)
                ])
                packed_attention_mask = torch.cat([
                    packed_attention_mask,
                    torch.zeros(padding_length, dtype=torch.long)
                ])
                packed_labels = torch.cat([
                    packed_labels,
                    torch.full((padding_length,), -100, dtype=torch.long)
                ])
            
            # Append packed tensors
            input_ids.append(packed_input_ids)
            attention_masks.append(packed_attention_mask)
            labels.append(packed_labels)
            
            # Reset for the next pack
            current_batch, current_length = [], 0

        # Add the current sequence to the batch
        current_batch.append(item)
        current_length += length

    # Handle the last batch
    if current_batch:
        packed_input_ids = torch.cat([x['input_ids'] for x in current_batch], dim=0)
        packed_attention_mask = torch.cat([x['attention_mask'] for x in current_batch], dim=0)
        packed_labels = torch.cat([x['labels'] for x in current_batch], dim=0)
        
        # Pad the last batch if it doesn't fill max_length
        if len(packed_input_ids) < max_length:
            padding_length = max_length - len(packed_input_ids)
            packed_input_ids = torch.cat([
                packed_input_ids,
                torch.full((padding_length,), pad_token_id, dtype=torch.long)
            ])
            packed_attention_mask = torch.cat([
                packed_attention_mask,
                torch.zeros(padding_length, dtype=torch.long)
            ])
            packed_labels = torch.cat([
                packed_labels,
                torch.full((padding_length,), -100, dtype=torch.long)
            ])
        
        input_ids.append(packed_input_ids)
        attention_masks.append(packed_attention_mask)
        labels.append(packed_labels)

    # Stack all batches into final tensors
    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_masks),
        'labels': torch.stack(labels)
    }


In [13]:
# Load and split dataset
dataset = load_dataset("yahma/alpaca-cleaned")
train_val_split = dataset['train'].train_test_split(test_size=0.1)
train_dataset, val_dataset = train_val_split['train'], train_val_split['test']

# Create datasets
train_dataset = AlpacaDataset(train_dataset, tokenizer)
val_dataset = AlpacaDataset(val_dataset, tokenizer)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=2,  # Adjust based on your GPU memory
    shuffle=True,
    collate_fn=dynamic_collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=dynamic_collate_fn
)

In [14]:
# Get a single batch from train_loader
sample_batch = next(iter(train_loader))

# Print the shapes first
print("\nBatch shapes:")
print(f"input_ids shape: {sample_batch['input_ids'].shape}")
print(f"attention_mask shape: {sample_batch['attention_mask'].shape}")

# Decode a single example from the batch
print("\nDecoded example:")
single_example = sample_batch['input_ids'][0]  # Take first example from batch
decoded_text = tokenizer.decode(single_example)
print(decoded_text)


Batch shapes:
input_ids shape: torch.Size([1, 1024])
attention_mask shape: torch.Size([1, 1024])

Decoded example:
<|begin_of_text|>Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Compare and contrast two characters in the movie The Avengers.

### Input:
Captain America and Iron Man

### Response:
Captain America and Iron Man are two of the main characters in the Marvel film, The Avengers. Both characters are members of the Avengers team and possess their own unique abilities and outlooks on life, but they have several key differences.

Captain America, also known as Steve Rogers, is a super-soldier who gains his powers from a formula developed during World War II. He is the epitome of an honorable and moral hero, who puts the safety of others before his own. Captain America is seen as the defender of justice, and is always fighting for what is right, using h

## Set up LoRA

In [15]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,  # LoRA rank
    lora_alpha=16,
    target_modules=[
        "q_proj",  
        "k_proj",  
        "v_proj",  
        "o_proj",
    ],
    modules_to_save=[
        "window_factors",  
        "linear_factors"   
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [16]:
# Convert model to PEFT
model = get_peft_model(model, lora_config)

In [17]:
# Verify the trainable parameters
def print_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        all_params += num_params
        if param.requires_grad:
            trainable_params += num_params
    print(
        f"trainable params: {trainable_params} || "
        f"all params: {all_params} || "
        f"trainable%: {100 * trainable_params / all_params:.2f}"
    )

print_trainable_parameters(model)

trainable params: 1703936 || all params: 1237518848 || trainable%: 0.14


## Finetuning

In [18]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from pathlib import Path

In [19]:
def train_lora(model, train_loader, val_loader, save_dir="lora_checkpoints", num_epochs=1):
    """LoRA finetuning with paper specifications"""
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)
    
    # Optimizer settings from paper
    optimizer = AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay=0.01
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Training loop
    best_val_loss = float('inf')
    training_stats = {'train_losses': [], 'val_losses': []}
    
    print(f"\nStarting LoRA finetuning on {device}")
    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of validation batches: {len(val_loader)}\n")
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader),
                          desc=f'LoRA Epoch {epoch + 1}/{num_epochs}')
        
        for step, batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
                use_cache=False
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            # Optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
            optimizer.step()
            optimizer.zero_grad()
            
            # Update progress
            avg_loss = total_loss / (step + 1)
            progress_bar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'avg_loss': f'{avg_loss:.3f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })
        
        train_loss = total_loss / len(train_loader)
        training_stats['train_losses'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    use_cache=False
                )
                val_loss += outputs.loss.item()
        
        val_loss /= len(val_loader)
        training_stats['val_losses'].append(val_loss)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        print(f"\nEpoch {epoch + 1} Results:")
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {val_loss:.4f}")
        
        # Save if best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model.save_pretrained(save_dir / "best_model")
            print(f"Saved new best model! (val_loss: {val_loss:.4f})")
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'training_stats': training_stats
        }
        torch.save(checkpoint, save_dir / f"checkpoint_epoch_{epoch+1}.pt")
    
    print("\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    return model, training_stats

In [20]:
save_dir = "llama_lora_checkpoints"
model, stats = train_lora(model, train_loader, val_loader, save_dir=save_dir)


Starting LoRA finetuning on cuda
Number of training batches: 23292
Number of validation batches: 2588



LoRA Epoch 1/1:   0%|          | 0/23292 [00:00<?, ?it/s]

Validation:   0%|          | 0/2588 [00:00<?, ?it/s]


Epoch 1 Results:
Training loss: 1.4209
Validation loss: 1.3385
Saved new best model! (val_loss: 1.3385)

Training completed!
Best validation loss: 1.3385


In [28]:
text = "How are you?"
input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_new_tokens=20, use_cache=False)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


How are you? How has the past week been for you? Have you been busy with work, family, or other
