# Introduction to Lightning and Performance Profiling

When working with PyTorch, you often spend significant time writing boilerplate code for training loops, data handling, and device management. This repetitive work can distract you from the main goal: designing and training your model.

This is where **[Lightning](https://lightning.ai/) (formerly PyTorch Lightning)** comes in. It is a high level framework that organizes your PyTorch code and automates the engineering, letting you focus on the research. This lab will introduce you to Lightning's structure and show you how it simplifies advanced tasks like performance tuning.

You will also be introduced to **Profiling**, a technique used to analyze your code's performance and find "bottlenecks" that slow down your training. By the end of this lab, you will have used Lightning to diagnose and fix a real performance issue, giving you a complete workflow for building more efficient models.

Specifically, you will:

* Organize standard PyTorch code into Lightning's `LightningDataModule` and `LightningModule`.
* Run a training loop to establish a baseline for speed and accuracy.
* Use the integrated Profiler to diagnose a "model complexity" bottleneck.
* Verify your fix by profiling a more efficient model and comparing the results.
* Evaluate the final trade-off between training speed and model performance.

## Imports

In [None]:
import sys
import warnings

# Redirect stderr to a black hole to catch other potential messages
class BlackHole:
    def write(self, message):
        pass
    def flush(self):
        pass
sys.stderr = BlackHole()

# Ignore Python-level UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from lightning.pytorch.profilers import PyTorchProfiler
from torch.profiler import schedule
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision import datasets, transforms

import helper_utils

torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", category=UserWarning)

## Defining the Data and Model with Lightning

With your environment set up, it's time to structure the core components of your project. Lightning provides an organized approach by separating data-handling from model logic. You'll define these in the next two steps.

### Step 1: Simplifying Data Loading with the `LightningDataModule`

In your previous work with PyTorch, you have seen the standard data pipeline: you define a `Dataset` and then wrap it in a `DataLoader`. This often requires you to manage separate `DataLoader` instances for your training and validation sets, which can scatter your data-handling code.

Lightning simplifies and organizes this entire process by encapsulating all data-related logic into a single, reusable class called a <code>[LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html)</code>. This class becomes the central hub for sourcing, preparing, and delivering your data, which keeps your main training script clean and focused on the model.

* Define the `CIFAR10DataModule` class by implementing these essential methods:
    * **`__init__`**: The constructor where you define specifications for your data pipeline, such as `batch_size`, `num_workers`, and data `transforms`.
    * **`prepare_data()`**: This method handles initial, one time setup like downloading the dataset. Lightning ensures this only happens on a single process to avoid conflicts.
        * For this class, this method will check if the **CIFAR10** dataset is present and download it if not.
    * **`setup()`**: Prepares the data for use by creating your training and validation `Dataset` splits. The `stage` argument *could* be used to set up different data for different stages (e.g., `'fit'`, `'validate'`, `'test'`). However, in this case, the setup is the same for all stages and, by default, prepares the data needed for the `'fit'` stage.
        * Think of the **'fit'** stage as the combined training and validation loop. The **'test'** or **'validate'** stages are similar to a standalone evaluation phase in standard PyTorch.
    * **`train_dataloader()` & `val_dataloader()`**: These methods return the familiar PyTorch `DataLoader` instances, configured with essential performance settings.

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    """A LightningDataModule for the CIFAR10 dataset."""

    def __init__(self, data_dir='./data', batch_size=128, num_workers=0):
        """
        Initializes the DataModule.

        Args:
            data_dir (str): Directory to store the data.
            batch_size (int): Number of samples per batch.
            num_workers (int): Number of subprocesses for data loading.
        """
        # Call the constructor of the parent class (LightningDataModule).
        super().__init__()
        # Store the data directory path.
        self.data_dir = data_dir
        # Store the batch size for the DataLoaders.
        self.batch_size = batch_size
        # Store the number of worker processes for data loading.
        self.num_workers = num_workers
        # Define a sequence of transformations to be applied to the images.
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        """Downloads the CIFAR10 dataset if not already present."""
        
        # Download the training split of CIFAR10.
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        # Download the testing split of CIFAR10.
        datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        """
        Assigns train/val datasets for use in dataloaders.

        Args:
            stage (str, optional): The stage of training (e.g., 'fit', 'test').
                               The Lightning Trainer requires this argument, but it is not
                               utilized in this implementation as the setup logic is the
                               same for all stages. Defaults to None.
        """
        
        # Create the training dataset instance and apply the transformations.
        self.cifar_train = datasets.CIFAR10(self.data_dir, train=True, transform=self.transform)
        # Create the validation dataset instance (using the test set) and apply transformations.
        self.cifar_val = datasets.CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        """Returns the DataLoader for the training set."""
        # The DataLoader handles batching, shuffling, and parallel data loading.
        return DataLoader(self.cifar_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        """Returns the DataLoader for the validation set."""
        # Shuffling is not necessary for the validation set.
        return DataLoader(self.cifar_val, batch_size=self.batch_size, num_workers=self.num_workers)

### Stage 2: Structuring Your Model with the `LightningModule`

Now that you've organized your data handling, it's time to define the model itself. In standard PyTorch, you typically define your model's architecture in a class that inherits from `nn.Module`. The Lightning equivalent is the <code>[LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html)</code>. This class is where you will organize all of your model-related code, from the layer definitions to the logic for a single training step.

By separating the model's logic in the `LightningModule` from the training engine, Lightning lets you build powerful, reusable models without worrying about the complex engineering that makes them run.

* Define `CIFAR10LightningModule` class with these key methods:
    * **`__init__()`**: The constructor where you define your neural network architecture, loss function, and any metrics.
    * **`forward()`**: This method defines the forward pass of your model, just like in a standard PyTorch module.
    * **`training_step()`**: Here, you'll define the logic for a single training batch. You perform the forward pass, calculate the loss, and log metrics. You simply need to return the loss; Lightning's automation handles the backpropagation and weight updates for you.
    * **`validation_step()`**: This contains the same logic, but for your validation data.
    * **`configure_optimizers()`**: In this method, you select and return the optimizer and its hyperparameters. Lightning will use this to update your model's parameters.

In [None]:
class CIFAR10LightningModule(pl.LightningModule):
    """A flexible LightningModule for CIFAR10 image classification."""

    def __init__(self,
                 learning_rate=1e-3,
                 weight_decay=0.01,
                 conv_channels=(256, 512, 1024),
                 linear_features=2048,
                 num_classes=10):
        """
        Initializes the LightningModule with configurable layer parameters.

        Args:
            learning_rate: The learning rate for the optimizer.
            weight_decay: The weight decay (L2 penalty) for the optimizer.
            conv_channels: A tuple specifying the output channels for each
                           convolutional block.
            linear_features: The number of features in the hidden fully
                             connected layer.
            num_classes: The number of output classes for the classification task.
        """
        # Call the constructor of the parent class.
        super().__init__()
        # Save the hyperparameters passed to the constructor. This makes them
        # accessible via `self.hparams` and logs them automatically.
        self.save_hyperparameters()
        
        # Calculate the flattened size of the feature maps after the final
        # pooling layer. This is needed to define the input size of the
        # first fully connected layer.
        flattened_size = self.hparams.conv_channels[-1] * 4 * 4
        
        # Define the model's architecture using a sequential container.
        self.model = nn.Sequential(
            nn.Conv2d(3, self.hparams.conv_channels[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.hparams.conv_channels[0], self.hparams.conv_channels[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.hparams.conv_channels[1], self.hparams.conv_channels[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(flattened_size, self.hparams.linear_features),
            nn.ReLU(),
            nn.Linear(self.hparams.linear_features, self.hparams.num_classes)
        )
        
        # Initialize the loss function.
        self.loss_fn = nn.CrossEntropyLoss()
        
        # Initialize metrics to track accuracy for training and validation.
        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)

    def forward(self, x):
        """
        Defines the forward pass of the model.

        Args:
            x: The input tensor containing a batch of images.

        Returns:
            The output tensor (logits) from the model.
        """
        # Pass the input through the sequential model.
        return self.model(x)

    def training_step(self, batch, batch_idx=None):
        """
        Performs a single training step.
    
        Args:
            batch (Any): The data batch from the dataloader.
            batch_idx (int, optional): The index of the current batch. The Lightning Trainer
                                     requires this argument, but it is not utilized in this
                                     implementation. Defaults to None.
        """
        # Unpack the batch into inputs (images) and labels.
        inputs, labels = batch
        # Perform a forward pass to get the model's predictions (logits).
        outputs = self(inputs)
        # Calculate the loss.
        loss = self.loss_fn(outputs, labels)

        # Log the training loss at the end of each epoch.
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        # Update the training accuracy metric with the current batch's results.
        self.train_accuracy(outputs, labels)
        # Log the training accuracy at the end of each epoch.
        self.log("train_accuracy", self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        
        # Return the loss to Lightning for backpropagation.
        return loss

    def validation_step(self, batch, batch_idx=None):
        """
        Performs a single validation step.
    
        Args:
            batch (Any): The data batch from the dataloader.
            batch_idx (int, optional): The index of the current batch. The Lightning Trainer
                                     requires this argument, but it is not utilized in this
                                     implementation. Defaults to None.
        """
        # Unpack the batch into inputs (images) and labels.
        inputs, labels = batch
        # Perform a forward pass to get the model's predictions (logits).
        outputs = self(inputs)
        # Calculate the loss.
        loss = self.loss_fn(outputs, labels)

        # Log the validation loss at the end of each epoch.
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        # Update the validation accuracy metric with the current batch's results.
        self.val_accuracy(outputs, labels)
        # Log the validation accuracy at the end of each epoch.
        self.log("val_accuracy", self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        """
        Configures and returns the model's optimizer.

        Returns:
            An instance of the optimizer.
        """
        # Create and return the AdamW optimizer.
        return optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)

Now you can instantiate the classes you've defined. This process is as straightforward as creating a standard PyTorch model or dataloader. The key difference is that you are creating Lightning objects that organize the familiar PyTorch logic.

* `dm_loader`: This is the `CIFAR10DataModule` that will feed data to the model. You'll configure it to use `num_workers=2`.
* `model_baseline`: This is the `CIFAR10LightningModule` that you will train and analyze.

In [None]:
# Instantiate the DataModule (2 workers).
dm_loader = CIFAR10DataModule(num_workers=2)

# Create an instance of the LightningModule.
model_baseline = CIFAR10LightningModule()

## A Quick Training Run

Now that you have defined and instantiated your `LightningDataModule` and `LightningModule`, you can see them in action.

The code below uses a helper function that leverages the Lightning `Trainer` to run a complete training loop for five epochs. Don't worry about the specifics of how the training function works; that will be covered in future material. For now, the goal is to see your Lightning components working together to train the model.

As you run the cell, **pay attention to how long the training takes and the results**. 

In [None]:
baseline_results = helper_utils.run_full_training(model_baseline, dm_loader)

print("\nTraining Complete!\n")
print("Final Training Metrics:")

print(f"\tTraining Accuracy:    {baseline_results['train_accuracy']}%")
print(f"\tValidation Accuracy:  {baseline_results['val_accuracy']}%")

## Exploring Model Complexity

Great work! You have successfully run a full training loop using the organized structure of Lightning. Now, reflect on that process. Did you notice the total training time? For a classic and relatively small dataset like **CIFAR-10**, it might have felt slower than you would hope. This experience leads you to an essential question every machine learning practitioner must ask: Why did it take that long, and can you achieve similar results more efficiently?

To find the answer, you will need to investigate the potential causes. A common culprit for slow training is the model itself. This naturally leads you to question your specific architecture and ask:

> Is This Model Too Complex for the Dataset?

To investigate this, take a closer look at your baseline model's architecture. By default, it's configured with:

* Convolutional channels: `(256, 512, 1024)`
* Linear features: `2048`

This is a deep and wide network. However, the **CIFAR-10** dataset consists of small 32x32 pixel images. An architecture this powerful might be overkill for this task.

When a model is unnecessarily complex for a given dataset, it can create a **"model complexity" bottleneck**. This means the GPU spends a disproportionate amount of time on the model's own calculations, slowing down training without providing a significant accuracy benefit.

So, **how can you verify this suspicion before running a full, time-consuming training session?** You will need a tool to look inside your code's execution and see where the time is being spent.

## Profiling: Understanding Your Code's Performance

The previous section concluded that you need a tool to look inside your code's execution. That tool is a **Profiler**.

Profiling is the process of analyzing your code to get a detailed breakdown of where it spends the most time and resources, like CPU cycles, GPU time, and memory. This analysis is the key to confirming your hypothesis about the model's complexity and finding any performance **bottlenecks**, the specific parts of your code that are disproportionately slow.

With profiling, you can answer important questions about your code's efficiency:

* Is my model spending too much time on a specific operation?
* Am I using my GPU effectively, or is it sitting idle waiting for data?
* Are there any memory issues that could impact training?
* Where should I focus my optimization efforts for the biggest impact?

You typically profile your code *after* you have a working model but *before* you commit to long, expensive training runs. It helps you find and fix performance issues early, ensuring your training is as efficient as possible.

**The Lightning Advantage**

In standard PyTorch, you would need to manually import the profiler and manage its state within the training loop. Lightning simplifies this into a single, clean step. You just need to configure the profiler, and Lightning automatically integrates it into its execution routine to capture performance data.

* Your first step is to specify a location where the generated profiling reports will be saved. These reports will contain all the performance data for your analysis.

In [None]:
log_dir = "./profiler_output"

### Step 1: Configure the `PyTorchProfiler`

Next, you will create an instance of the <code>[PyTorchProfiler](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.profilers.PyTorchProfiler.html)</code>. This object allows you to control exactly how the profiling is done.

* `dirpath` & `filename`: These arguments tell the profiler where to save its output file.

* `schedule`: This is a vital parameter that controls the profiler's activity to ensure you only measure stable training steps. It defines a sequence:
    * `wait`: Ignores the first few batches.
        * This step is needed to skip initial, one-time setup operations (like memory allocation) that aren't part of a normal training step. Unlike the other phases, the profiler is completely **idle** during this time.
    * `warmup`: Runs a few more batches to allow hardware to stabilize.
        * This step is needed because hardware like GPUs requires a few batches to reach a stable, peak-performance state. Unlike the `active` phase, the profiler runs but **discards** the performance data collected here.
    * `active`: Begins recording performance data for the specified number of batches.
        * This is the main measurement phase. It's different from the other two because this is the only phase where the profiler **records and saves** performance data for you to analyze.
    * `repeat`: Specifies how many times this cycle should be executed.
>    
* `profile_memory`: Setting this to `True` tracks memory allocation, which is excellent for identifying operations with high memory consumption.

In [None]:
# Configure the PyTorch Profiler
profiler = PyTorchProfiler(
    # Set the directory to save the profiler report
    dirpath=log_dir,
    # Specify the filename for the report
    filename="profile_report",
    # Define the profiling schedule (wait -> warmup -> active)
    # Total 14 steps
    schedule=schedule(wait=2, warmup=2, active=10, repeat=1),
    # Enable memory usage profiling
    profile_memory=True
)

### Step 2: Initialize the `Trainer`

This is where all the pieces come together. You will now initialize the <code>[Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)</code>, the central engine in Lightning. It automates the entire training loop, which you will leverage here to perform a short diagnostic run. To switch from a normal training run to a profiling run, you simply pass your configured `profiler` object to the `Trainer`.

* `profiler=profiler`: This is the key step where you attach the profiler to the training process.
* `max_steps=14`: For this diagnostic run, you will limit training to just 14 steps, enough to cover the profiler's schedule (`wait=2` + `warmup=2` + `active=10`).
* `accelerator="auto"`: This parameter tells Lightning to automatically detect and use the available hardware.
* `logger=False` & `enable_model_summary=False`: You will disable the default logger and model summary output to keep the console clean and focused on the profiler's results.
* `enable_checkpointing=False`: This disables the automatic saving of model checkpoints, which is not needed for a short diagnostic run.

In [None]:
# Initialize the Trainer
trainer = pl.Trainer(
    # Attach the configured profiler
    profiler=profiler,
    # Limit training to 14 steps to match the profiler's schedule
    max_steps=14,
    # Automatically select the hardware accelerator (e.g., GPU, CPU)
    accelerator="auto",
    # Use a single device for training
    devices=1,
    # Disable the default logger for a cleaner output
    logger=False,
    # Disable the model summary for the same reason
    enable_model_summary=False,
    # Disable automatic checkpointing
    enable_checkpointing=False
)

### Step 3: Run the Diagnostic and Profiling

With the `profiler` and `Trainer` configured and your model and data objects ready, you can now start the diagnostic run.

* You'll do this by calling `.fit()` on your `Trainer` instance. This single command handles the entire training loop. Because you attached the `profiler` to the `Trainer`, this command automatically executes the short, diagnostic profiling run instead of a full training session.
    * `model_baseline` and `dm_loader`: These are the `LightningModule` and `LightningDataModule` objects you instantiated in the previous section.

In [None]:
# Start the training and profiling run.
trainer.fit(model_baseline, dm_loader)

# Print a confirmation message when done.
print("\nProfiling Complete!\n")

### Step 4: Analyzing the Profiler Results: From Raw Data to Insights

The `trainer.fit()` command you just ran generated a raw JSON trace file. This file contains a highly detailed, event-by-event log of every operation. 
The raw data is very fine-grained not necessarily focused on the type of operation that you want to put the emphasis on.

Before you look at the summarized table, it's helpful to understand the **major categories of work** the profiler is tracking. 
During any training step, the time is generally spent on a few key things:

* **Data Loading:** Moving a batch of data from your dataloader to the active device (e.g., the GPU).

* **Model Computation:** The core mathematical work of your network. This includes the **forward pass** (running data through layers) and the **backward pass** (calculating gradients).

* **Optimizer Step:** Applying the calculated gradients to update your model's weights.

* **Framework Overhead:** The internal operations run by PyTorch and Lightning to coordinate the entire process.

While a deep dive is beyond the scope of this lab, the `display_profiler_logs` function in the next cell parses all the granular events and helps you see which of these categories is taking the most time. It presents a simple, sorted summary of the most expensive operations.

* Run the next cell to display the top 10 most time-consuming operations from your diagnostic run. The table is sorted by total time in descending order.
    * Feel free to change the `head` value if you wish to see more or less rows.

In [None]:
# Display the top 10 most time-consuming operations from the profiler's report
helper_utils.display_profiler_logs(profiler, head=10)

#### Overview of the Profiler Table

A high-level overview of the profiler table reveals the major categories of operations being tracked:

- **ATen operations (`aten::...`)** : PyTorchâ€™s low-level tensor functions from the ATen C++ backend:
  - `aten::copy_`
  - `aten::to`
  - `aten::_to_copy`
>  
- **Lightning utility** : Functions from Lightning that handle device placement:
  - `transfer_batch_to_device`
>
- **CUDA runtime** : GPU driver/library calls for data transfer:
  - `cudaMemcpyAsync`
>
- **Training logic** : High-level operations in the training loop:
  - `ProfilerStep*` (the wrapper for the whole step)
  - `Optimizer.step#AdamW.step` (the optimizer update)

#### Focusing your Investigation on Model Computations

As you can see, the report is very detailed, mixing high-level model computations with lower-level data transfer and framework overhead. To test the hypothesis about model complexity, you need to cut through this noise and narrow your focus on the model's computations.

* Run the next cell to display a filtered table showing the total `ProfilerStep*` time (the full duration of one training iteration, including forward, backward, and optimization) alongside the four most time-consuming operations.

In [None]:
# Display a focused summary of the profiler report for the baseline run.
# This filters for the overall time and the top 4 computational operations.
helper_utils.display_model_computation_logs(profiler)

<br>

**Interpreting the Baseline Results: The Bottleneck**

Looking at the filtered table you just generated, you can see exactly where the model spends most of its computational time. 
The `ProfilerStep*` entry represents the overall time for a single training step. 
Below it, you can see the most expensive mathematical operations: `aten::conv2d` (convolution) and the backward-pass operations for the model's layers.

These operations are the fundamental building blocks of your network. Their high cost in this run leads to an important question: **what potential improvements could you make?**

An architecture this large is likely more powerful than necessary for this dataset. The profiler shows that the heavy computational work of the model itself is the dominant factor slowing things down. The more complex these calculations are, the longer the GPU is occupied with each batch. 

Next, you will profile a second, more streamlined model to measure how simplifying the architecture impacts performance.

## Profiling a More Efficient Model

The analysis in the last section suggested a **model complexity** bottleneck. Your baseline model's architecture is likely too powerful for the simple CIFAR-10 dataset, causing it to be unnecessarily slow.

To test this hypothesis, you will now profile a second, more streamlined version of the model to measure how simplifying the architecture impacts performance.

### Step 1: Configure a New Profiler

**Your Task**

* Your first step is to configure a new `PyTorchProfiler` for this second diagnostic run. The setup is identical to the baseline run, but you must provide a new `filename`. This is an important step to ensure you don't overwrite the results from your first analysis.
    * `dirpath`: Set this to the `log_dir` variable.
    * `filename`: Give it a new name, for example, `"profile_report_efficient"`.
    * `schedule:`: Use the same schedule as the baseline profiler (`wait=2`, `warmup=2`, `active=10`, `repeat=1`).
    * `profile_memory`: Enable this by setting it to `True`.

In [None]:
try:
    # Configure the PyTorch Profiler for the efficient model run
    profiler_efficient = PyTorchProfiler(
        
        # Set the directory to save the profiler report
        dirpath= log_dir,
        
        # Specify a new filename for the report
        filename="profile_report_efficient",
        
        # Define the profiling schedule (wait -> warmup -> active)
        schedule= schedule(wait=2, warmup=2, active=10, repeat=1),
        
        # Enable memory usage profiling
        profile_memory=True
    )

    print("\033[92mPyTorchProfiler configured successfully!")

except Exception as e:
    print("\033[91mSomething went wrong, try again!")
    raise e

<br>
<details>
<summary><span style="color:green;"><strong>Solution (Click here to expand)</strong></span></summary>

```python
# Configure the PyTorch Profiler for the efficient model run
profiler_efficient = PyTorchProfiler( ### Add your code here
    
    # Set the directory to save the profiler report
    dirpath=log_dir, ### Add your code here
    
    # Specify a new filename for the report
    filename="profile_report_efficient", ### Add your code here
    
    # Define the profiling schedule (wait -> warmup -> active)
    schedule=schedule(wait=2, warmup=2, active=10, repeat=1), ### Add your code here
    
    # Enable memory usage profiling
    profile_memory=True ### Add your code here
)
```

### Step 2: Configure a New Trainer

**Your Task**

* Next, you'll create a new `Trainer` instance for this run. To ensure a fair comparison with your baseline, you will use the exact same configuration as before. The only change is passing in your new `profiler_efficient` object.
    * `profiler`: Attach the `profiler_efficient` object you just created.
    * `max_steps`: Limit training to `14` steps to match the profiler's schedule.
    * `accelerator`: Set to `"auto"` to automatically select the hardware.
    * `devices`: Use a single device (1).
    * `logger`: Disable the logger by setting it to `False`.
    * `enable_model_summary`: Disable the model summary (`False`).
    * `enable_checkpointing`: Disable automatic checkpointing (`False`).

In [None]:
try:
    # Initialize the Trainer
    trainer_efficient = pl.Trainer(
        
        # Attach the configured profiler
        profiler=profiler_efficient,
        
        # Limit training to 14 steps to match the profiler's schedule
        max_steps=14,
        
        # Automatically select the hardware accelerator (e.g., GPU, CPU)
        accelerator="auto",
        
        # Use a single device for training
        devices=1,
        
        # Disable the default logger for a cleaner output
        logger= False,
        
        # Disable the model summary for the same reason
        enable_model_summary= False,

        # Disable automatic checkpointing
        enable_checkpointing= False
    )

    print("\033[92mTrainer configured successfully!")

except Exception as e:
    print("\033[91mSomething went wrong, try again!")
    raise e

<br>
<details>
<summary><span style="color:green;"><strong>Solution (Click here to expand)</strong></span></summary>

```python
# Initialize the Trainer
trainer_efficient = pl.Trainer( ### Add your code here
    
    # Attach the configured profiler
    profiler=profiler_efficient, ### Add your code here
    
    # Limit training to 14 steps to match the profiler's schedule
    max_steps=14, ### Add your code here
    
    # Automatically select the hardware accelerator (e.g., GPU, CPU)
    accelerator="auto", ### Add your code here
    
    # Use a single device for training
    devices=1, ### Add your code here
    
    # Disable the default logger for a cleaner output
    logger=False, ### Add your code here
    
    # Disable the model summary for the same reason
    enable_model_summary=False, ### Add your code here

    # Disable automatic checkpointing
    enable_checkpointing=False ### Add your code here
)
```

### Step 3: Profile the Efficient Model

Now, you'll instantiate your `LightningModule` again, but this time with a much simpler architecture (`conv_channels=(32, 64, 128`) and `linear_features=512`), to create the "efficient" version of the model.

In [None]:
# Create a new instance of the model with a much simpler architecture.
model_efficient = CIFAR10LightningModule(
    conv_channels=(32, 64, 128),
    linear_features=512
)

**Your Task**

Now it's time to start the second diagnostic run. You'll do this by calling the `.fit()` method on your new `trainer_efficient` instance.

You have to use the same data module (`dm_loader`) as before. This is an important part of the experiment, as it ensures the only variable you've changed is the model's architecture.

* Call `.fit()` on the `trainer_efficient` object.
* Pass the `model_efficient` and `dm_loader` as arguments.

In [None]:
try:
    # Start the second diagnostic run with the new, streamlined model.
    ### Add your code here
    trainer_efficient.fit(model_efficient, dm_loader)
    
    print("\nProfiling Complete!\n")
    
except Exception as e:
    print("\033[91mSomething went wrong, try again!")
    raise e

<br>
<details>
<summary><span style="color:green;"><strong>Solution (Click here to expand)</strong></span></summary>

```python
# Start the second diagnostic run with the new, streamlined model.
trainer_efficient.fit(model_efficient, dm_loader) ### Add your code here

print("\nProfiling Complete!\n")
```

### Step 4: Analyzing the Efficient Profiler Results

* Run the following cell to display the new profiler report.

In [None]:
# Display the top 10 most time-consuming operations from the profiler_efficient's report
helper_utils.display_profiler_logs(profiler_efficient, head=10)

#### Comparing the Results

Now for the direct comparison. The cell below will generate a summary table showing the performance of the key computational operations before and after you simplified the model's architecture.

* Run the next cell to display the comparison. Notice the difference in the total time for the `ProfilerStep*` and how the execution time for each computational kernel has changed.

In [None]:
# Generate the comparison report.
helper_utils.display_comparison_report(profiler, profiler_efficient)

<br>

#### Analyzing the Improvement: The Impact of Simplicity

The comparison table clearly shows a significant performance gain from simplifying the model's architecture. The most important change is the dramatic reduction in the overall `ProfilerStep*` time, meaning the simpler model completes a training step much faster. This speed-up comes directly from reducing the time spent on core computational kernels like `aten::conv2d` and their backward passes.

This demonstrates a key lesson in model optimization: **your architecture should match your dataset's complexity**. For a simple dataset like CIFAR-10, a smaller model is more efficient, allowing you to train faster.

#### A Note on the Results

You might find it interesting that while the overall `ProfilerStep*` time dropped significantly, the time for individual kernels like `aten::conv2d` or `aten::convolution_backward` changed only slightly, and may in some cases have even increased. This is not an error; it highlights how performance bottlenecks work.

* **Before (Complex Model)**: The model had very large tensors (e.g., `1024` channels). Handling these large shapes adds a lot of overhead in memory management and framework logic *around* the core calculations. The GPU might even be underutilized if the overhead of preparing these large tensors causes small delays between operations.

* **After (Simple Model)**: The core math for a single convolution doesn't fundamentally change, but the tensors it operates on are much smaller (e.g., `128` channels). The massive speed-up comes from the reduced overhead. With simpler data shapes to manage, the framework can feed operations to the GPU more efficiently, leading to a much faster total step time even if the time for one specific kernel doesn't change dramatically.

Essentially, you have removed the **model complexity bottleneck**, allowing the entire pipeline to run much more smoothly.

#### Beyond Model Complexity

It's important to remember that this is just one type of performance issue the profiler can help diagnose. In other scenarios, you might use it to uncover **data loading bottlenecks** (where the GPU is idle waiting for data from the CPU), inefficient memory usage, or other parts of the pipeline that slow down your training. The key is to use the profiler as a versatile tool to form and test hypotheses about any aspect of your code's performance.

## Training the Efficient Model

Great work! Your analysis and changes have paid off. The profiler has confirmed that the simplified model is significantly faster per step. However, speed is only half of the story.

The final, essential question is: **does this efficiency come at the cost of performance?** A faster model is only useful if it can still achieve good accuracy. To answer this, you will now run a full training loop on `model_efficient` and inspect its final metrics.

In [None]:
efficient_results = helper_utils.run_full_training(model_efficient, dm_loader)

print("\nTraining Complete!\n")
print("Final Training Metrics:")

print(f"\tTraining Accuracy:    {efficient_results['train_accuracy']}%")
print(f"\tValidation Accuracy:  {efficient_results['val_accuracy']}%")

* Finally, display a summary table to compare the final metrics of both models side-by-side, making it easy to see the trade-offs.

In [None]:
helper_utils.display_metrics_comparison(baseline_results, efficient_results)

<br>

Now, analyze your results in the table above.

Compare the validation metrics for your efficient model to the baseline. You'll likely find that the performance is very close. It's common for a well-sized, efficient model to perform comparably and sometimes even slightly better than an overly complex one on the same dataset.

The key takeaway is to evaluate the trade-off. In this case, you have likely achieved a significant improvement in training speed without a major compromise in your model's ability to generalize. With more hyperparameter tuning or by training for more epochs, you could likely improve these results even further.

## Conclusion

Congratulations on completing the lab! You've successfully navigated a complete workflow for diagnosing and improving a model's performance using **Lightning**.

You've seen firsthand how Lightning's structure, through the `LightningDataModule` and `LightningModule`, removes boilerplate and organizes your project. This clean structure makes it much simpler to integrate powerful tools like the profiler.

More importantly, you've learned a practical, data-driven approach to optimization. You started by establishing a baseline, used the **profiler** to form a hypothesis about a performance bottleneck, tested that hypothesis with a second profiling run, and finally, verified that your more efficient model performed equally well.

For those interested in exploring the full, granular details, you can download the `.JSON` trace files from the `./profiler_output/` directory and upload them to a trace viewer like the [Perfetto UI](https://ui.perfetto.dev). This provides far more information than the summary table.

Mastering these profiling techniques is an essential skill for building efficient models and scaling up to larger models and datasets.