In [4]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import torch.nn.functional as F

class LoRAGenerator(nn.Module):
    def __init__(self, hidden_size, lora_rank):
        super().__init__()
        self.hidden_size = hidden_size
        self.lora_rank = lora_rank
        
        # v_proj uses 512 output dimension (as seen in weight shape)
        self.v_proj_out_dim = 512
        
        self.context_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=8,
                dim_feedforward=4*hidden_size,
                dropout=0.1,
                activation="gelu",
                batch_first=True
            ),
            num_layers=2
        )
        
        # Separate generators for q_proj and v_proj
        self.q_proj_generator = self._create_lora_generator(hidden_size)  # 2048 output dim
        self.v_proj_generator = self._create_lora_generator(self.v_proj_out_dim)  # 512 output dim
        
    def _create_lora_generator(self, out_dim):
        return nn.Sequential(
            nn.Linear(self.hidden_size, 4*self.hidden_size),
            nn.GELU(),
            # Output matches dimensions needed for A and B separately
            nn.Linear(4*self.hidden_size, self.lora_rank * self.hidden_size + out_dim * self.lora_rank)
        )
    
    def _reshape_lora_matrices(self, raw_output, out_dim):
        # Split the output into parts for A and B
        split_point = self.lora_rank * self.hidden_size
        a_flat = raw_output[:, :split_point]
        b_flat = raw_output[:, split_point:]
        
        # Reshape A to [batch, rank, hidden]
        lora_A = a_flat.view(-1, self.lora_rank, self.hidden_size)
        
        # Reshape B to [batch, out_dim, rank]
        lora_B = b_flat.view(-1, out_dim, self.lora_rank)
        
        return lora_A, lora_B
    
    def forward(self, context_embeddings):
        # Process context
        context_encoded = self.context_encoder(context_embeddings)
        context_pooled = context_encoded.mean(dim=1)  # [batch_size, hidden_size]
        
        updates = {}
        
        # Generate q_proj updates (2048 output dim)
        q_raw = self.q_proj_generator(context_pooled)
        q_lora_A, q_lora_B = self._reshape_lora_matrices(q_raw, self.hidden_size)
        updates['q_proj'] = {
            'lora_A': q_lora_A,
            'lora_B': q_lora_B
        }
        
        # Generate v_proj updates (512 output dim)
        v_raw = self.v_proj_generator(context_pooled)
        v_lora_A, v_lora_B = self._reshape_lora_matrices(v_raw, self.v_proj_out_dim)
        updates['v_proj'] = {
            'lora_A': v_lora_A,
            'lora_B': v_lora_B
        }
        
        return updates

# def train_lora_generator(
#     base_model,
#     lora_model,
#     generator,
#     train_dataloader,
#     num_epochs=5,
#     learning_rate=1e-4,
# ):
#     device = next(base_model.parameters()).device
#     generator = generator.to(device)
#     optimizer = torch.optim.AdamW(generator.parameters(), lr=learning_rate)
    
#     # Get target modules
#     target_modules = {}
#     for name, module in lora_model.named_modules():
#         if hasattr(module, 'lora_A'):
#             if 'q_proj' in name or 'v_proj' in name:
#                 target_modules[name] = module
#                 print(f"\nFound module {name}")
#                 print(f"LoRA A shape: {module.lora_A[module.active_adapter].weight.shape}")
#                 print(f"LoRA B shape: {module.lora_B[module.active_adapter].weight.shape}")
    
#     for epoch in range(num_epochs):
#         total_loss = 0
#         num_batches = 0
        
#         for batch_idx, batch in enumerate(train_dataloader):
#             context_ids = batch["context"].to(device)
#             target_ids = batch["target"].to(device)
            
#             # Generate base model outputs
#             with torch.no_grad():
#                 context_embeddings = base_model.model.embed_tokens(context_ids)
#                 base_outputs = base_model(input_ids=context_ids)
#                 base_logits = base_outputs.logits
            
#             # Generate LoRA updates
#             lora_updates = generator(context_embeddings)
            
#             # Apply updates
#             for module_name, update in lora_updates.items():
#                 for full_name, module in target_modules.items():
#                     if module_name in full_name:
#                         lora_A = update['lora_A'].mean(dim=0)
#                         lora_B = update['lora_B'].mean(dim=0)
                        
#                         # Print shapes for debugging
#                         if batch_idx == 0:
#                             print(f"\nUpdating {full_name}")
#                             print(f"Generated LoRA A shape: {lora_A.shape}")
#                             print(f"Target LoRA A shape: {module.lora_A[module.active_adapter].weight.shape}")
#                             print(f"Generated LoRA B shape: {lora_B.shape}")
#                             print(f"Target LoRA B shape: {module.lora_B[module.active_adapter].weight.shape}")
                        
#                         # Update weights
#                         module.lora_A[module.active_adapter].weight.data = lora_A
#                         module.lora_B[module.active_adapter].weight.data = lora_B
            
#             # Forward pass with updated LoRA model
#             lora_outputs = lora_model(input_ids=target_ids)
            
#             # Compute loss
#             loss = F.kl_div(
#                 F.log_softmax(lora_outputs.logits, dim=-1),
#                 F.softmax(base_logits, dim=-1),
#                 reduction='batchmean'
#             )
            
#             # Optimization
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             total_loss += loss.item()
#             num_batches += 1
            
#             if num_batches % 1 == 0:
#                 print(f"Epoch {epoch}, Batch {num_batches}, Loss: {loss.item():.4f}")
        
#         avg_loss = total_loss / num_batches
#         print(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}")
# def train_lora_generator(
#     base_model,
#     lora_model,
#     generator,
#     train_dataloader,
#     num_epochs=5,
#     learning_rate=1e-4,
# ):
#     device = next(base_model.parameters()).device
#     generator = generator.to(device)
#     optimizer = torch.optim.AdamW(generator.parameters(), lr=learning_rate)
    
#     # Initialize target modules
#     target_modules = {}
#     for name, module in lora_model.named_modules():
#         if hasattr(module, 'lora_A'):
#             if 'q_proj' in name or 'v_proj' in name:
#                 target_modules[name] = module
    
#     print(f"Found {len(target_modules)} target modules for LoRA updates")
    
#     # Track statistics
#     epoch_stats = []
    
#     for epoch in range(num_epochs):
#         print(f"\nEpoch {epoch+1}/{num_epochs}")
#         print("-" * 20)
        
#         total_loss = 0
#         batch_losses = []
#         num_batches = 0
        
#         for batch_idx, batch in enumerate(train_dataloader):
#             context_ids = batch["context"].to(device)
#             target_ids = batch["target"].to(device)
            
#             with torch.no_grad():
#                 context_embeddings = base_model.model.embed_tokens(context_ids)
#                 base_outputs = base_model(input_ids=context_ids)
#                 base_logits = base_outputs.logits
            
#             # Generate and apply LoRA updates
#             lora_updates = generator(context_embeddings)
            
#             # Update LoRA weights
#             for module_name, update in lora_updates.items():
#                 for full_name, module in target_modules.items():
#                     if module_name in full_name:
#                         lora_A = update['lora_A'].mean(dim=0)
#                         lora_B = update['lora_B'].mean(dim=0)
#                         module.lora_A[module.active_adapter].weight.data = lora_A
#                         module.lora_B[module.active_adapter].weight.data = lora_B
            
#             # Forward pass with updated LoRA model
#             lora_outputs = lora_model(input_ids=target_ids)
            
#             # Compute loss
#             loss = F.kl_div(
#                 F.log_softmax(lora_outputs.logits, dim=-1),
#                 F.softmax(base_logits, dim=-1),
#                 reduction='batchmean'
#             )

#             #overpenalizing -- teacher forcing
            
#             # Optimization
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             # Track statistics
#             current_loss = loss.item()
#             total_loss += current_loss
#             batch_losses.append(current_loss)
#             num_batches += 1
            
#             print(f"Batch {batch_idx+1}: Loss = {current_loss:.4f}")
        
#         # Compute epoch statistics
#         avg_loss = total_loss / num_batches
#         min_loss = min(batch_losses)
#         max_loss = max(batch_losses)
        
#         stats = {
#             'epoch': epoch + 1,
#             'avg_loss': avg_loss,
#             'min_loss': min_loss,
#             'max_loss': max_loss,
#             'num_batches': num_batches
#         }
#         epoch_stats.append(stats)
        
#         print("\nEpoch Statistics:")
#         print(f"Average Loss: {avg_loss:.4f}")
#         print(f"Min Loss: {min_loss:.4f}")
#         print(f"Max Loss: {max_loss:.4f}")
#         print(f"Number of Batches: {num_batches}")
    
#     # Print final summary
#     print("\nTraining Summary")
#     print("=" * 40)
#     print("Epoch  |  Avg Loss  |  Min Loss  |  Max Loss")
#     print("-" * 40)
#     for stats in epoch_stats:
#         print(f"{stats['epoch']:5d}  |  {stats['avg_loss']:.4f}    |  {stats['min_loss']:.4f}    |  {stats['max_loss']:.4f}")
# # Model setup


def autoregressive_teacher_forcing(base_model, lora_model, input_ids, max_steps=20):
    """
    Run step-by-step teacher forcing evaluation:
    1. Get base model's next token prediction
    2. Compute KL divergence of LoRA model on that specific token
    3. Feed the base model's token back to LoRA model for next step
    """
    device = input_ids.device
    batch_size = input_ids.shape[0]
    losses = []
    
    # Start with the input context
    lora_tokens = input_ids.clone()
    
    with torch.no_grad():
        for step in range(max_steps):
            # Get base model prediction
            base_outputs = base_model(input_ids=lora_tokens)
            base_logits = base_outputs.logits[:, -1, :]  # Last token predictions
            base_probs = F.softmax(base_logits, dim=-1)
            
            # Get base model's token choice
            base_next_token = base_logits.argmax(dim=-1)  # [batch_size]
            
            # Get LoRA model prediction (on same sequence)
            lora_outputs = lora_model(input_ids=lora_tokens)
            lora_logits = lora_outputs.logits[:, -1, :]  # Last token predictions
            lora_log_probs = F.log_softmax(lora_logits, dim=-1)
            
            # Compute KL divergence just for this step's token
            step_loss = F.kl_div(
                lora_log_probs.unsqueeze(1),  # [batch, 1, vocab]
                base_probs.unsqueeze(1),      # [batch, 1, vocab]
                reduction='none'
            ).sum(dim=-1)  # [batch, 1]
            
            losses.append(step_loss)
            
            # Update LoRA tokens with base model's choice
            lora_tokens = torch.cat([lora_tokens, base_next_token.unsqueeze(1)], dim=1)
    
    # Stack losses for all steps
    step_losses = torch.cat(losses, dim=1)  # [batch_size, max_steps]
    return step_losses, lora_tokens

def train_lora_generator(
    base_model,
    lora_model,
    generator,
    train_dataloader,
    num_epochs=5,
    learning_rate=1e-4,
    max_gen_steps=20
):
    device = next(base_model.parameters()).device
    generator = generator.to(device)
    optimizer = torch.optim.AdamW(generator.parameters(), lr=learning_rate)
    
    # Initialize target modules
    target_modules = {}
    for name, module in lora_model.named_modules():
        if hasattr(module, 'lora_A'):
            if 'q_proj' in name or 'v_proj' in name:
                target_modules[name] = module
    
    print(f"Found {len(target_modules)} target modules for LoRA updates")
    
    # Track statistics
    epoch_stats = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 20)
        
        total_loss = 0
        total_steps = 0
        batch_losses = []
        
        for batch_idx, batch in enumerate(train_dataloader):
            context_ids = batch["context"].to(device)
            
            # Get initial context embeddings for LoRA update generation
            with torch.no_grad():
                context_embeddings = base_model.model.embed_tokens(context_ids)
            
            # Generate and apply LoRA updates based on context
            lora_updates = generator(context_embeddings)
            
            # Update LoRA weights
            for module_name, update in lora_updates.items():
                for full_name, module in target_modules.items():
                    if module_name in full_name:
                        lora_A = update['lora_A'].mean(dim=0)
                        lora_B = update['lora_B'].mean(dim=0)
                        module.lora_A[module.active_adapter].weight.data = lora_A
                        module.lora_B[module.active_adapter].weight.data = lora_B
            
            # Run teacher forcing evaluation
            step_losses, generated_tokens = autoregressive_teacher_forcing(
                base_model=base_model,
                lora_model=lora_model,
                input_ids=context_ids,
                max_steps=max_gen_steps
            )
            
            # Compute mean loss over steps
            loss = step_losses.mean()
            
            # Print step-by-step losses for first batch
            if batch_idx == 0:
                print("\nStep-by-step losses for first sequence:")
                for step, step_loss in enumerate(step_losses[0]):
                    print(f"Step {step}: Loss = {step_loss.item():.4f}")
                
                print("\nGenerated text:")
                print("Context:", tokenizer.decode(context_ids[0]))
                print("Generated:", tokenizer.decode(generated_tokens[0]))
            
            # Optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track statistics
            current_loss = loss.item()
            total_loss += current_loss * len(step_losses)
            total_steps += len(step_losses)
            batch_losses.append(current_loss)
            
            print(f"Batch {batch_idx+1}: Average Loss = {current_loss:.4f}")
        
        # Compute epoch statistics
        avg_loss = total_loss / total_steps
        min_loss = min(batch_losses)
        max_loss = max(batch_losses)
        
        stats = {
            'epoch': epoch + 1,
            'avg_loss': avg_loss,
            'min_loss': min_loss,
            'max_loss': max_loss,
            'total_steps': total_steps
        }
        epoch_stats.append(stats)
        
        print("\nEpoch Statistics:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Min Batch Loss: {min_loss:.4f}")
        print(f"Max Batch Loss: {max_loss:.4f}")
        print(f"Total Generation Steps: {total_steps}")
    
    # Print final summary
    print("\nTraining Summary")
    print("=" * 40)
    print("Epoch  |  Avg Loss  |  Min Loss  |  Max Loss")
    print("-" * 40)
    for stats in epoch_stats:
        print(f"{stats['epoch']:5d}  |  {stats['avg_loss']:.4f}    |  {stats['min_loss']:.4f}    |  {stats['max_loss']:.4f}")


        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

# LoRA config
LORA_RANK = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05

lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=LORA_DROPOUT,
    bias="none",
)

lora_model = get_peft_model(base_model, lora_config).to(device)

# Create generator
generator = LoRAGenerator(
    hidden_size=base_model.config.hidden_size,
    lora_rank=LORA_RANK
)

In [8]:
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import Dataset as HFDataset
import requests
from transformers import AutoTokenizer

class ContextTargetDataset(Dataset):
    def __init__(self, contexts, targets):
        self.contexts = contexts
        self.targets = targets
    
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, idx):
        return {
            "context": self.contexts[idx],
            "target": self.targets[idx]
        }

def prepare_dataset(tokenizer, max_length=512, batch_size=4):
    """
    Prepare dataset and create dataloader for LoRA training
    """
    # Download tiny shakespeare dataset
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    response = requests.get(url)
    text = response.text

    # Split into chunks
    chunk_size = max_length * 2  # To have room for both context and target
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size) if len(text[i:i+chunk_size]) == chunk_size]
    
    # Create dataset from chunks
    data = [{"text": chunk} for chunk in chunks]
    dataset = HFDataset.from_list(data)

    # Initialize lists for contexts and targets
    all_contexts = []
    all_targets = []

    # Tokenization and processing function
    def process_chunk(examples):
        # Tokenize
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=chunk_size,
            return_tensors="pt"
        )

        # Split into context and target
        split_point = max_length
        
        for ids in tokenized["input_ids"]:
            context = ids[:split_point]
            target = ids[split_point:split_point*2]
            all_contexts.append(context)
            all_targets.append(target)

    # Process all examples
    dataset.map(
        process_chunk,
        batched=True,
        remove_columns=dataset.column_names
    )

    # Create custom dataset
    train_dataset = ContextTargetDataset(
        contexts=torch.stack(all_contexts),
        targets=torch.stack(all_targets)
    )

    # Create dataloader
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return train_dataloader

def setup_training(model_name="meta-llama/Llama-3.2-1B", batch_size=4):
    """
    Set up tokenizer and create dataloader
    """
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Set padding token
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    
    # Create dataloader
    train_dataloader = prepare_dataset(
        tokenizer=tokenizer,
        max_length=512,
        batch_size=batch_size
    )
    
    return train_dataloader

# Create dataloader
train_dataloader = setup_training(batch_size=4)

print("DataLoader created successfully!")
print(f"Sample batch shape from dataloader:")
sample_batch = next(iter(train_dataloader))
print(f"Context shape: {sample_batch['context'].shape}")
print(f"Target shape: {sample_batch['target'].shape}")

Map:   0%|          | 0/1089 [00:00<?, ? examples/s]

DataLoader created successfully!
Sample batch shape from dataloader:
Context shape: torch.Size([4, 512])
Target shape: torch.Size([4, 512])


In [6]:
def autoregressive_teacher_forcing(base_model, lora_model, input_ids, max_steps=20):
    """
    Run step-by-step teacher forcing evaluation with gradient tracking
    """
    device = input_ids.device
    batch_size = input_ids.shape[0]
    losses = []
    
    # Start with the input context
    lora_tokens = input_ids.clone()
    
    # Get base model predictions first (without grad since we don't train base model)
    with torch.no_grad():
        base_outputs = base_model(input_ids=lora_tokens)
        base_next_token = base_outputs.logits[:, -1, :].argmax(dim=-1)  # [batch_size]
        base_probs = F.softmax(base_outputs.logits[:, -1, :], dim=-1)  # [batch_size, vocab_size]
    
    # Get LoRA predictions (with grad)
    lora_outputs = lora_model(input_ids=lora_tokens)
    lora_logits = lora_outputs.logits[:, -1, :]  # [batch_size, vocab_size]
    lora_log_probs = F.log_softmax(lora_logits, dim=-1)  # [batch_size, vocab_size]
    
    # Compute KL divergence for this step
    step_loss = F.kl_div(
        lora_log_probs.unsqueeze(1),  # [batch, 1, vocab]
        base_probs.unsqueeze(1),      # [batch, 1, vocab]
        reduction='none'
    ).sum(dim=-1)  # [batch, 1]
    
    return step_loss.mean()  # Return mean loss over batch

def train_lora_generator(
    base_model,
    lora_model,
    generator,
    train_dataloader,
    num_epochs=5,
    learning_rate=1e-4,
):
    device = next(base_model.parameters()).device
    generator = generator.to(device)
    optimizer = torch.optim.AdamW(generator.parameters(), lr=learning_rate)
    
    # Initialize target modules
    target_modules = {}
    for name, module in lora_model.named_modules():
        if hasattr(module, 'lora_A'):
            if 'q_proj' in name or 'v_proj' in name:
                target_modules[name] = module
    
    print(f"Found {len(target_modules)} target modules for LoRA updates")
    
    # Track statistics
    epoch_stats = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 20)
        
        total_loss = 0
        batch_losses = []
        num_batches = 0
        
        for batch_idx, batch in enumerate(train_dataloader):
            context_ids = batch["context"].to(device)
            
            # Get context embeddings for LoRA update generation
            context_embeddings = base_model.model.embed_tokens(context_ids)
            
            # Generate and apply LoRA updates based on context
            lora_updates = generator(context_embeddings)
            
            # Update LoRA weights
            for module_name, update in lora_updates.items():
                for full_name, module in target_modules.items():
                    if module_name in full_name:
                        lora_A = update['lora_A'].mean(dim=0)
                        lora_B = update['lora_B'].mean(dim=0)
                        module.lora_A[module.active_adapter].weight.data = lora_A
                        module.lora_B[module.active_adapter].weight.data = lora_B
            
            # Run teacher forcing evaluation (with gradients)
            loss = autoregressive_teacher_forcing(
                base_model=base_model,
                lora_model=lora_model,
                input_ids=context_ids
            )
            
            # Print sample predictions for first batch
            if batch_idx == 0:
                with torch.no_grad():
                    base_out = base_model(input_ids=context_ids)
                    lora_out = lora_model(input_ids=context_ids)
                    base_next = base_out.logits[:, -1, :].argmax(dim=-1)
                    lora_next = lora_out.logits[:, -1, :].argmax(dim=-1)
                    print("\nSample Predictions:")
                    print("Base model next token:", tokenizer.decode(base_next[0]))
                    print("LoRA model next token:", tokenizer.decode(lora_next[0]))
            
            # Optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track statistics
            current_loss = loss.item()
            total_loss += current_loss
            batch_losses.append(current_loss)
            num_batches += 1
            
            print(f"Batch {batch_idx+1}: Loss = {current_loss:.4f}")
        
        # Compute epoch statistics
        avg_loss = total_loss / num_batches
        min_loss = min(batch_losses)
        max_loss = max(batch_losses)
        
        stats = {
            'epoch': epoch + 1,
            'avg_loss': avg_loss,
            'min_loss': min_loss,
            'max_loss': max_loss,
            'num_batches': num_batches
        }
        epoch_stats.append(stats)
        
        print("\nEpoch Statistics:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Min Loss: {min_loss:.4f}")
        print(f"Max Loss: {max_loss:.4f}")
        print(f"Number of Batches: {num_batches}")
    
    # Print final summary
    print("\nTraining Summary")
    print("=" * 40)
    print("Epoch  |  Avg Loss  |  Min Loss  |  Max Loss")
    print("-" * 40)
    for stats in epoch_stats:
        print(f"{stats['epoch']:5d}  |  {stats['avg_loss']:.4f}    |  {stats['min_loss']:.4f}    |  {stats['max_loss']:.4f}")


In [2]:
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from datasets import Dataset as HFDataset
import requests

def prepare_mini_dataset(tokenizer, max_length=512, batch_size=4, num_samples=20):
    """
    Prepare a small dataset for testing, using only num_samples examples
    """
    # Download tiny shakespeare dataset
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    response = requests.get(url)
    text = response.text

    # Split into chunks
    chunk_size = max_length * 2  # For both context and target
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size) 
             if len(text[i:i+chunk_size]) == chunk_size]
    
    # Take only the first num_samples chunks
    chunks = chunks[:num_samples]
    
    # Create dataset
    data = [{"text": chunk} for chunk in chunks]
    dataset = HFDataset.from_list(data)

    # Initialize lists for contexts and targets
    all_contexts = []
    all_targets = []

    def process_chunk(examples):
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=chunk_size,
            return_tensors="pt"
        )

        # Split each chunk into context and target
        split_point = max_length
        for ids in tokenized["input_ids"]:
            context = ids[:split_point]
            target = ids[split_point:split_point*2]
            all_contexts.append(context)
            all_targets.append(target)

    # Process examples
    dataset.map(
        process_chunk,
        batched=True,
        remove_columns=dataset.column_names
    )

    # Create custom dataset
    train_dataset = ContextTargetDataset(
        contexts=torch.stack(all_contexts),
        targets=torch.stack(all_targets)
    )

    print(f"\nDataset statistics:")
    print(f"Number of samples: {len(train_dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Number of batches: {len(train_dataset) // batch_size}")

    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

class ContextTargetDataset(Dataset):
    def __init__(self, contexts, targets):
        self.contexts = contexts
        self.targets = targets
    
    def __len__(self):
        return len(self.contexts)
    
    def __getitem__(self, idx):
        return {
            "context": self.contexts[idx],
            "target": self.targets[idx]
        }

def setup_mini_training(model_name="meta-llama/Llama-3.2-1B", batch_size=4, num_samples=20):
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    
    # Create mini dataloader
    train_dataloader = prepare_mini_dataset(
        tokenizer=tokenizer,
        max_length=512,
        batch_size=batch_size,
        num_samples=num_samples
    )
    
    return train_dataloader



In [7]:
# Create dataloader and start training
# train_dataloader = setup_training(batch_size=4)
train_dataloader = setup_mini_training(batch_size=10, num_samples=50)  # 5 batches total

# Run training with statistics
train_lora_generator(
    base_model=base_model,
    lora_model=lora_model,
    generator=generator,
    train_dataloader=train_dataloader,
    num_epochs=25,
    learning_rate=1e-4
)

Map:   0%|          | 0/50 [00:00<?, ? examples/s]


Dataset statistics:
Number of samples: 50
Batch size: 10
Number of batches: 5
Found 32 target modules for LoRA updates

Epoch 1/25
--------------------

Sample Predictions:
Base model next token: oth
LoRA model next token: oth
Batch 1: Loss = 0.0831
Batch 2: Loss = 0.0557
Batch 3: Loss = 0.1057
Batch 4: Loss = 0.1141
Batch 5: Loss = 0.2575

Epoch Statistics:
Average Loss: 0.1232
Min Loss: 0.0557
Max Loss: 0.2575
Number of Batches: 5

Epoch 2/25
--------------------

Sample Predictions:
Base model next token: oth
LoRA model next token: oth
Batch 1: Loss = 0.0886
Batch 2: Loss = 0.1606
Batch 3: Loss = 0.0614
Batch 4: Loss = 0.2366
Batch 5: Loss = 0.0533

Epoch Statistics:
Average Loss: 0.1201
Min Loss: 0.0533
Max Loss: 0.2366
Number of Batches: 5

Epoch 3/25
--------------------

Sample Predictions:
Base model next token: oth
LoRA model next token: oth
Batch 1: Loss = 0.3388
Batch 2: Loss = 0.1888
Batch 3: Loss = 0.2961
Batch 4: Loss = 0.1414
Batch 5: Loss = 0.0686

Epoch Statistics:
Av

In [14]:
import torch
from peft import LoraConfig, get_peft_model

# Setup base models first
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B").to(device)

# Configure LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

# Create LoRA model
lora_model = get_peft_model(base_model, lora_config).to(device)

# Inspect shapes
shapes = {}
for name, module in lora_model.named_modules():
    if hasattr(module, 'lora_A'):
        shapes[name] = {
            'lora_A': module.lora_A[module.active_adapter].weight.shape,
            'lora_B': module.lora_B[module.active_adapter].weight.shape,
            'weight': module.weight.shape if hasattr(module, 'weight') else None
        }
        print(f"\nModule: {name}")
        print(f"LoRA A shape: {shapes[name]['lora_A']}")
        print(f"LoRA B shape: {shapes[name]['lora_B']}")
        print(f"Weight shape: {shapes[name]['weight']}")

print("\nShape Summary:")
unique_shapes = {}
for name, shape_info in shapes.items():
    key = (f"A:{shape_info['lora_A']}", f"B:{shape_info['lora_B']}")
    if key not in unique_shapes:
        unique_shapes[key] = []
    unique_shapes[key].append(name.split('.')[-2])  # Get the module type (q_proj or v_proj)

print("\nUnique shape patterns:")
for shapes, modules in unique_shapes.items():
    print(f"\nShapes {shapes} found in modules: {modules}")


Module: base_model.model.model.layers.0.self_attn.q_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([2048, 8])
Weight shape: torch.Size([2048, 2048])

Module: base_model.model.model.layers.0.self_attn.v_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([512, 8])
Weight shape: torch.Size([512, 2048])

Module: base_model.model.model.layers.1.self_attn.q_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([2048, 8])
Weight shape: torch.Size([2048, 2048])

Module: base_model.model.model.layers.1.self_attn.v_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([512, 8])
Weight shape: torch.Size([512, 2048])

Module: base_model.model.model.layers.2.self_attn.q_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([2048, 8])
Weight shape: torch.Size([2048, 2048])

Module: base_model.model.model.layers.2.self_attn.v_proj
LoRA A shape: torch.Size([8, 2048])
LoRA B shape: torch.Size([512, 8])
Weight shape: torch.Size([512,