<center>
    <h1>Fine Tuning</h1>
</center>

# Brief Recap of Fine Tuning

Fine-tuning techniques are specialized methods for adapting pre-trained language models to specific tasks or domains. These techniques have revolutionized NLP by making it more efficient and accessible to fine-tune large language models with limited computational resources.

## Why Traditional Fine-tuning is Challenging

1. **High Computational Cost**: Fine-tuning the entire model requires significant computational resources, as all parameters are updated during training.
  
2. **Large Storage Requirement**: Each fine-tuned model copy occupies substantial storage, which scales poorly with the number of tasks or datasets.

3. **Catastrophic Forgetting**: Updating all parameters can lead to the loss of knowledge from the pre-trained model, making it less effective on tasks outside the fine-tuning domain.

4. **Inefficiency for Large Models**: For large-scale models like GPT or LLaMA, fine-tuning is resource-intensive, requiring extensive GPU/TPU memory.

5. **Limited Adaptability**: Fine-tuned models are specialized for a single task, making reuse for other tasks less feasible without further fine-tuning.

<center>
    <img src="static/image1.gif" alt="Fine Tuning" style="width:50%;">
</center>

# Understanding LoRA (Low-Rank Adaptation)

<center>
    <img src="static/image2.gif" alt="Fine Tuning with LoRA" style="width:50%;">
</center>

## The Problem LoRA Solves

1. **Resource Intensity**
   - Full fine-tuning requires updating all parameters
   - High memory requirements (2-3x model size)
   - Expensive computational resources needed

2. **Storage Overhead**
   - Each fine-tuned version needs full model storage
   - Multiple task adaptations become impractical
   - Version management becomes complex

3. **Training Efficiency**
   - Long training times
   - High energy consumption
   - Limited parallel adaptations

## Core Concepts

### Novel Approach

1. **Low-Rank Decomposition**
   - Represents weight updates as low-rank matrices
   - Uses matrix factorization for efficiency
   - Minimizes parameter count while maintaining performance

2. **Frozen Weights**
   - Original model weights remain unchanged
   - Only train small adaptation matrices
   - Preserves pre-trained knowledge

3. **Parameter-Efficient Updates**
   - Updates through small matrices (A and B)
   - Rank determines compression ratio
   - Trainable parameters reduced significantly

## How LoRA Works

1. **Weight Update Decomposition**:
    ```
    ΔW = BA
    where:
    - ΔW ∈ ℝᵐˣⁿ (weight update)
    - B ∈ ℝᵐˣʳ (first adaptation matrix)
    - A ∈ ℝʳˣⁿ (second adaptation matrix)
    - r is the rank (typically 8, 16, or 32)
    ```

2. **Forward Pass Computation**:
    ```
    Y = XW + α(X(BA))
    where:
    - X is input
    - W is original weights
    - α is scaling factor
    - BA is LoRA update
    ```

3. **Parameter Reduction**:
    ```
    Original parameters: m × n
    LoRA parameters: r × (m + n)
    Reduction ratio: (r × (m + n)) / (m × n)
    ```

## LoRA Implementation Details

In [None]:
import tensorflow as tf
import os
import numpy as np

### Components Overview

In [None]:
class LoRAConfig:
    """
    Configuration class for Low-Rank Adaptation (LoRA).

    LoRA is a technique that reduces the number of trainable parameters by
    injecting low-rank matrices into existing layers while keeping the base
    model weights frozen.

    Attributes:
        rank (int): The rank of the LoRA decomposition (controls adaptation capacity).
        alpha (int): Scaling factor for LoRA updates.
        target_modules (list): The list of module names to apply LoRA to (e.g., attention layers).
        dropout (float): Dropout rate applied to LoRA layers.
    """

    def __init__(self,
                 rank=8,
                 alpha=32,
                 target_modules=None,
                 dropout=0.1):
        """
        Initializes the LoRA configuration.

        Args:
            rank (int, optional): The rank for LoRA decomposition. Default is 8.
            alpha (int, optional): The scaling factor for LoRA updates. Default is 32.
            target_modules (list, optional): List of module names where LoRA should be applied.
                                             Default: ['query', 'key', 'value'].
            dropout (float, optional): Dropout rate applied to LoRA layers. Default is 0.1.
        """
        self.rank = rank  # Defines the rank of the low-rank decomposition
        self.alpha = alpha  # Scaling factor for LoRA updates
        self.target_modules = target_modules or ['query', 'key', 'value']  # Apply LoRA to these layers
        self.dropout = dropout  # Dropout applied to LoRA layers (helps prevent overfitting)


#### Explanation of the above code

**Purpose**

This class serves as a configuration container for LoRA hyperparameters and settings. It centralizes all LoRA-specific parameters in one place for easy management and modification.

**Parameters Explained:**

1. **rank (default=8)**
   - Defines the dimension of low-rank matrices
   - Controls compression ratio and memory savings
   - Lower rank = more compression but potentially less capacity
   - Common values: 8, 16, 32
   - Formula: compression ≈ 2r/(d_in + d_out)

2. **alpha (default=32)**
   - Scaling factor for LoRA updates
   - Controls the magnitude of adaptations
   - Usually set to match or be larger than rank
   - Helps stabilize training
   - Formula: output = original + (alpha * LoRA_output)

3. **target_modules (default=['query', 'key', 'value'])**
   - Specifies which layers to apply LoRA to
   - Defaults to attention mechanism components
   - Can be customized for different architectures
   - Common targets:
     - query: Query projection in attention
     - key: Key projection in attention
     - value: Value projection in attention

4. **dropout (default=0.1)**
   - Dropout rate for LoRA layers
   - Helps prevent overfitting
   - Applied only to LoRA path, not base model
   - Standard range: 0.0-0.5

5. **std - Random weight initialization**
   $$
   \frac{\sqrt{\frac{2.0}{\text{float}(\text{shape}[0])}}}{\text{rank}}
   $$


### Weight Initialization Strategies

In [None]:
class LoRAInitialization:
    """
    Utility class for initializing LoRA (Low-Rank Adaptation) weight matrices.

    LoRA uses low-rank matrices (A and B) to adapt frozen model weights,
    requiring careful initialization for stability and efficiency.

    Methods:
        init_weights_a(shape, rank): Initializes matrix A using scaled Kaiming/He initialization.
        init_weights_b(shape): Initializes matrix B with zeros for stability.
    """

    @staticmethod
    def init_weights_a(shape, rank):
        """
        Initializes the LoRA A matrix using Kaiming/He initialization.

        This ensures that the weight distribution is properly scaled based on 
        the number of input features, helping prevent vanishing/exploding gradients.

        Args:
            shape (tuple): Shape of the weight matrix (input_dim, rank).
            rank (int): The rank for low-rank adaptation.

        Returns:
            tf.Tensor: Initialized weight matrix for LoRA A.
        """
        std = np.sqrt(2.0 / float(shape[0])) / rank  # Scale variance by rank
        return tf.random.normal(shape, stddev=std)

    @staticmethod
    def init_weights_b(shape):
        """
        Initializes the LoRA B matrix with zeros for stability.

        A zero-initialized B matrix ensures that at the start of training,
        the LoRA adaptation does not interfere with the frozen model weights.

        Args:
            shape (tuple): Shape of the weight matrix (rank, output_dim).

        Returns:
            tf.Tensor: Zero-initialized weight matrix for LoRA B.
        """
        return tf.zeros(shape)


#### Explanation of the above code

**Purpose**

This class handles the initialization strategies for the two LoRA matrices (A and B). It uses different initialization approaches for each matrix to ensure stable training and good convergence.

**Methods Explained:**

1. **init_weights_a**
    ```python
    @staticmethod
    def init_weights_a(shape, rank):
        std = np.sqrt(2.0 / float(shape[0])) / rank
        return tf.random.normal(shape, stddev=std)
    ```
      
    - **Purpose**: Initializes the first LoRA matrix (A)
    - **Uses Kaiming/He Initialization**:
      - Designed for ReLU-based networks
      - Helps maintain variance across layers
      - Scaled by rank for stability
    - **Parameters**:
      - shape: Dimensions of matrix A
      - rank: LoRA rank parameter
    - **Formula Breakdown**:
      - `2.0 / float(shape[0])`: He initialization base
      - `/rank`: Additional scaling for LoRA stability
      - Result used as standard deviation for normal distribution

2. **init_weights_b**
    ```python
    @staticmethod
    def init_weights_b(shape):
        return tf.zeros(shape)
    ```

    - **Purpose**: Initializes the second LoRA matrix (B)
    - **Uses Zero Initialization**:
      - Ensures LoRA starts with no initial impact
      - Allows gradual learning of adaptations
      - Promotes stability in early training
    - **Parameters**:
      - shape: Dimensions of matrix B

## Implementing LoRA in TensorFlow

### 1. LoRA Layer Implementation

In [None]:
class LoRALayer(tf.keras.layers.Layer):
    """
    Implements a Low-Rank Adaptation (LoRA) layer for fine-tuning large models efficiently.

    LoRA injects trainable low-rank matrices into a frozen layer, allowing adaptation
    without modifying the original pre-trained weights.

    Attributes:
        original_layer (tf.keras.layers.Layer): The frozen base layer to be adapted.
        rank (int): The rank of the LoRA decomposition (controls adaptation flexibility).
        alpha (int): The scaling factor applied to the LoRA output.
        dropout_rate (float): Dropout rate applied before LoRA transformation.
        lora_a (tf.Variable): Trainable low-rank matrix A.
        lora_b (tf.Variable): Trainable low-rank matrix B.
        dropout (tf.keras.layers.Dropout): Dropout layer applied to LoRA inputs.
    """

    def __init__(self, 
                 original_layer,
                 rank=8,
                 alpha=32,
                 dropout_rate=0.1,
                 **kwargs):
        """
        Initializes the LoRA layer.

        Args:
            original_layer (tf.keras.layers.Layer): The frozen base layer being adapted.
            rank (int, optional): The rank for LoRA decomposition. Default is 8.
            alpha (int, optional): Scaling factor for LoRA updates. Default is 32.
            dropout_rate (float, optional): Dropout rate applied before LoRA transformation. Default is 0.1.
        """
        super().__init__(**kwargs)
        
        self.original_layer = original_layer  # Store the frozen base layer
        self.rank = rank  # LoRA rank for low-rank decomposition
        self.alpha = alpha  # Scaling factor for LoRA updates
        self.dropout_rate = dropout_rate  # Dropout rate applied before LoRA updates
        
        # Get original weight shape
        self.original_shape = original_layer.get_weights()[0].shape
        
        # Create LoRA matrices (trainable parameters)
        self.lora_a = self._create_lora_matrix("a")
        self.lora_b = self._create_lora_matrix("b")
        
        # Dropout applied before LoRA updates
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
        # Freeze original model weights
        self.original_layer.trainable = False

    def _create_lora_matrix(self, name):
        """
        Creates a trainable LoRA matrix (A or B) using predefined initialization.

        Args:
            name (str): "a" or "b", indicating which LoRA matrix to create.

        Returns:
            tf.Variable: Trainable weight matrix for LoRA adaptation.
        """
        if name == "a":
            shape = (self.original_shape[0], self.rank)  # A: (input_dim, rank)
            initializer = LoRAInitialization.init_weights_a
        else:
            shape = (self.rank, self.original_shape[1])  # B: (rank, output_dim)
            initializer = LoRAInitialization.init_weights_b
            
        return self.add_weight(
            name=f"lora_{name}",
            shape=shape,
            initializer=initializer,
            trainable=True
        )
    
    def call(self, inputs, training=None):
        """
        Forward pass of the LoRA layer.

        Args:
            inputs (tf.Tensor): The input tensor.
            training (bool, optional): Whether the model is in training mode.

        Returns:
            tf.Tensor: Output tensor after applying LoRA adaptation.
        """
        # Compute original output using frozen weights
        original_output = self.original_layer(inputs)
        
        # Apply dropout before LoRA transformation (only in training mode)
        lora_input = self.dropout(inputs, training=training) if training else inputs
        
        # Compute the LoRA adaptation
        lora_output = tf.matmul(
            tf.matmul(lora_input, self.lora_a),  # First projection (low-rank A)
            self.lora_b  # Second projection (low-rank B)
        )
        
        # Combine the original output with LoRA adaptation (scaled by alpha)
        return original_output + (self.alpha * lora_output)


#### Explanation of the above code

1. Class Initialization
    ```python
    def __init__(self, 
                original_layer,
                rank=8,
                alpha=32,
                dropout_rate=0.1,
                **kwargs):
        super().__init__(**kwargs)
    ```
    - Inherits from TensorFlow's base Layer class
    - Takes original layer and LoRA parameters
    - Parameters:
    - original_layer: Base layer to adapt
    - rank: Dimension of low-rank matrices
    - alpha: Scaling factor
    - dropout_rate: Regularization strength

2. Setup and Initialization
    ```python
    # Store parameters
    self.original_layer = original_layer
    self.rank = rank
    self.alpha = alpha
    self.dropout_rate = dropout_rate

    # Get shape from original layer
    self.original_shape = original_layer.get_weights()[0].shape
    ```
    - Stores configuration parameters
    - Extracts shape from original layer weights
    - Prepares for LoRA matrix creation

3. Matrix Creation Helper
    ```python
    def _create_lora_matrix(self, name):
        if name == "a":
            shape = (self.original_shape[0], self.rank)
            initializer = LoRAInitialization.init_weights_a
        else:
            shape = (self.rank, self.original_shape[1])
            initializer = LoRAInitialization.init_weights_b
    ```
    - Creates LoRA matrices A and B
    - Matrix A: input_dim × rank
    - Matrix B: rank × output_dim
    - Uses different initializations for each matrix

4. Forward Pass Implementation
    ```python
    def call(self, inputs, training=None):
        # Original transformation
        original_output = self.original_layer(inputs)
        
        # LoRA path with dropout
        lora_input = inputs
        if training:
            lora_input = self.dropout(lora_input, training=training)
        
        # LoRA transformation
        lora_output = tf.matmul(
            tf.matmul(lora_input, self.lora_a),
            self.lora_b
        )
    ```
    - Implements forward pass computation
    - Steps:
    1. Compute original layer output
    2. Apply dropout during training
    3. Compute LoRA transformation
    4. Combine results with scaling

#### Key Components:

1. Original Layer Handling
    ```python
    self.original_layer = original_layer
    self.original_layer.trainable = False
    ```
    - Stores original layer
    - Freezes original weights

2. LoRA Matrices
    ```python
    self.lora_a = self._create_lora_matrix("a")
    self.lora_b = self._create_lora_matrix("b")
    ```
    - Creates two trainable matrices
    - Different initialization strategies
    - Shapes determined by original layer

3. Dropout Implementation
    ```python
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    ```
    - Adds regularization
    - Only applied during training
    - Applied to LoRA path only

4. Forward Pass Logic
    ```python
    return original_output + (self.alpha * lora_output)
    ```
    - Combines original and LoRA paths
    - Scales LoRA contribution
    - Maintains original layer behavior

### 2. Model Adapter Implementation

In [None]:
class LoRAModelAdapter:
    """
    A utility class for adapting a pre-trained model with LoRA (Low-Rank Adaptation).

    This adapter replaces specific layers (e.g., attention layers) with LoRA-modified
    versions while keeping the base model frozen. It enables efficient fine-tuning
    without updating the full model parameters.

    Attributes:
        model (tf.keras.Model): The original pre-trained model to be adapted.
        config (LoRAConfig): Configuration object specifying LoRA parameters.
        lora_layers (list): A list storing all applied LoRA layers.
    """

    def __init__(self,
                 model,
                 config: LoRAConfig):
        """
        Initializes the LoRAModelAdapter.

        Args:
            model (tf.keras.Model): The original model to be adapted with LoRA.
            config (LoRAConfig): Configuration object specifying LoRA parameters.
        """
        self.model = model  # Store the base model
        self.config = config  # Store the LoRA configuration
        self.lora_layers = []  # Track all applied LoRA layers

    def adapt_layer(self, layer):
        """
        Applies LoRA adaptation to a single layer if it matches the criteria.

        Args:
            layer (tf.keras.layers.Layer): The layer to check and adapt.

        Returns:
            tf.keras.layers.Layer: The original or LoRA-modified layer.
        """
        if isinstance(layer, tf.keras.layers.Dense):  # Only apply LoRA to Dense layers
            return LoRALayer(
                layer,
                rank=self.config.rank,
                alpha=self.config.alpha,
                dropout_rate=self.config.dropout
            )
        return layer  # Return unchanged if the layer is not adapted

    def create_adapted_model(self):
        """
        Creates a new model with LoRA adaptations applied to target layers.

        The function clones the original model and replaces specified layers
        (e.g., 'query', 'key', 'value') with LoRA-modified versions.

        Returns:
            tf.keras.Model: A new model with LoRA adaptations applied.
        """
        def clone_function(layer):
            """
            Function used to modify layers during model cloning.

            Args:
                layer (tf.keras.layers.Layer): The layer being cloned.

            Returns:
                tf.keras.layers.Layer: The modified or unchanged layer.
            """
            if any(name in layer.name for name in self.config.target_modules):
                adapted_layer = self.adapt_layer(layer)
                if isinstance(adapted_layer, LoRALayer):
                    self.lora_layers.append(adapted_layer)  # Track adapted layers
                return adapted_layer
            return layer  # Return unchanged layer if not modified
        
        # Clone the model while replacing specified layers with LoRA layers
        adapted_model = tf.keras.models.clone_model(
            self.model,
            clone_function=clone_function
        )
        
        return adapted_model


#### Explanation of the above code

1. Class Initialization
    ```python
    def __init__(self, model, config: LoRAConfig):
        self.model = model
        self.config = config
        self.lora_layers = []
    ```
    - **Purpose**: Initializes the adapter with:
    - model: Original model to adapt
    - config: LoRA configuration settings
    - lora_layers: Tracks created LoRA layers
    - **Type Hint**: Expects LoRAConfig object for configuration

2. Layer Adaptation Method
    ```python
    def adapt_layer(self, layer):
        """Apply LoRA adaptation to a single layer"""
        if isinstance(layer, tf.keras.layers.Dense):
            return LoRALayer(
                layer,
                rank=self.config.rank,
                alpha=self.config.alpha,
                dropout_rate=self.config.dropout
            )
        return layer
    ```
    - **Purpose**: Converts single layer to LoRA version
    - **Process**:
    1. Checks if layer is Dense type
    2. Creates LoRA version if applicable
    3. Returns original layer if not Dense
    - **Parameters**: Uses configuration values for:
    - rank
    - alpha
    - dropout_rate

3. Model Adaptation Method
    ```python
    def create_adapted_model(self):
        """Create a new model with LoRA adaptations"""
        def clone_function(layer):
            if any(name in layer.name 
                for name in self.config.target_modules):
                adapted_layer = self.adapt_layer(layer)
                if isinstance(adapted_layer, LoRALayer):
                    self.lora_layers.append(adapted_layer)
                return adapted_layer
            return layer
        
        adapted_model = tf.keras.models.clone_model(
            self.model,
            clone_function=clone_function
        )
        
        return adapted_model
    ```
    - **Purpose**: Creates complete LoRA-adapted model
    - **Process**:
    1. Defines clone function for layer handling
    2. Checks layer names against target modules
    3. Adapts matching layers
    4. Tracks created LoRA layers
    5. Clones entire model with adaptations

#### Key Features:

1. Selective Adaptation
    ```python
    if any(name in layer.name for name in self.config.target_modules)
    ```
    - Only adapts specified layers
    - Maintains original architecture
    - Configurable targeting

2. Layer Tracking
    ```python
    self.lora_layers.append(adapted_layer)
    ```
    - Keeps record of LoRA layers
    - Enables monitoring
    - Facilitates management

3. Model Preservation
    ```python
    adapted_model = tf.keras.models.clone_model(...)
    ```
    - Creates new model instance
    - Preserves original model
    - Safe adaptation process

### 3. Training Manager

In [None]:
class LoRATrainingManager:
    """
    A training manager for fine-tuning models with LoRA (Low-Rank Adaptation).

    This class manages optimization, loss tracking, and gradient updates while ensuring
    that only LoRA parameters are updated, keeping the original model weights frozen.

    Attributes:
        model (tf.keras.Model): The LoRA-adapted model to be trained.
        learning_rate (float): The learning rate for the optimizer.
        weight_decay (float): The weight decay coefficient for AdamW.
        optimizer (tf.keras.optimizers.AdamW): Optimizer for updating LoRA parameters.
        loss_tracker (tf.keras.metrics.Mean): Tracks loss values across training steps.
    """

    def __init__(self,
                 model,
                 learning_rate=1e-4,
                 weight_decay=0.01):
        """
        Initializes the LoRATrainingManager.

        Args:
            model (tf.keras.Model): The model to be trained with LoRA.
            learning_rate (float, optional): The learning rate for the optimizer. Default is 1e-4.
            weight_decay (float, optional): Weight decay factor for AdamW. Default is 0.01.
        """
        self.model = model  # Store the LoRA-adapted model
        self.learning_rate = learning_rate  # Learning rate for optimizer
        self.weight_decay = weight_decay  # Weight decay to prevent overfitting
        
        # Create an optimizer with weight decay
        self.optimizer = self._create_optimizer()
        
        # Loss tracker to monitor training progress
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')

    def _create_optimizer(self):
        """
        Creates an AdamW optimizer for training LoRA parameters.

        Returns:
            tf.keras.optimizers.AdamW: Optimizer configured with learning rate and weight decay.
        """
        return tf.keras.optimizers.AdamW(
            learning_rate=self.learning_rate,
            weight_decay=self.weight_decay
        )

    @tf.function
    def train_step(self, inputs, labels):
        """
        Performs a single training step, updating only LoRA parameters.

        Args:
            inputs (tf.Tensor): Input tensor (e.g., tokenized sequences).
            labels (tf.Tensor): Target tensor (e.g., ground-truth labels).

        Returns:
            dict: A dictionary containing the loss value for tracking.
        """
        with tf.GradientTape() as tape:
            # Forward pass
            predictions = self.model(inputs, training=True)
            # Compute loss
            loss = self.compute_loss(labels, predictions)
            
        # Get trainable variables (only LoRA parameters)
        trainable_vars = [var for var in self.model.trainable_variables
                          if 'lora_' in var.name]  # Ensures we only update LoRA layers
        
        # Compute gradients
        gradients = tape.gradient(loss, trainable_vars)
        
        # Apply gradients using AdamW optimizer
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Update loss tracking metric
        self.loss_tracker.update_state(loss)
        
        return {
            "loss": self.loss_tracker.result()
        }

    def compute_loss(self, labels, predictions):
        """
        Computes the loss function for training.

        Args:
            labels (tf.Tensor): Ground-truth labels.
            predictions (tf.Tensor): Model predictions.

        Returns:
            tf.Tensor: Computed loss value.
        """
        return tf.keras.losses.sparse_categorical_crossentropy(labels, predictions, from_logits=True)


#### Explanation of the above code

1. Class Initialization
    ```python
    def __init__(self,
                model,
                learning_rate=1e-4,
                weight_decay=0.01):
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        
        self.optimizer = self._create_optimizer()
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
    ```
    - **Purpose**: Sets up training environment
    - **Parameters**:
    - model: LoRA-adapted model
    - learning_rate: Training rate (default: 0.0001)
    - weight_decay: L2 regularization (default: 0.01)
    - **Components**:
    - Creates optimizer
    - Initializes loss tracking

2. Optimizer Creation
    ```python
    def _create_optimizer(self):
        return tf.keras.optimizers.AdamW(
            learning_rate=self.learning_rate,
            weight_decay=self.weight_decay
        )
    ```
    - **Purpose**: Initializes AdamW optimizer
    - **Features**:
    - Adaptive learning rates
    - Weight decay regularization
    - Momentum-based updates

3. Training Step Implementation
    ```python
    @tf.function  # Compiler decorator for performance
    def train_step(self, inputs, labels):
        with tf.GradientTape() as tape:
            # Forward pass
            predictions = self.model(inputs, training=True)
            # Calculate loss
            loss = self.compute_loss(labels, predictions)
    ```
    - **Purpose**: Executes single training iteration
    - **Process**:
    1. Records operations for gradient computation
    2. Performs forward pass
    3. Calculates loss

4. Gradient Computation and Application
    ```python
    # Get trainable variables (only LoRA parameters)
    trainable_vars = [var for var in self.model.trainable_variables
                    if 'lora_' in var.name]

    # Compute gradients
    gradients = tape.gradient(loss, trainable_vars)

    # Apply gradients
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    ```
    - **Purpose**: Updates LoRA parameters
    - **Features**:
    - Selects only LoRA variables
    - Computes gradients
    - Applies updates

#### Key Components:

1. Loss Tracking
    ```python
    self.loss_tracker = tf.keras.metrics.Mean(name='loss')
    self.loss_tracker.update_state(loss)
    ```
    - Maintains running average of loss
    - Tracks training progress
    - Returns current metrics

2. LoRA Parameter Selection
    ```python
    trainable_vars = [var for var in self.model.trainable_variables
                    if 'lora_' in var.name]
    ```
    - Filters for LoRA parameters
    - Ignores frozen base model
    - Efficient update process

3. Gradient Management
    ```python
    gradients = tape.gradient(loss, trainable_vars)
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    ```
    - Computes parameter updates
    - Applies optimization steps
    - Manages learning process

### 4. Complete Training Pipeline

In [None]:
def train_with_lora(
    base_model,
    train_dataset,
    validation_dataset,
    config: LoRAConfig
):
    """
    Trains a model using LoRA (Low-Rank Adaptation) with a given dataset.

    This function wraps a base model with LoRA, fine-tunes only the LoRA parameters,
    and runs a training loop while evaluating on a validation set.

    Args:
        base_model (tf.keras.Model): The original pre-trained model.
        train_dataset (tf.data.Dataset): The training dataset containing input-label pairs.
        validation_dataset (tf.data.Dataset): The validation dataset for evaluation.
        config (LoRAConfig): Configuration specifying LoRA parameters and training settings.

    Returns:
        tf.keras.Model: The trained LoRA-adapted model.
    """

    # Create LoRA adapter and generate an adapted model
    adapter = LoRAModelAdapter(base_model, config)
    adapted_model = adapter.create_adapted_model()
    
    # Initialize training manager
    trainer = LoRATrainingManager(adapted_model, learning_rate=config.learning_rate)

    # Training loop
    for epoch in range(config.epochs):
        print(f"Epoch {epoch + 1}/{config.epochs}")

        # Train on all batches
        for batch in train_dataset:
            metrics = trainer.train_step(
                batch['input_ids'],  # Input sequence
                batch['labels']      # Corresponding labels
            )
        
        # Validate on the entire validation set
        val_metrics = trainer.evaluate(validation_dataset)

        # Print metrics
        print(f"Training loss: {metrics['loss']:.4f}")
        print(f"Validation loss: {val_metrics['loss']:.4f}")

    return adapted_model  # Return the trained model


#### Explanation of the code

1. Function Definition
    ```python
    def train_with_lora(
        base_model,
        train_dataset,
        validation_dataset,
        config: LoRAConfig
    ):
    ```
    - **Purpose**: Main training pipeline for LoRA
    - **Parameters**:
    - base_model: Original model to adapt
    - train_dataset: Training data
    - validation_dataset: Validation data
    - config: LoRA configuration settings

2. Model Adaptation
    ```python
    # Create LoRA adapter
    adapter = LoRAModelAdapter(base_model, config)
    adapted_model = adapter.create_adapted_model()
    ```
    - **Purpose**: Sets up LoRA-adapted model
    - **Process**:
    1. Creates adapter instance
    2. Applies LoRA to specified layers
    3. Returns adapted model

3. Training Setup
    ```python
    # Initialize training manager
    trainer = LoRATrainingManager(adapted_model)
    ```
    - **Purpose**: Prepares training environment
    - **Features**:
    - Sets up optimizer
    - Initializes loss tracking
    - Manages training state

4. Training Loop
    ```python
    # Training loop
    for epoch in range(config.epochs):
        print(f"Epoch {epoch + 1}/{config.epochs}")
        
        # Train
        for batch in train_dataset:
            metrics = trainer.train_step(
                batch['input_ids'],
                batch['labels']
            )
    ```
    - **Purpose**: Executes training process
    - **Components**:
    - Epoch iteration
    - Batch processing
    - Metrics collection

5. Validation
    ```python
    # Validate
    val_metrics = trainer.evaluate(validation_dataset)
    ```
    - **Purpose**: Evaluates model performance
    - **Process**:
    - Runs validation data
    - Computes metrics
    - Tracks progress

6. Progress Reporting
    ```python
    # Print metrics
    print(f"Training loss: {metrics['loss']:.4f}")
    print(f"Validation loss: {val_metrics['loss']:.4f}")
    ```
    - **Purpose**: Monitors training progress
    - **Output**:
    - Training loss
    - Validation loss
    - Formatted metrics

# Understanding Q-LoRA: Quantized Low-Rank Adaptation

<center>
    <img src="static/image3.gif" alt="Q-LoRA" style="width:50%;">
</center>

## The Problem Q-LoRA Solves

1. **Memory Constraints**
   - Even LoRA requires full-precision model weights
   - 16-bit models still consume significant memory
   - Limited by GPU VRAM during training

2. **Hardware Limitations**
   - Most consumer GPUs can't handle large models
   - Training requires expensive specialized hardware
   - Multiple GPUs often needed for fine-tuning

3. **Accessibility Issues**
   - Research limited by hardware requirements
   - High computational costs
   - Resource-intensive deployment

## Core Concepts

### Novel Approach

1. **4-bit Quantization**
   - Reduces model precision from 16/32-bit to 4-bit
   - Uses special NormalFloat (NF4) format
   - Maintains model quality despite compression

2. **Double Quantization**
   - Quantizes both weights and quantization constants
   - Further reduces memory footprint
   - Minimal impact on model performance

3. **Paged Attention**
   - Efficient memory management
   - CPU offloading for attention computations
   - Dynamic memory allocation

### Key Characteristics

1. **Memory Efficiency**
   - 85% memory reduction compared to full fine-tuning
   - Enables training on consumer GPUs
   - Supports larger context windows

2. **Quality Preservation**
   - Maintains model performance
   - Comparable results to full fine-tuning
   - Stable training process

3. **Accessibility**
   - Works on single GPU setups
   - Reduces hardware requirements
   - Enables broader research participation

## How Q-LoRA Works

1. **NormalFloat (NF4) Quantization**:
    ```
    Q(x) = s * round(clamp(x/s, -1, 1) * (2^b - 1)) / (2^b - 1)
    where:
    - x is the original value
    - s is the scaling factor
    - b is bits (4 for NF4)
    ```

2. **Double Quantization**:
    ```
    First level: W_q = Q1(W, s1)
    Second level: s_q = Q2(s1, s2)
    where:
    - W is original weights
    - Q1, Q2 are quantization functions
    - s1, s2 are scaling factors
    ```

3. **Gradient Computation**:
    ```
    ∂L/∂W = (∂L/∂W_q) * (∂W_q/∂W)
    where:
    - L is loss function
    - W_q is quantized weights
    - Straight-through estimator for gradients
    ```

## Q-LoRA Implementation Details

### Components Overview

In [None]:
class QLoRAConfig:
    """
    Configuration class for QLoRA (Quantized Low-Rank Adaptation).
    
    QLoRA is an extension of LoRA that applies quantization (e.g., 4-bit) to 
    reduce memory usage while preserving fine-tuning efficiency. This class 
    holds hyperparameters for the quantization process.
    """
    
    def __init__(self,
                 bits=4,
                 group_size=128,
                 double_quant=True,
                 quant_type="nf4"):
        """
        Initializes the QLoRA configuration.

        Args:
            bits (int): Number of bits for quantization. Default is 4 (4-bit quantization).
            group_size (int): The size of groups for quantization. 
                              A smaller group size improves precision but increases memory usage.
            double_quant (bool): Whether to use double quantization (quantizing the quantization constants).
                                 Helps reduce memory footprint while maintaining performance.
            quant_type (str): The type of quantization format. 
                              Common choices include:
                              - "nf4" (Normal Float 4), a format designed for efficient low-bit quantization.
                              - "fp4" (Float 4), another floating-point-based 4-bit quantization.
        """
        self.bits = bits  # Number of bits for weight quantization (e.g., 4-bit)
        self.group_size = group_size  # Group size for quantization; controls trade-off between accuracy and efficiency
        self.double_quant = double_quant  # Enables double quantization (quantizing quantization constants)
        self.quant_type = quant_type  # Specifies the quantization type, e.g., "nf4" (Normal Float 4)


### Quantization Implementation

In [None]:
import numpy as np

class NF4Quantizer:
    """
    A quantizer that implements Normal Float 4 (NF4) quantization.
    
    NF4 is a 4-bit floating-point quantization scheme optimized for 
    low-bit representation while maintaining numerical stability.
    The range of NF4 is (-1, +1)
    """
    
    def __init__(self):
        """
        Initializes the NF4 quantizer with predefined quantization levels.
        
        The NF4 levels are chosen based on a non-uniform distribution to 
        better capture the important numerical ranges in LLM fine-tuning.
        """
        # Predefined NF4 quantization levels (non-uniformly spaced)
        self.levels = np.array([
            -1.0, -0.72, -0.34, -0.11,  # Negative values
            0.0, 0.11, 0.34, 0.72, 1.0   # Positive values
        ])
    
    def quantize(self, x):
        """
        Quantizes the input tensor using NF4 quantization levels.
        
        Args:
            x (numpy array): Input tensor to be quantized.
        
        Returns:
            numpy array: Quantized tensor where each value is mapped 
                         to the nearest NF4 quantization level.
        """
        # Find the nearest NF4 quantization level for each element in x
        indices = np.digitize(x, self.levels) - 1  # Get index of closest bin
        return self.levels[indices]  # Map values to the nearest quantization level


## Implementing Q-LoRA in TensorFlow

### Quantized Layer Implementation

### **1. Quantization Formula**
The real-valued input \( x \) is quantized into an integer representation \( x_q \):

$$
x_q = \text{round} \left( \frac{x}{s} \times (2^b - 1) \right) - z
$$

where:
- \( x_q \) → Quantized integer value
- \( x \) → Original floating-point weight
- \( s \) → Scale factor (learned or computed)
- \( z \) → Zero-point offset
- \( b \) → Number of bits (e.g., \( b = 4 \) for 4-bit quantization)
- round → Rounds to the nearest integer

For **4-bit quantization**, the range is:

$$
2^4 - 1 = 15
$$

Thus, simplifying:

$$
x_q = \text{round} \left( \frac{x}{s} \times 15 \right) - z
$$

---

### **2. Dequantization Formula**
To recover the floating-point value \( \hat{x} \) from the quantized integer \( x_q \):

$$
\hat{x} = (x_q + z) \times s
$$

where:
- \( \hat{x} \) → Recovered floating-point value (approximation of \( x \))
- \( x_q \) → Quantized integer value
- \( s \) → Scale factor
- \( z \) → Zero-point offset

---

### **3. Clipping (Ensuring the Range)**
Before quantization, values are clipped to ensure they stay within the allowed range:

$$
x_{\text{clipped}} = \text{clip}(x/s, -1, 1)
$$

This ensures that out-of-range values do not cause errors in quantization.


In [None]:
import tensorflow as tf

class QLoRALayer(tf.keras.layers.Layer):
    """
    QLoRA Layer for applying Low-Rank Adaptation (LoRA) with quantization.

    This layer replaces a dense layer with a quantized version and 
    injects trainable low-rank adaptation matrices (LoRA) while keeping 
    the original weights frozen.
    """

    def __init__(self, 
                 original_layer,
                 rank=8,
                 alpha=32,
                 bits=4,
                 group_size=128,
                 **kwargs):
        """
        Initializes a QLoRA layer.

        Args:
            original_layer (tf.keras.layers.Layer): The dense layer being adapted.
            rank (int): The rank for LoRA decomposition (typically small, e.g., 8).
            alpha (int): Scaling factor for LoRA updates (controls adaptation strength).
            bits (int): Number of bits for quantization (e.g., 4-bit quantization).
            group_size (int): Number of weights per quantization group (affects precision).
        """
        super().__init__(**kwargs)
        
        self.original_layer = original_layer  # Store the original dense layer
        self.rank = rank  # LoRA rank for low-rank decomposition
        self.alpha = alpha  # LoRA scaling factor
        self.bits = bits  # Bit precision for quantization
        self.group_size = group_size  # Group size for quantization

        # Initialize quantization parameters
        self.quantizer = self._create_quantizer()
        
        # Get original weight shape
        self.original_shape = original_layer.get_weights()[0].shape
        
        # Store the quantized weights in low-bit format
        self.quantized_weights = self._quantize_weights(original_layer.get_weights()[0])
        
        # Initialize LoRA matrices (learnable parameters)
        self.lora_a = self._create_lora_weights("a")  # LoRA A matrix
        self.lora_b = self._create_lora_weights("b")  # LoRA B matrix
        
        # Freeze original weights (only train LoRA parameters)
        self.original_layer.trainable = False

    def _create_quantizer(self):
        """
        Creates quantization parameters for weight quantization.

        Returns:
            dict: A dictionary containing scale and zero point variables.
        """
        return {
            'scale': tf.Variable(1.0, trainable=False),  # Scale factor for quantization
            'zero_point': tf.Variable(0.0, trainable=False)  # Zero point offset
        }
    
    def _create_lora_weights(self, name):
        """
        Initializes trainable LoRA matrices.

        Args:
            name (str): "a" or "b", indicating which LoRA matrix to create.

        Returns:
            tf.Variable: LoRA weight matrix.
        """
        if name == "a":
            shape = (self.original_shape[0], self.rank)  # A: (input_dim, rank)
        else:
            shape = (self.rank, self.original_shape[1])  # B: (rank, output_dim)
            
        return self.add_weight(
            name=f"lora_{name}",
            shape=shape,
            initializer="zeros",  # Start with zero initialization
            trainable=True  # Only LoRA matrices are trainable
        )
    
    def quantize(self, x):
        """
        Quantizes the input tensor using scale and zero-point.

        Args:
            x (tf.Tensor): The tensor to quantize.

        Returns:
            tf.Tensor: Quantized tensor stored in int8 format.
        """
        scale = self.quantizer['scale']
        zero_point = self.quantizer['zero_point']
        
        range_float = 2.0 ** self.bits - 1.0  # Max integer representation
        x_scaled = tf.clip_by_value(x / scale, -1.0, 1.0)  # Normalize range
        x_scaled_q = tf.round(x_scaled * range_float)  # Convert to integer
        
        return tf.cast(x_scaled_q - zero_point, dtype=tf.int8)  # Store as INT8

    def dequantize(self, x_q):
        """
        Dequantizes an int8 tensor back to floating point.

        Args:
            x_q (tf.Tensor): The quantized int8 tensor.

        Returns:
            tf.Tensor: Floating-point tensor for computation.
        """
        scale = self.quantizer['scale']
        zero_point = self.quantizer['zero_point']
        
        return (tf.cast(x_q, dtype=tf.float32) + zero_point) * scale


    def call(self, inputs):
        """
        Forward pass of the QLoRA layer.

        Args:
            inputs (tf.Tensor): The input tensor.

        Returns:
            tf.Tensor: Output tensor after applying quantized transformation and LoRA adaptation.
        """
        # Quantize the original layer's weights
        q_weights = self.quantize(self.original_layer.weights[0])
        
        # Apply the original transformation with quantized weights
        original_output = tf.matmul(inputs, q_weights)
        
        # Compute the LoRA update
        lora_output = tf.matmul(
            tf.matmul(inputs, self.lora_a),  # First projection (low-rank A)
            self.lora_b  # Second projection (low-rank B)
        )
        
        # Combine the original output with LoRA adaptation (scaled by alpha)
        return original_output + (self.alpha * lora_output)


#### Explanation of the code

1. Class Initialization
    ```python
    def __init__(self, 
                original_layer,
                rank=8,
                alpha=32,
                bits=4,
                group_size=128,
                **kwargs):
    ```
    - **Purpose**: Initializes quantized LoRA layer
    - **Parameters**:
    - original_layer: Base layer to adapt
    - rank: LoRA rank dimension
    - alpha: Scaling factor
    - bits: Quantization precision (default 4-bit)
    - group_size: Quantization group size

2. Quantizer Creation
    ```python
    def _create_quantizer(self):
        return {
            'scale': tf.Variable(1.0, trainable=False),
            'zero_point': tf.Variable(0.0, trainable=False)
        }
    ```
    - **Purpose**: Sets up quantization parameters
    - **Components**:
    - scale: Scaling factor for quantization
    - zero_point: Offset for quantization
    - **Features**: Non-trainable variables

3. Weight Creation
    ```python
    def _create_lora_weights(self, name):
        if name == "a":
            shape = (self.original_shape[0], self.rank)
        else:
            shape = (self.rank, self.original_shape[1])
            
        return self.add_weight(
            name=f"lora_{name}",
            shape=shape,
            initializer="zeros",
            trainable=True
        )
    ```
    - **Purpose**: Creates LoRA matrices
    - **Features**:
    - Matrix A: input_dim × rank
    - Matrix B: rank × output_dim
    - Trainable parameters

4. Quantization Implementation
    ```python
    def quantize(self, x):
        scale = self.quantizer['scale']
        zero_point = self.quantizer['zero_point']
        
        range_float = 2.0 ** self.bits - 1.0
        x_scaled = tf.clip_by_value(x / scale, -1.0, 1.0)
        x_scaled_q = tf.round(x_scaled * range_float)
        return (x_scaled_q - zero_point) * scale
    ```
    - **Purpose**: Implements weight quantization
    - **Process**:
    1. Scale input values
    2. Clip to range [-1, 1]
    3. Quantize to specified bits
    4. Rescale to original range

5. Forward Pass Implementation
    ```python
    def call(self, inputs):
        # Quantize original weights
        q_weights = self.quantize(self.original_layer.weights[0])
        
        # Original path with quantized weights
        original_output = tf.matmul(inputs, q_weights)
        
        # LoRA path
        lora_output = tf.matmul(
            tf.matmul(inputs, self.lora_a),
            self.lora_b
        )
        
        # Combine outputs
        return original_output + (self.alpha * lora_output)
    ```
    - **Purpose**: Executes forward pass
    - **Steps**:
    1. Quantize original weights
    2. Compute original path
    3. Compute LoRA path
    4. Combine results

### Memory Management Implementation

In [None]:
class PagedAttention:
    """
    Implements memory-efficient attention by splitting computation into manageable chunks.
    
    Paged Attention is useful for handling large inputs that do not fit into memory
    by processing smaller chunks sequentially while maintaining correctness.

    Attributes:
        max_memory (int): The maximum amount of memory available for computation.
        cache (dict): A cache for storing intermediate results if needed.
    """

    def __init__(self, max_memory=None):
        """
        Initializes the PagedAttention module.

        Args:
            max_memory (int, optional): The maximum memory available for processing.
                                        This is used to determine the optimal chunk size.
        """
        self.max_memory = max_memory
        self.cache = {}  # Reserved for potential caching of attention computations
    
    def compute(self, query, key, value):
        """
        Computes attention using chunked processing for memory efficiency.

        Args:
            query (tf.Tensor): The query tensor of shape (batch_size, seq_len, dim).
            key (tf.Tensor): The key tensor of shape (batch_size, seq_len, dim).
            value (tf.Tensor): The value tensor of shape (batch_size, seq_len, dim).

        Returns:
            tf.Tensor: The output tensor after applying attention, with the same shape
                       as the query (batch_size, seq_len, dim).
        """
        batch_size = tf.shape(query)[0]

        # Determine chunk size based on available memory
        chunk_size = self._calculate_chunk_size(query)
        num_chunks = tf.shape(query)[1] // chunk_size

        outputs = []
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = (i + 1) * chunk_size

            # Process the chunked portion of the sequence
            chunk_output = self._process_chunk(
                query[:, start_idx:end_idx], key, value
            )
            outputs.append(chunk_output)

        # Concatenate all processed chunks to reconstruct the full sequence output
        return tf.concat(outputs, axis=1)
    
    def _calculate_chunk_size(self, tensor):
        """
        Computes the optimal chunk size for attention computation based on memory constraints.

        Args:
            tensor (tf.Tensor): The input tensor to determine chunking strategy.

        Returns:
            int: The computed chunk size to fit within the memory limit.
        """
        element_size = tensor.dtype.size  # Get the size of each element in bytes
        return min(
            tf.shape(tensor)[1],  # Ensure chunk size does not exceed sequence length
            self.max_memory // (element_size * tf.shape(tensor)[2])  # Memory-based limit
        )
    
    def _process_chunk(self, query_chunk, key, value):
        """
        Computes scaled dot-product attention for a single chunk.

        Args:
            query_chunk (tf.Tensor): A chunk of the query tensor (batch_size, chunk_size, dim).
            key (tf.Tensor): The full key tensor (batch_size, seq_len, dim).
            value (tf.Tensor): The full value tensor (batch_size, seq_len, dim).

        Returns:
            tf.Tensor: The attention-weighted value tensor for the chunk.
        """
        # Compute scaled dot-product attention scores
        scores = tf.matmul(query_chunk, key, transpose_b=True)
        scores = scores / tf.sqrt(tf.cast(tf.shape(key)[-1], tf.float32))  # Scale by sqrt(dim)
        attention = tf.nn.softmax(scores, axis=-1)  # Apply softmax to get attention weights
        
        # Multiply attention scores with value tensor
        return tf.matmul(attention, value)


#### Explanation of the code

1. Class Initialization
    ```python
    def __init__(self, max_memory=None):
        self.max_memory = max_memory
        self.cache = {}
    ```
    - **Purpose**: Initializes paged attention system
    - **Parameters**:
    - max_memory: Memory limit for chunks
    - **Features**: 
    - Caching mechanism
    - Memory management

2. Main Computation Method
    ```python
    def compute(self, query, key, value):
        batch_size = tf.shape(query)[0]
        
        # Split computation into manageable chunks
        chunk_size = self._calculate_chunk_size(query)
        num_chunks = tf.shape(query)[1] // chunk_size
    ```
    - **Purpose**: Manages chunked attention computation
    - **Process**:
    1. Determines chunk size
    2. Calculates number of chunks
    3. Processes each chunk separately

3. Chunk Size Calculation
    ```python
    def _calculate_chunk_size(self, tensor):
        element_size = tensor.dtype.size
        return min(
            tf.shape(tensor)[1],
            self.max_memory // (element_size * tf.shape(tensor)[2])
        )
    ```
    - **Purpose**: Determines optimal chunk size
    - **Factors**:
    - Memory limit
    - Element size
    - Tensor dimensions

4. Chunk Processing
    ```python
    def _process_chunk(self, query_chunk, key, value):
        # Compute attention for chunk
        scores = tf.matmul(query_chunk, key, transpose_b=True)
        scores = scores / tf.sqrt(tf.cast(tf.shape(key)[-1], tf.float32))
        attention = tf.nn.softmax(scores, axis=-1)
        return tf.matmul(attention, value)
    ```
    - **Purpose**: Processes individual attention chunks
    - **Steps**:
    1. Compute attention scores
    2. Apply scaling factor
    3. Calculate softmax
    4. Compute final values

#### Key Components:

1. Memory Management
    ```python
    chunk_size = self._calculate_chunk_size(query)
    num_chunks = tf.shape(query)[1] // chunk_size
    ```
    - Manages memory usage
    - Prevents OOM errors
    - Optimizes chunk size

2. Chunked Processing
    ```python
    outputs = []
    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = (i + 1) * chunk_size
        
        chunk_output = self._process_chunk(
            query[:, start_idx:end_idx],
            key,
            value
        )
        outputs.append(chunk_output)
    ```
    - Processes in manageable chunks
    - Maintains sequence order
    - Accumulates results

3. Attention Computation
    ```python
    scores = tf.matmul(query_chunk, key, transpose_b=True)
    scores = scores / tf.sqrt(tf.cast(tf.shape(key)[-1], tf.float32))
    attention = tf.nn.softmax(scores, axis=-1)
    ```
    - Standard attention mechanism
    - Scaled dot-product attention
    - Memory-efficient implementation

### Model Wrapper Implementation

In [None]:
class QLoRAModelWrapper:
    """
    A wrapper for applying QLoRA (Quantized Low-Rank Adaptation) to a base model.

    QLoRA reduces memory usage by quantizing weights and introducing 
    low-rank trainable matrices, allowing efficient fine-tuning of large models.

    Attributes:
        base_model (tf.keras.Model): The original pre-trained model.
        rank (int): The rank of LoRA decomposition (low-rank factor).
        alpha (int): Scaling factor for LoRA updates.
        bits (int): Number of bits for quantization (e.g., 4-bit).
        group_size (int): The size of quantization groups.
        qlora_layers (list): A list storing all applied QLoRA layers.
    """

    def __init__(self,
                 base_model,
                 rank=8,
                 alpha=32,
                 bits=4,
                 group_size=128):
        """
        Initializes the QLoRAModelWrapper.

        Args:
            base_model (tf.keras.Model): The original model to be adapted with QLoRA.
            rank (int, optional): The rank of the LoRA decomposition. Default is 8.
            alpha (int, optional): Scaling factor for LoRA. Default is 32.
            bits (int, optional): Number of bits for weight quantization. Default is 4.
            group_size (int, optional): The size of weight groups for quantization. Default is 128.
        """
        self.base_model = base_model  # Store the original model
        self.rank = rank  # LoRA rank (controls compression level)
        self.alpha = alpha  # LoRA scaling factor
        self.bits = bits  # Bit-width for quantization
        self.group_size = group_size  # Number of grouped weights per quantization
        self.qlora_layers = []  # Store LoRA-applied layers
    
    def apply_qlora(self, layer_names=None):
        """
        Applies QLoRA modifications to specific layers in the model.

        Args:
            layer_names (list, optional): A list of layer name substrings to be adapted with QLoRA.
                                          Default applies LoRA to ['query', 'key', 'value'] layers.

        Returns:
            tf.keras.Model: A new model with QLoRA modifications applied.
        """
        if layer_names is None:
            layer_names = ['query', 'key', 'value']  # Default: Apply QLoRA to attention layers

        def replace_layer(layer):
            """
            Replaces the specified layers with their QLoRA-adapted versions.

            Args:
                layer (tf.keras.layers.Layer): The current layer in the model.

            Returns:
                tf.keras.layers.Layer: The original or modified layer.
            """
            # Check if the layer's name matches any of the target layer names
            if any(name in layer.name for name in layer_names):
                if isinstance(layer, tf.keras.layers.Dense):  # Apply only to Dense layers
                    qlora_layer = QLoRALayer(
                        layer,
                        rank=self.rank,
                        alpha=self.alpha,
                        bits=self.bits,
                        group_size=self.group_size
                    )
                    self.qlora_layers.append(qlora_layer)  # Track applied layers
                    return qlora_layer  # Replace with QLoRA layer
            return layer  # Keep the layer unchanged

        # Clone the base model while applying QLoRA modifications
        new_model = tf.keras.models.clone_model(
            self.base_model,
            clone_function=replace_layer
        )

        return new_model


#### Explanation of the code

1. Class Initialization
    ```python
    def __init__(self,
                base_model,
                rank=8,
                alpha=32,
                bits=4,
                group_size=128):
        self.base_model = base_model
        self.rank = rank
        self.alpha = alpha
        self.bits = bits
        self.group_size = group_size
        self.qlora_layers = []
    ```
    - **Purpose**: Initializes Q-LoRA wrapper
    - **Parameters**:
    - base_model: Original model to adapt
    - rank: LoRA rank dimension
    - alpha: Scaling factor
    - bits: Quantization precision
    - group_size: Quantization group size
    - **Storage**: Tracks modified layers

2. Layer Replacement Method
    ```python
    def apply_qlora(self, layer_names=None):
        if layer_names is None:
            layer_names = ['query', 'key', 'value']
    ```
    - **Purpose**: Applies Q-LoRA to specified layers
    - **Default Targets**: 
    - query layers
    - key layers
    - value layers

3. Layer Replacement Function
    ```python
    def replace_layer(layer):
        if any(name in layer.name for name in layer_names):
            if isinstance(layer, tf.keras.layers.Dense):
                qlora_layer = QLoRALayer(
                    layer,
                    rank=self.rank,
                    alpha=self.alpha,
                    bits=self.bits,
                    group_size=self.group_size
                )
                self.qlora_layers.append(qlora_layer)
                return qlora_layer
        return layer
    ```
    - **Purpose**: Handles individual layer replacement
    - **Process**:
    1. Checks layer name match
    2. Verifies layer type
    3. Creates Q-LoRA layer
    4. Tracks modifications

4. Model Modification
    ```python
    # Clone and modify model
    new_model = tf.keras.models.clone_model(
        self.base_model,
        clone_function=replace_layer
    )
    ```
    - **Purpose**: Creates adapted model
    - **Features**:
    - Non-destructive modification
    - Preserves original model
    - Selective adaptation

#### Key Features:

1. Configuration Management
   ```python
   self.rank = rank
   self.alpha = alpha
   self.bits = bits
   self.group_size = group_size
   ```
   - Centralized parameter storage
   - Consistent configuration
   - Easy modification

2. Layer Tracking
   ```python
   self.qlora_layers.append(qlora_layer)
   ```
   - Maintains layer registry
   - Enables monitoring
   - Facilitates management

3. Selective Adaptation
   ```python
   if any(name in layer.name for name in layer_names):
   ```
   - Targeted modifications
   - Flexible layer selection
   - Controlled adaptation


### Training Configuration

In [None]:
class QLoRATrainer:
    """
    A trainer for fine-tuning models with QLoRA (Quantized Low-Rank Adaptation)
    and memory-efficient attention.

    This class uses `PagedAttention` to reduce memory consumption while training
    large models and applies gradient-based optimization.

    Attributes:
        model (tf.keras.Model): The model being fine-tuned with QLoRA.
        learning_rate (float): Learning rate for optimization.
        paged_attention (PagedAttention): A memory-efficient attention mechanism.
        optimizer (tf.keras.optimizers.Optimizer): The optimizer for training.
    """

    def __init__(self,
                 model,
                 learning_rate=1e-4,
                 max_memory=None):
        """
        Initializes the QLoRATrainer.

        Args:
            model (tf.keras.Model): The model to be fine-tuned with QLoRA.
            learning_rate (float, optional): The learning rate for optimization. Default is 1e-4.
            max_memory (int, optional): The maximum memory limit for PagedAttention. Default is None.
        """
        self.model = model  # Store the QLoRA model
        self.learning_rate = learning_rate  # Learning rate for training
        self.paged_attention = PagedAttention(max_memory)  # Memory-efficient attention
        
        # Define an optimizer (default: Adam)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)

    @tf.function
    def train_step(self, inputs, labels):
        """
        Performs a single training step using memory-efficient attention.

        Args:
            inputs (tf.Tensor): The input data (e.g., tokenized sequences).
            labels (tf.Tensor): The target labels for training.

        Returns:
            tf.Tensor: The computed loss value for this step.
        """
        with tf.GradientTape() as tape:
            # Forward pass using memory-efficient PagedAttention
            predictions = self.model(
                inputs,
                attention_implementation=self.paged_attention
            )
            loss = self.compute_loss(labels, predictions)  # Compute loss
        
        # Compute gradients of the loss w.r.t model parameters
        gradients = tape.gradient(loss, self.model.trainable_variables)
        
        # Apply gradients using optimizer
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        
        return loss
    
    def compute_loss(self, labels, predictions):
        """
        Computes the loss function for model training.

        Args:
            labels (tf.Tensor): The ground-truth labels.
            predictions (tf.Tensor): The model predictions.

        Returns:
            tf.Tensor: The computed loss value.
        """
        return tf.keras.losses.sparse_categorical_crossentropy(labels, predictions, from_logits=True)


#### Explanation of the code

1. Class Initialization
    ```python
    def __init__(self,
                model,
                learning_rate=1e-4,
                max_memory=None):
        self.model = model
        self.learning_rate = learning_rate
        self.paged_attention = PagedAttention(max_memory)
    ```
    - **Purpose**: Sets up Q-LoRA training environment
    - **Parameters**:
    - model: Q-LoRA adapted model
    - learning_rate: Training rate
    - max_memory: Memory limit for attention
    - **Features**: 
    - Memory-efficient attention
    - Configurable learning rate

2. Training Step Method
    ```python
    @tf.function  # TensorFlow optimization decorator
    def train_step(self, inputs, labels):
    ```
    - **Purpose**: Executes single training iteration
    - **Optimization**: Graph mode execution
    - **Parameters**:
    - inputs: Training data
    - labels: Target values

3. Forward Pass
    ```python
    with tf.GradientTape() as tape:
        # Forward pass with memory-efficient attention
        predictions = self.model(
            inputs,
            attention_implementation=self.paged_attention
        )
        loss = self.compute_loss(labels, predictions)
    ```
    - **Purpose**: Computes model predictions
    - **Features**:
    - Gradient tracking
    - Paged attention usage
    - Loss computation

4. Gradient Computation and Application
    ```python
    # Compute gradients
    gradients = tape.gradient(
        loss,
        self.model.trainable_variables
    )

    # Apply gradients
    self.optimizer.apply_gradients(
        zip(gradients, self.model.trainable_variables)
    )
    ```
    - **Purpose**: Updates model parameters
    - **Process**:
    1. Compute gradients
    2. Apply updates
    3. Return loss

#### Key Components:

1. Memory Management
   ```python
   self.paged_attention = PagedAttention(max_memory)
   ```
   - Efficient attention computation
   - Memory-aware processing
   - Controlled resource usage

2. Optimization
   ```python
   @tf.function
   def train_step(self, inputs, labels):
   ```
   - Graph compilation
   - Performance optimization
   - Efficient execution

3. Gradient Management
   ```python
   gradients = tape.gradient(
      loss,
      self.model.trainable_variables
   )
   ```
   - Automatic differentiation
   - Parameter updates
   - Training optimization