In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
from models.linear_attn_sw import HybridAttention

warnings.filterwarnings("ignore")

In [2]:
import os
os.environ['HF_TOKEN'] = "YOUR_HF_TOKEN"
!huggingface-cli login --token $HF_TOKEN

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `598Project` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
model = model.cuda()

In [4]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

## Replace with our trained hybrid blocks

In [5]:
layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
#layer_indices = [0]
    
# Load and replace each hybrid block
for layer_idx in layer_indices:
    # Create new hybrid attention
    hybrid_attn = HybridAttention(model.config, layer_idx=layer_idx)
    
    # Load saved weights
    state_dict = torch.load(f"hybrid_blocks/hybrid_layer_{layer_idx}.pt")
    hybrid_attn.load_state_dict(state_dict)
    
    # Move to GPU
    hybrid_attn = hybrid_attn.cuda()
    
    # Replace original attention with hybrid attention
    model.model.layers[layer_idx].self_attn = hybrid_attn
    
    print(f"Replaced attention in layer {layer_idx} with hybrid attention")

Replaced attention in layer 0 with hybrid attention
Replaced attention in layer 2 with hybrid attention
Replaced attention in layer 4 with hybrid attention
Replaced attention in layer 6 with hybrid attention
Replaced attention in layer 8 with hybrid attention
Replaced attention in layer 10 with hybrid attention
Replaced attention in layer 12 with hybrid attention
Replaced attention in layer 14 with hybrid attention


In [6]:
text = "2+2="
input_ids = tokenizer(text,return_tensors="pt").input_ids
input_ids = input_ids.cuda()

outputs = model.generate(
    input_ids,
    use_cache=False,  # Disable KV cache
    max_new_tokens=20
)

decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


2+2=elimelimelimelimelimelimelimelimelim proving Why meanings Why meaning Wol LETelimelimelimMakes


## Set up finetuning

In [7]:
def prepare_for_finetuning(model):
    # First freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    num_trainable_params = 0
    
    for layer_idx in layer_indices:
        hybrid_attn = model.model.layers[layer_idx].self_attn
        
        # 1. Unfreeze projection matrices
        trainable_projections = [
            hybrid_attn.q_proj.weight,
            hybrid_attn.k_proj.weight,
            hybrid_attn.v_proj.weight,
            hybrid_attn.o_proj.weight
        ]
        for param in trainable_projections:
            param.requires_grad = True
            num_trainable_params += param.numel()
        
        # 2. Unfreeze mixing factors
        hybrid_attn.window_factors.requires_grad = True
        hybrid_attn.linear_factors.requires_grad = True
        num_trainable_params += hybrid_attn.window_factors.numel()
        num_trainable_params += hybrid_attn.linear_factors.numel()
    
    print(f"Number of trainable parameters: {num_trainable_params:,}")
    return model

In [8]:
def verify_trainable_params(model):
    print("\nTrainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}")

In [9]:
model = prepare_for_finetuning(model)
verify_trainable_params(model)

Number of trainable parameters: 83,886,592

Trainable parameters:
model.layers.0.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.0.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.0.self_attn.q_proj.weight: torch.Size([2048, 2048])
model.layers.0.self_attn.k_proj.weight: torch.Size([512, 2048])
model.layers.0.self_attn.v_proj.weight: torch.Size([512, 2048])
model.layers.0.self_attn.o_proj.weight: torch.Size([2048, 2048])
model.layers.2.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.2.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.2.self_attn.q_proj.weight: torch.Size([2048, 2048])
model.layers.2.self_attn.k_proj.weight: torch.Size([512, 2048])
model.layers.2.self_attn.v_proj.weight: torch.Size([512, 2048])
model.layers.2.self_attn.o_proj.weight: torch.Size([2048, 2048])
model.layers.4.self_attn.window_factors: torch.Size([1, 32, 1, 1])
model.layers.4.self_attn.linear_factors: torch.Size([1, 32, 1, 1])
model.layers.4.s

## Prepare Dolly

In [10]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

class DollyDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=1024):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Set padding to left
        self.tokenizer.padding_side = 'left'

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Basic prompt structure
        prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
        prompt += f"### Instruction:\n{item['instruction']}\n\n"
        
        # Add context only if it exists and is not empty/null
        if 'context' in item and item['context'] and not isinstance(item['context'], float):  # Check for NaN
            context = item['context'].strip()
            if context:
                prompt += f"### Context:\n{context}\n\n"
        
        prompt += f"### Response:\n{item['response']}<|end of text|>\n\n"

        # Tokenize with truncation
        encodings = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Create labels by setting padding tokens to -100 to be ignored in loss
        labels = encodings['input_ids'].clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze(),
            'labels': labels.squeeze()
        }

In [11]:
from torch.nn.utils.rnn import pad_sequence
import torch

def dynamic_collate_fn(batch, max_length=1024, pad_token_id=128001):  # Update pad_token_id to your tokenizer's eos token
    input_ids, attention_masks, labels = [], [], []
    current_batch, current_length = [], 0

    for item in batch:
        length = len(item['input_ids'])
        
        # Check if adding the sequence exceeds the max_length
        if current_length + length > max_length:
            # Pack the current batch into final tensors
            packed_input_ids = torch.cat([x['input_ids'] for x in current_batch], dim=0)
            packed_attention_mask = torch.cat([x['attention_mask'] for x in current_batch], dim=0)
            packed_labels = torch.cat([x['labels'] for x in current_batch], dim=0)
            
            # Pad to max_length if needed
            if len(packed_input_ids) < max_length:
                padding_length = max_length - len(packed_input_ids)
                packed_input_ids = torch.cat([
                    packed_input_ids,
                    torch.full((padding_length,), pad_token_id, dtype=torch.long)
                ])
                packed_attention_mask = torch.cat([
                    packed_attention_mask,
                    torch.zeros(padding_length, dtype=torch.long)
                ])
                packed_labels = torch.cat([
                    packed_labels,
                    torch.full((padding_length,), -100, dtype=torch.long)
                ])
            
            # Append packed tensors
            input_ids.append(packed_input_ids)
            attention_masks.append(packed_attention_mask)
            labels.append(packed_labels)
            
            # Reset for the next pack
            current_batch, current_length = [], 0

        # Add the current sequence to the batch
        current_batch.append(item)
        current_length += length

    # Handle the last batch
    if current_batch:
        packed_input_ids = torch.cat([x['input_ids'] for x in current_batch], dim=0)
        packed_attention_mask = torch.cat([x['attention_mask'] for x in current_batch], dim=0)
        packed_labels = torch.cat([x['labels'] for x in current_batch], dim=0)
        
        # Pad the last batch if it doesn't fill max_length
        if len(packed_input_ids) < max_length:
            padding_length = max_length - len(packed_input_ids)
            packed_input_ids = torch.cat([
                packed_input_ids,
                torch.full((padding_length,), pad_token_id, dtype=torch.long)
            ])
            packed_attention_mask = torch.cat([
                packed_attention_mask,
                torch.zeros(padding_length, dtype=torch.long)
            ])
            packed_labels = torch.cat([
                packed_labels,
                torch.full((padding_length,), -100, dtype=torch.long)
            ])
        
        input_ids.append(packed_input_ids)
        attention_masks.append(packed_attention_mask)
        labels.append(packed_labels)

    # Stack all batches into final tensors
    return {
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_masks),
        'labels': torch.stack(labels)
    }


In [12]:
print("Loading Dolly-15k dataset...")
dataset = load_dataset("databricks/databricks-dolly-15k")
train_val_split = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset, val_dataset = train_val_split['train'], train_val_split['test']

# Create datasets
train_dataset = DollyDataset(train_dataset, tokenizer)
val_dataset = DollyDataset(val_dataset, tokenizer)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=2,  # Adjust based on your GPU memory
    shuffle=True,
    collate_fn=dynamic_collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=dynamic_collate_fn
)

Loading Dolly-15k dataset...


In [13]:
# Test the dataloader
print("\nTesting dataloader with a sample batch...")
sample_batch = next(iter(train_loader))
print("\nBatch shapes:")
print(f"input_ids shape: {sample_batch['input_ids'].shape}")
print(f"attention_mask shape: {sample_batch['attention_mask'].shape}")

# Print an example to verify formatting
print("\nDecoded example:")
single_example = sample_batch['input_ids'][0]
decoded_text = tokenizer.decode(single_example)
print(decoded_text)


Testing dataloader with a sample batch...

Batch shapes:
input_ids shape: torch.Size([1, 1024])
attention_mask shape: torch.Size([1, 1024])

Decoded example:
<|begin_of_text|>Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is an Edgeworth box in Economics?

### Response:
An Edgeworth box in Economics, is a graphical representation of a market with just two commodities, X and Y, and two consumers. The dimensions of the box are the total quantities Ωx and Ωy of the two goods.<|end of text|>

<|begin_of_text|>Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is a Balance sheet?

### Response:
A balance sheet is a summary of an organisation's financial position. It lists the values, in the books of account on a particular date, of all the organisation's assets and liabilities. The assets and liabilities are grouped in categories, t

## Set up LoRA

In [14]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,  # LoRA rank
    lora_alpha=16,
    target_modules=[
        "q_proj",  
        "k_proj",  
        "v_proj",  
        "o_proj",
    ],
    modules_to_save=[
        "window_factors",  
        "linear_factors"   
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [15]:
# Convert model to PEFT
model = get_peft_model(model, lora_config)

In [16]:
# Verify the trainable parameters
def print_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        all_params += num_params
        if param.requires_grad:
            trainable_params += num_params
    print(
        f"trainable params: {trainable_params} || "
        f"all params: {all_params} || "
        f"trainable%: {100 * trainable_params / all_params:.2f}"
    )

print_trainable_parameters(model)

trainable params: 1703936 || all params: 1237518848 || trainable%: 0.14


## Finetuning

In [17]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from pathlib import Path

In [18]:
def train_lora(model, train_loader, val_loader, save_dir="lora_checkpoints", num_epochs=1):
    """LoRA finetuning with paper specifications"""
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)
    
    # Optimizer settings from paper
    optimizer = AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay=0.01
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Training loop
    best_val_loss = float('inf')
    training_stats = {'train_losses': [], 'val_losses': []}
    
    print(f"\nStarting LoRA finetuning on {device}")
    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of validation batches: {len(val_loader)}\n")
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader),
                          desc=f'LoRA Epoch {epoch + 1}/{num_epochs}')
        
        for step, batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
                use_cache=False
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            # Optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
            optimizer.step()
            optimizer.zero_grad()
            
            # Update progress
            avg_loss = total_loss / (step + 1)
            progress_bar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'avg_loss': f'{avg_loss:.3f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })
        
        train_loss = total_loss / len(train_loader)
        training_stats['train_losses'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    use_cache=False
                )
                val_loss += outputs.loss.item()
        
        val_loss /= len(val_loader)
        training_stats['val_losses'].append(val_loss)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        print(f"\nEpoch {epoch + 1} Results:")
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {val_loss:.4f}")
        
        # Save if best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model.save_pretrained(save_dir / "best_model")
            print(f"Saved new best model! (val_loss: {val_loss:.4f})")
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'training_stats': training_stats
        }
        torch.save(checkpoint, save_dir / f"checkpoint_epoch_{epoch+1}.pt")
    
    print("\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    return model, training_stats

In [19]:
save_dir = "llama_lora_checkpoints"
model, stats = train_lora(model, train_loader, val_loader, save_dir=save_dir)


Starting LoRA finetuning on cuda
Number of training batches: 6755
Number of validation batches: 751



LoRA Epoch 1/1:   0%|          | 0/6755 [00:00<?, ?it/s]

Validation:   0%|          | 0/751 [00:00<?, ?it/s]


Epoch 1 Results:
Training loss: 2.0905
Validation loss: 1.9463
Saved new best model! (val_loss: 1.9463)

Training completed!
Best validation loss: 1.9463


In [21]:
text = "2+2="
input_ids = tokenizer(text, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_new_tokens=20, use_cache=False)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


2+2=4
3+3=6
4+4=8
5+5=10

