In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig

MODEL_NAME = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

In [2]:
from peft import LoraConfig, get_peft_model

# Define LoRA configuration
lora_config = LoraConfig(
    r=8,  # LoRA rank (low-rank dimension)
    lora_alpha=16,  # Scaling factor
    target_modules=["q_proj", "v_proj"],  # Apply LoRA to attention layers
    lora_dropout=0.05,
    bias="none",
)

# Apply LoRA to Llama model
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.06889230774326355


In [63]:
class LoRAGenerator(nn.Module):
    def __init__(self, base_model, lora_dim=8, num_target_modules=2):
        super().__init__()
        self.lora_dim = lora_dim
        self.hidden_size = base_model.config.hidden_size
        self.num_target_modules = num_target_modules  # Number of target modules (q_proj, v_proj)

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                activation="gelu"
            ),
            num_layers=4
        )

        # Generate separate LoRA updates for each target module
        self.lora_projections = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.GELU(),
                nn.Linear(self.hidden_size, self.lora_dim * self.hidden_size)
            ) for _ in range(num_target_modules)
        ])

    def forward(self, context_embedding):
        # Process context through transformer
        context_repr = self.transformer_encoder(context_embedding)
        
        # Take last token representation
        last_token_embedding = context_repr[:, -1, :]  # (batch, hidden_size)
        
        # Generate LoRA updates for each target module
        lora_updates = []
        for projection in self.lora_projections:
            # Generate update and reshape to match LoRA dimensions
            update = projection(last_token_embedding)
            update = update.view(-1, self.lora_dim, self.hidden_size)
            lora_updates.append(update)
            
        return torch.stack(lora_updates, dim=1)  # (batch, num_modules, lora_dim, hidden_size)

In [58]:
# Set up padding token
from datasets import load_dataset, Dataset
import requests
# Load PTB dataset in streaming mode (avoids LocalFileSystem issues)
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
text = response.text

# Create dataset from text
data = [{"text": chunk} for chunk in text.split('\n\n') if chunk.strip()]
dataset = Dataset.from_list(data)

# Select only the first 500 examples (no full download)

tokenizer.pad_token = tokenizer.eos_token  # Use EOS token as padding token
tokenizer.padding_side = "left"

# Function to split each example into (context, target)
# Function to split each example into (context, target)
def create_context_target_pairs(examples):
    encoding = tokenizer(examples["text"], 
                         truncation=True, 
                         padding="max_length", 
                         max_length=512)

    input_ids = encoding["input_ids"]  # Convert tensors to list
    
    # Ensure tokenization worked
    if len(input_ids[0]) < 2:
        return {}  # Skip invalid samples

    # Split the sequence in half
    split_idx = len(input_ids[0]) // 2  # Halfway point
    context = [ids[:split_idx] for ids in input_ids]  # First half as context
    target = [ids[split_idx:] for ids in input_ids]   # Second half as target

    return {"context": context, "target": target}


# Apply function to dataset
tokenized_datasets = dataset.map(create_context_target_pairs, 
                           batched=True, 
                           remove_columns=["text"])
tokenized_datasets.set_format(type="torch", columns=["context", "target"])

# Print a sample
print(tokenized_datasets[0])

Map:   0%|          | 0/7222 [00:00<?, ? examples/s]

{'context': tensor([128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001, 128001,
        128001, 128001, 1280

In [59]:
import torch.nn.functional as F

def compute_kl_loss(logits_base, logits_lora):
    """KL Divergence between logits of base model (with context) and LoRA model (without context)."""
    probs_base = F.log_softmax(logits_base, dim=-1)
    probs_lora = F.log_softmax(logits_lora, dim=-1)
    kl_div = F.kl_div(probs_lora, probs_base, reduction='batchmean')
    return kl_div


In [78]:
class LoRAGenerator(nn.Module):
    def __init__(self, base_model, lora_dim=8):
        super().__init__()
        self.lora_dim = lora_dim
        self.hidden_size = base_model.config.hidden_size
        
        # Simple MLP for each target module (q_proj and v_proj)
        self.update_generators = nn.ModuleDict({
            'q_proj': nn.Sequential(
                nn.Linear(self.hidden_size, 4*self.hidden_size),
                nn.GELU(),
                nn.Linear(4*self.hidden_size, 2 * self.hidden_size * self.lora_dim)  # For both A and B
            ),
            'v_proj': nn.Sequential(
                nn.Linear(self.hidden_size, 4*self.hidden_size),
                nn.GELU(),
                nn.Linear(4*self.hidden_size, 2 * self.hidden_size * self.lora_dim)  # For both A and B
            )
        })

    def forward(self, context_embedding):
        # Take mean of context embeddings
        context_mean = context_embedding.mean(dim=1)  # [batch_size, hidden_size]
        
        updates = {}
        for module_name, generator in self.update_generators.items():
            # Generate concatenated updates
            concat_update = generator(context_mean)  # [batch_size, 2 * hidden_size * lora_dim]
            
            # Split into A and B updates
            split_point = self.hidden_size * self.lora_dim
            A_flat = concat_update[:, :split_point]
            B_flat = concat_update[:, split_point:]
            
            # Reshape A and B to match LoRA dimensions
            updates[module_name] = {
                # Reshape to match expected LoRA dimensions
                'A': A_flat.view(-1, self.lora_dim, self.hidden_size),  # [batch_size, lora_dim, hidden_size]
                'B': B_flat.view(-1, self.hidden_size, self.lora_dim)   # [batch_size, hidden_size, lora_dim]
            }
            
        return updates

# Training Loop
for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0
    
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        # Move data to GPU
        context_input_ids = batch["context"].to("cuda")
        target_ids = batch["target"].to("cuda")
        
        # Get base model outputs using the separate base_model
        with torch.no_grad():
            base_output = base_model(input_ids=context_input_ids)
            base_logits = base_output.logits
            context_embeddings = base_model.model.embed_tokens(context_input_ids)
        
        # Generate LoRA updates
        lora_updates = lora_generator(context_embeddings)
        
        # Debug print
        for module_name, update in lora_updates.items():
            print(f"\n{module_name} shapes:")
            print(f"A shape: {update['A'].shape}")
            print(f"B shape: {update['B'].shape}")
        
        # Update LoRA weights
        for module_name, update in lora_updates.items():
            module = get_lora_layer(lora_model, module_name)
            if module is not None:
                # Take mean across batch dimension
                A_mean = update['A'].mean(dim=0)  # [lora_dim, hidden_size]
                B_mean = update['B'].mean(dim=0)  # [hidden_size, lora_dim]
                
                # Verify shapes match LoRA expectations
                print(f"\nModule {module_name} weight shapes:")
                print(f"Original lora_A shape: {module.lora_A[module.active_adapter].weight.shape}")
                print(f"Original lora_B shape: {module.lora_B[module.active_adapter].weight.shape}")
                print(f"New A shape: {A_mean.shape}")
                print(f"New B shape: {B_mean.shape}")
                
                # Update weights
                module.lora_A[module.active_adapter].weight.data = A_mean
                module.lora_B[module.active_adapter].weight.data = B_mean
        
        # Forward pass with LoRA model
        lora_output = lora_model(input_ids=target_ids)
        
        # Compute loss
        loss = compute_kl_loss(base_logits, lora_output.logits)
        loss.backward()
        optimizer.step()
        
        # Update statistics
        total_loss += loss.item()
        num_batches += 1
        
        if num_batches % 10 == 0:
            print(f"Epoch {epoch}, Batch {num_batches}, Loss: {loss.item():.4f}")


q_proj shapes:
A shape: torch.Size([4, 2048, 8])
B shape: torch.Size([4, 8, 2048])

v_proj shapes:
A shape: torch.Size([4, 2048, 8])
B shape: torch.Size([4, 8, 2048])

Module q_proj weight shapes:
Original lora_A shape: torch.Size([2048, 8])
Original lora_B shape: torch.Size([8, 2048])
New A shape: torch.Size([2048, 8])
New B shape: torch.Size([8, 2048])

Module v_proj weight shapes:
Original lora_A shape: torch.Size([2048, 8])
Original lora_B shape: torch.Size([8, 2048])
New A shape: torch.Size([2048, 8])
New B shape: torch.Size([8, 2048])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x2048 and 8x2048)

In [65]:
for batch in train_dataloader:
    context_input_ids = batch["context"].to("cuda")
    target_ids = batch["target"].to("cuda")
    
    # Print shapes
    print(f"Context input ids shape: {context_input_ids.shape}")
    
    context_embeddings = model.model.embed_tokens(context_input_ids)
    print(f"Context embeddings shape: {context_embeddings.shape}")
    
    lora_update = lora_generator(context_embeddings)
    print(f"LoRA update shape: {lora_update.shape}")
    
    # Break after first batch to see shapes
    break

Context input ids shape: torch.Size([4, 256])
Context embeddings shape: torch.Size([4, 256, 2048])
LoRA update shape: torch.Size([4, 2048, 2048])
