# Chapter 1: Introduction

Learn about the promise and challenges of deep learning in biology. You will be walked through practical questions to consider before launching a new project—like what your model could replace, whether deep learning is even necessary, and how to structure your workflow. This chapter also includes a short technical introduction covering JAX/Flax, Python patterns common in machine learning, working environments, and practical setup tips.

---

## Table of Contents

---

## Using Code Examples

Supplemental material (code examples, exercises, etc.) is available for download at https://github.com/deep-learning-for-biology.

---

## Getting Started
Before jumping into code, we walk through how to frame a project, evaluate your data, and avoid common pitfalls. A bit of structure and planning up front will make your work more reproducible, more flexible, and ultimately more useful and impactful.

---

### Deciding What Your Model Will Replace

This section from the introductory chapter focuses on the strategic framing of biological deep learning projects, arguing that defining the "real-world" target of a model is more critical than the initial choice of architecture.

#### Key Summary Points

* **Avoid the "Tinker Trap":** Deep learning in biology is intellectually stimulating, which often leads researchers to spend excessive time on technical minutiae. To remain focused, one must identify the existing process the model is intended to replace or improve.
* **Domain-Specific Impact Areas:**
    * **Healthcare & Drug Discovery:** Models aim to replace slow or manual tasks such as dermatological diagnosis, culture-based pathogen detection, manual MRI tumor segmentation, and exhaustive wet-lab screening for drug-target interactions.
    * **Molecular Biology:** Computational tools like *AlphaFold* provide 3D protein structures that would otherwise require months of expensive X-ray crystallography or Cryo-EM. Other models act as digital alternatives to RNA-seq (gene expression) or manual variant interpretation.
    * **Ecology:** AI replaces labor-intensive field work, such as in-person biodiversity surveys (via acoustics) or manual crop scouting (via satellite/drone imagery), and offers non-invasive alternatives to physical animal tagging.
* **Quantifying Success:** Researchers should estimate the potential impact in terms of time, cost, or labor.
* **Innovation vs. Replacement:** Not all models replace old workflows. Some enable entirely new capabilities, such as generating *de novo* biological sequences or linking disparate data types that were previously incompatible. In these cases, success must be evaluated without established benchmarks.

---

### Determining Your Criteria for Success

Define success metrics early to avoid endless, unfocused experimentation and wasted time in deep learning projects.

**Five Types of Success Criteria:**

1. **Performance Metrics**
   - **Examples:** accuracy, AUC, F1 score
   - Goals may include matching human expert performance, achieving experimental correlation, or maintaining low false-positive rates

2. **Interpretability Requirements**
   - Focus on explainability and transparency of model decisions
   - Important for domain expert trust, calibrated uncertainty estimates, and understandable feature attributions

3. **Model Size and Inference Efficiency**
   - Critical for resource-constrained environments (smartphones, embedded devices)
   - Metrics include inference time, memory usage, energy consumption, and performance per FLOP (floating point operation)
   - May prioritize efficiency over raw accuracy for real-time applications

4. **Training Efficiency**
   - Relevant when compute resources are limited or in educational settings
   - May focus on CPU-compatible models rather than GPU-dependent ones
   - Prioritizes fast training and minimal hardware requirements

5. **Generalizability**
   - Aims for models that work across multiple datasets or tasks
   - Relevant for foundational models designed for broad applicability
   - Values flexibility and reusability over single-task optimization

**Key Takeaway:**
Establishing clear success criteria upfront helps determine when a project is complete and ensures efforts remain focused and realistic while balancing multiple objectives.

---


### Invest Heavily in Evaluations

Evaluation strategy should be a top priority from the start, not an afterthought. It guides the entire project and determines whether your work produces meaningful results.

**What Strong Evaluation Involves:**
- Defining precise measurement methods and metrics
- Establishing validation procedures
- Selecting appropriate baselines for comparison
- Creating a well-designed evaluation strategy before building models

**Benefits of Strong Evaluations:**
- Measure progress accurately
- Detect bugs in models or pipelines
- Estimate task difficulty
- Build intuition about the problem
- Provide a known point of comparison to assess if the model is learning meaningfully

**Recommended Time Allocation:**
A rough guideline for successful machine learning projects:
- **50%** - Designing evaluation strategies and running baselines
- **25%** - Curating or processing data
- **25%** - Model architecture development

**Critical Warning:**
Without good evaluations, you operate blindly — unable to determine if your model is improving, understand trade-offs, or verify meaningful learning is occurring.

**Key Principle:**
Evaluation is not an end-stage activity. It should be designed at the beginning and used to guide decisions throughout the entire project lifecycle.

---

### Designing Baselines

This section explains the importance of **baselines** as practical evaluation tools in machine learning — simple methods that establish minimum performance thresholds to compare against more complex models.

#### Purpose of Baselines
- Measure progress and understand task difficulty
- Catch bugs early in model development
- Sometimes surprisingly competitive with complex models
- Signal when something is wrong if models can't beat them

#### Classification Baselines

1. **Random prediction**: Equal probability for all classes (zero information baseline)

2. **Weighted random prediction**: Sample proportional to class frequencies in training data (useful for imbalanced datasets)

3. **Majority class**: Always predict most common class (strong baseline for highly imbalanced problems)

4. **Nearest neighbor**: Predict label of most similar training example (effective for low-dimensional or structured data)

### Regression Baselines

1. **Mean/median prediction**: Always predict training set average or median

2. **Single-feature linear regression**: Fit line using strongest individual predictor (tests incremental value of complexity)

3. **K-nearest neighbor regression**: Average target values of k most similar examples

#### Domain-Specific Heuristics

- Apply simple rules based on domain knowledge
- Examples:
  - **Diagnostics:** threshold-based classification on biomarkers
  - **Medical imaging:** rank by average pixel intensity
  - **Genomics:** assign mutations to nearest gene

#### Key Takeaway
If your model can't beat basic baselines, investigate your data, features, or modeling approach before adding complexity.

---

### Time-Boxing Your Project

Time-boxing is the practice of setting a fixed, non-negotiable timeframe for a project or specific task. In deep learning research—where projects can become open-ended and "failed" experiments are common—this strategy ensures that even unsuccessful ideas provide value without draining unlimited resources.

#### Strategies for Effective Time-Boxing

* **Establish a Rigid Deadline:** Determine a realistic total duration for the project (e.g., two weeks or three months). The project should pause or stop once this limit is reached, regardless of whether the target metrics were achieved.
* **Define Clear Checkpoints:** Break the timeline into intermediate milestones to monitor progress. Key checkpoints might include:
    * Completion of data preprocessing.
    * Training and evaluation of a baseline model.
    * Reaching a specific performance threshold.
* **Micro Time-Boxing:** Apply the same principle to specific sub-tasks or experimental ideas. For example, allocate exactly one week to test a new model architecture; if it does not show improvement within that window, abandon it and move on.
* **Structured Reflection:** Use the end of the time-box to evaluate outcomes. Focus on what was learned and what technical insights can be applied to future work, transforming a "failed" project into a stepping stone.
* **Mitigate Scope Creep:** Guard against the urge to justify extensions or "one more tweak." When perfectionism or indecision stalls progress, consult with a mentor or collaborator to regain perspective and maintain focus on the broader goals.

#### Summary

Time-boxing is a tool for maintaining focus and avoiding burnout. It forces a decision-making point where you must evaluate the project's viability, ensuring that your energy is always directed toward the most promising research avenues.

---

### Deciding Whether You Really Need Deep Learning

While deep learning is a powerful tool in the biological sciences, it is not always the optimal solution. This section emphasizes the importance of evaluating whether a simpler, traditional approach can meet your project's goals more efficiently.

#### Key Considerations for Choosing Your Approach

* **Evaluate Simpler Alternatives:** Before committing to a deep learning architecture, consider if linear regression, decision trees, or basic statistical techniques are sufficient.
* **Implementation and Setup:** Traditional methods are generally quicker to implement, easier to set up, and require less specialized expertise to maintain.
* **Computational Efficiency:** Simpler models are far less resource-intensive. They can often run on standard hardware (CPUs) with minimal training time, whereas deep learning typically requires expensive GPU resources.
* **Interpretability and Debugging:** Deep learning models are notoriously "black boxes" and difficult to troubleshoot. Simpler methods are often easier to explain to stakeholders, troubleshoot for errors, and validate against biological ground truth.
* **Weighted Trade-offs:** The smarter path is often the one that delivers the required performance with the least amount of complexity. If a traditional method provides the necessary insights, the overhead of deep learning may not be justified.

#### Summary

The decision to use deep learning should be based on necessity rather than novelty. Prioritizing simplicity when possible leads to more robust, interpretable, and cost-effective biological research.

---

### Ensuring That You Have Enough Good Data

In the context of biological deep learning, where data acquisition can be expensive and prone to technical noise, the mantra of "garbage in, garbage out" is particularly relevant. This section highlights that the sophistication of your model cannot compensate for poor underlying data.

#### Critical Data Requirements

* **Sufficient Quantity:** Deep learning models generally require thousands of labeled examples to generalize effectively.
    * **Benchmarking:** Consult existing literature to determine the standard dataset size for your specific biological task.
    * **Transfer Learning:** If your dataset is small (e.g., a rare disease cohort), use transfer learning. Start with a model pre-trained on a massive, related dataset (like ImageNet for microscopy or UniProt for protein sequences) and fine-tune it on your specific data.
* **Sufficient Quality:** The reliability of your model is capped by the cleanliness and consistency of your data.
    * **Error Impact:** Inconsistent labeling or high levels of experimental noise can cause models to fail catastrophically.
    * **Curation:** High-quality, curated data is often more valuable than a larger volume of "noisy" data. Prioritizing rigorous quality control (QC) and thoughtful curation is essential for building trustworthy models.

#### Summary

Success in deep learning is a balance between scale and precision. While you need enough data to capture biological variance, that data must be clean enough for the model to learn meaningful patterns rather than experimental artifacts.

---

### Assembling a Team

Collaborating effectively is a catalyst for success in biological deep learning, where the complexity of the data often requires a blend of computational and experimental expertise.

#### Strategies for Finding and Building a Team

* **Engage with Digital Communities:** Use platforms like Reddit, Discord, X, and specialized Slack groups to share ideas and meet potential partners.
* **Participate in Structured Challenges:** Join hackathons or competitions on platforms like Kaggle or Zindi to meet people with shared interests and receive immediate feedback.
* **Prioritize Interdisciplinary Diversity:** Aim for a "cross-pollination" of skills. Biologists should seek out machine learning experts, and vice versa, to ensure the model is both mathematically sound and biologically relevant.
* **Consult Domain Experts:** Reach out to authors of relevant papers or attendees at conferences. Genuine interest in a specific biological problem often leads to successful "cold" outreach and expert guidance.

#### Best Practices for Effective Collaboration

* **Establish Clear Governance:** Define specific roles, responsibilities, and decision-making processes early to prevent misunderstandings and scope creep.
* **Utilize a Shared Tech Stack:** Implement collaborative tools such as:
    * **Version Control:** Git for code management.
    * **Shared Environments:** Google Colab for interactive modeling.
    * **Task Tracking:** Notion, Trello, or simple shared documents to organize workflows.
* **Encourage Specialization:** Allow team members to focus on their strengths, whether that is data engineering, infrastructure, modeling, or biological interpretation.
* **Pilot the Partnership:** Start with a small, low-pressure "sprint" or exploration to test compatibility before committing to a long-term research project.

#### Summary

While solo research is possible, interdisciplinary teams often produce more robust and innovative results. By combining deep domain knowledge with technical ML expertise and using structured communication tools, you can significantly accelerate the "Get Started" phase of your project.

---

### You Don't Need a Supercomputer or a PhD

It is a common misconception that deep learning in biology is reserved for those with elite credentials or massive infrastructure. In reality, the field is increasingly accessible to anyone with curiosity and a laptop.

#### Challenging Common Misconceptions

* **The "Huge Compute" Myth:** You do not need a supercomputer to make a meaningful impact.
    * **Iterative Prototyping:** Start with small, lightweight models to test ideas quickly before scaling up.
    * **Accessible Hardware:** Utilize free GPU resources from platforms like **Google Colab** or **Kaggle**. For larger tasks, scalable cloud instances (AWS, GCP, Azure) allow you to pay only for what you use.
    * **Analysis over Training:** Significant research involves analyzing or fine-tuning existing models rather than training them from scratch, which requires much less computational power.
* **The "Expert-Only" Myth:** You do not need a PhD in both ML and Biology to contribute.
    * **Modern Tooling:** High-level frameworks (like PyTorch or JAX) have lowered the barrier to entry for building complex architectures.
    * **Open Source Ecosystem:** Leverage pre-trained models and open-source codebases to build upon the work of others.
    * **Abundant Learning Resources:** Tutorials, walkthroughs, and videos offer accessible pathways to mastering the necessary concepts outside of traditional academia.
    * **Uncharted Problems:** Many biological questions have yet to be approached with a machine learning lens, leaving plenty of room for newcomers to find niche areas of discovery.

#### Summary

The barrier to entry for biological deep learning is lower than it has ever been. By starting small, utilizing free resources, and leveraging the open-source community, you can contribute to the field regardless of your current budget or formal title.

---

## Technical Introduction

This section introduces the specific software ecosystem used in the book—**JAX** and **Flax**—and explains the rationale for choosing these tools for biological deep learning projects.

### The JAX and Flax Ecosystem

* **JAX:** A system for high-performance numerical computing that transforms Python and NumPy code into optimized machine code for accelerators (GPUs/TPUs).
* **Flax:** A flexible neural network library designed specifically to run on top of JAX.
* **`dlfb` (Deep Learning for Biology):** A custom companion library provided with the book to handle common utilities and repetitive tasks (https://github.com/deep-learning-for-biology/dlfb.git).

---

### Why Use JAX and Flax for Biology?

* **Familiarity:** JAX uses the `jax.numpy` ($jnp$) API, which is almost identical to standard NumPy, making the transition seamless for those already doing scientific computing in Python.
* **Functional Clarity:** JAX follows a "pure function" style. This explicit approach reduces hidden states, making the underlying math of biological models easier to understand and debug.
* **First-Class Transformations:** JAX offers powerful, composable tools:
    * `jit`: Just-In-Time compilation via the XLA (Accelerated Linear Algebra) compiler for speed.
    * `grad`: Automatic differentiation for calculating gradients.
    * `vmap`: Automatic vectorization to handle batches of data (like thousands of protein sequences) without manual loops.
* **Research Alignment:** JAX is the preferred tool for modern "AI for Science" research, including major breakthroughs like AlphaFold.

#### Trade-offs and Considerations

* **Learning Curve:** JAX requires a shift toward functional programming, which may feel different than the object-oriented approach of PyTorch.
* **Ecosystem Size:** The JAX community is smaller than PyTorch's, and APIs (like the shift from Flax `linen` to the newer `nnx`) can evolve quickly.
* **Framework Interoperability:** The book occasionally uses **PyTorch** (e.g., for Hugging Face model embeddings) because certain tools are more mature in that ecosystem.

#### Advanced Performance Optimization

While the book focuses on clarity, it identifies four key areas for scaling real-world biological models:

* **Numerical Precision:** Using formats like $bfloat16$ to speed up matrix multiplications on specialized hardware (Tensor Cores).
* **Profiling:** Using tools like `jax.profiler` to identify computational and memory bottlenecks.
* **Memory Efficiency:** Using **gradient checkpointing** (`remat`) to train deeper models by trading computation for memory.
* **Distributed Training:** Scaling models across multiple GPUs or TPUs for massive datasets.

#### Summary

Choosing JAX and Flax aligns your work with the "bleeding edge" of biological research while providing a transparent, mathematically grounded framework for learning.

For those seeking a deeper technical dive or troubleshooting support, the text recommends two specific JAX resources:
* **Official JAX Tutorials:** The primary source for detailed, hands-on learning and practical application of the framework (https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html).
* **The "Sharp Bits" Notebook:** An essential reference guide that documents common pitfalls and non-intuitive behaviors unique to JAX's functional programming model (https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).

---

### Python tips

This section covers essential Python concepts frequently encountered in machine learning code, particularly with JAX and Flax frameworks.

#### 1. Type Annotations and Docstrings

Python is dynamically typed, which is flexible but can hide bugs. Type annotations improve readability, enable static type checking (mypy) or Vs Code's Pylance, and simplify debugging.



```python
import numpy as np

# Basic function without type hints
def mean_squared_error(y_true, y_pred):
    squared_errors = (y_true - y_pred) ** 2
    return np.mean(squared_errors)

# Improved function with type hints and docstring
def mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Calculate the Mean Squared Error (MSE) between two NumPy arrays.
    
    Args:
        y_true (np.ndarray): Ground-truth values.
        y_pred (np.ndarray): Predicted values.
    """
    squared_errors = (y_true - y_pred) ** 2
    return np.mean(squared_errors)

```

**Benefits:**
- Clarifies input/output types
- Enhances IDE documentation and autocomplete
- Improves code readability
- Enables static type checking with tools like mypy

#### 2. Decorators

Decorators are functions that modify the behavior of other functions, commonly used for performance enhancement, caching, or logging.

For example, JIT Compilation with JAX (`@jax.jit`) decorator:

**Basic function:**

```python
import jax
import jax.numpy as jnp

def compute_ten_power_sum(arr: jax.Array) -> float:
    """Raise values to the power of 10 and then sum."""
    return jnp.sum(arr**10)

arr = jnp.array([1, 2, 3, 4, 5])
compute_ten_power_sum(arr)
# Output: Array(10874275, dtype=int32)
```

**Method 1 - Apply JIT directly:**
```python
jitted_compute_ten_power_sum = jax.jit(compute_ten_power_sum)
jitted_compute_ten_power_sum(arr)
# Output: Array(10874275, dtype=int32)
```

**Method 2 - Use decorator syntax:**
```python
@jax.jit
def compute_ten_power_sum(arr: jax.Array) -> float:
    """Raise values to the power of 10 and then sum."""
    return jnp.sum(arr**10)

compute_ten_power_sum(arr)
# Output: Array(10874275, dtype=int32)
```

**How `@jax.jit` works:**
1. Traces the function using special tracer objects (not real data)
2. Builds a computation graph (static representation of operations)
3. Compiles via XLA (Accelerated Linear Algebra) to optimized machine code
4. Caches compiled version for reuse with same input shapes/types
5. Result: ~20× speedup on GPU

**JIT Debugging Challenges:**
- `print()` statements and `pdb` don't work as expected
- Side effects are skipped during tracing
- Cryptic error messages referencing internal JAX/XLA code

**Solution**: Set environment variable `JAX_DISABLE_JIT=True` to globally disable JIT for debugging or you may set directly in your Python code: 

```python
import jax
jax.config.update("jax_disable_jit", True)

def f(x):
    y = jnp.log(x)
    if jnp.isnan(y):
        breakpoint()
    return y

jax.jit(f)(-2.)  # ==> Enters PDB breakpoint!

```
**Strengths and limitations of `jax_disable_jit`**
* **Strengths:**
    * Easy to apply
    * Enables use of Python’s built-in `breakpoint` and `print`
    * Throws standard Python exceptions and is compatible with PDB postmortem
* **Limitations:**
    * Running functions without JIT-compilation can be slow

See the [JAX debugging documentation](https://docs.jax.dev/en/latest/debugging/flags.html#jax-disable-jit-configuration-option-and-context-manager) for more details:


#### 3. Preconfiguring JAX JIT with `partial`

`functools.partial` prefills/binds arguments to create new functions with fixed values, a general utility in Python.

**Basic example:**

```python
from functools import partial

def scale(x, scaling_factor):
    return x * scaling_factor

# Create new function with scaling_factor fixed to 10
scale_by_10 = partial(scale, scaling_factor=10)
scale_by_10(3)
# Output: 30

```

Here, `scale_by_10` is a new function that behaves like `scale(x, 10)`.

**JAX-specific usage with static arguments:**

In the context of JAX, `partial` is often used to customize a decorator before applying it, like this: `@partial(jax.jit, static_argnums=...)`. This is a way to configure the jax.jit decorator itself.

```python
from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnums=(0,))
def summarize(average_method: str, x: jax.Array) -> float:
    if average_method == "mean":
        return jnp.mean(x)
    elif average_method == "median":
        return jnp.median(x)
    else:
        raise ValueError(f"Unsupported average type: {average_method}")

data_array = jnp.array([1.0, 2.0, 100.0])

# JAX compiles one version for average_method="mean"
print(f"Mean: {summarize('mean', data_array)}")

# JAX compiles another version for average_method="median"
print(f"Median: {summarize('median', data_array)}")

# Calling with "mean" again uses cached compiled version
print(f"Mean again: {summarize('mean', data_array)}")

# Output:
# Mean: 34.333335876464844
# Median: 2.0
# Mean again: 34.333335876464844
```

If we didn’t mark `average` as static with `static_argnums=(0,)`, JAX would throw an error, because it can’t trace control flow that depends on strings unless it knows their value ahead of time. Marking arguments as static tells JAX to compile a separate,
specialized version of the function for each unique value of that static argument it encounters.

**Static vs Dynamic arguments:**
* **Dynamic**: Numerical inputs (`jax.Array`, `float`, `int`) - can vary without recompilation if shapes/types remain constant
* **Static**: Strings, Python objects, functions - affect control flow; must mark with `static_argnums` or use closures


#### 4. Closures

**Definition**: Functions that "remember" their enclosing scope's variables.

**Example:**
```python
def outer_function(x):
    def inner_function(y):
        return x + y  # inner_function "closes over" x
    return inner_function

add_five = outer_function(5)  # x is 5
result = add_five(10)  # y is 10
print(f"Closure result: {result}")
# Output: Closure result: 15
```

**Usage in JAX ML code:**
- Extensively used for loss functions, regularizers, augmentation pipelines
- Avoids passing configuration values as arguments (which might require `static_argnums`)
- Values are "closed over" instead

#### 5. Generators

Iterates over data lazily (one item at a time) - essential for large datasets that don't fit in memory.

**Simple generator:**
```python
from typing import Iterator

def data_generator() -> Iterator[dict]:
    """Yield data samples with features and labels."""
    for i in range(5):
        yield {"feature": i, "label": i % 2}

# Example usage
generator = data_generator()
next(generator)
# Output: {'feature': 0, 'label': 0}
```

**Integration with TensorFlow Datasets (TFDS):**
```python
import tensorflow as tf

features = np.array([1, 2, 3, 4, 5])
labels = np.array([0, 1, 0, 1, 0])

# Create TensorFlow dataset from NumPy arrays
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Batch with size 2, drop incomplete final batch
batched_dataset = dataset.batch(2, drop_remainder=True)

# Create iterator and retrieve first batch
ds = iter(batched_dataset)
next(ds)
# Output:
# (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 2])>,
#  <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)
```

**Why TFDS with JAX?**
- JAX lacks native data-loading library
- TFDS provides clean API for batching, shuffling, and prefetching
- Custom pipelines offer more control (covered in later chapters)

#### `jax.jit` decorator

In [10]:
import jax 
import jax.numpy as jnp

def compute_ten_power_sum(arr: jax.Array) -> float:
    """Computes the sum of 10 raised to the power of each element in the input array."""
    return jnp.sum(arr ** 10)
    
arr = jnp.array([1, 2, 3, 4, 5])

# No JIT compilation
%time print(compute_ten_power_sum(arr))

# Jitted version
jitted_compute_ten_power_sum = jax.jit(compute_ten_power_sum)
%time print(jitted_compute_ten_power_sum(arr)) # first call (compilation time) takes longer
%time print(jitted_compute_ten_power_sum(arr)) # subsequent calls are faster

10874275
CPU times: user 425 μs, sys: 16 μs, total: 441 μs
Wall time: 267 μs
10874275
CPU times: user 24.8 ms, sys: 0 ns, total: 24.8 ms
Wall time: 23.7 ms
10874275
CPU times: user 103 μs, sys: 0 ns, total: 103 μs
Wall time: 106 μs


In [13]:
@jax.jit
def compute_ten_power_sum(arr: jax.Array) -> float:
    """Computes the sum of 10 raised to the power of each element in the input array."""
    return jnp.sum(arr ** 10)

arr = jnp.array([1, 2, 3, 4, 5])
%time print(compute_ten_power_sum(arr)) # first call (compilation time) takes longer
%time print(compute_ten_power_sum(arr)) # subsequent calls are faster

arr = jnp.array([5, 4, 3, 2, 1])
%time print(compute_ten_power_sum(arr)) # if array shape/dtype is the same, no recompilation

arr = jnp.array([5, 4, 3, 2])
%time print(compute_ten_power_sum(arr)) # different shape, triggers recompilation

10874275
CPU times: user 24.6 ms, sys: 43 μs, total: 24.6 ms
Wall time: 23.6 ms
10874275
CPU times: user 104 μs, sys: 0 ns, total: 104 μs
Wall time: 106 μs
10874275
CPU times: user 66 μs, sys: 0 ns, total: 66 μs
Wall time: 68.2 μs
10874274
CPU times: user 20.2 ms, sys: 989 μs, total: 21.2 ms
Wall time: 21.1 ms


#### `jax.jit` with Python's `partial`

In [15]:
from functools import partial
import jax
import jax.numpy as jnp

# @partial(jax.jit, static_argnums=(0,)) # using deprecated static_argnums
@partial(jax.jit, static_argnames=("average_method",)) # using static_argnames
def summarize(average_method: str, x: jax.Array) -> float:
    if average_method == "mean":
        return jnp.mean(x)
    elif average_method == "median":
        return jnp.median(x)
    else:
        raise ValueError(f"Unknown average method: {average_method}")
    
arr = jnp.array([1, 2, 3, 4, 5])
%time print(summarize("mean", arr))  # JIT compilation for "mean"
%time print(summarize("mean", arr))  # Subsequent call for "mean"
%time print(summarize("median", arr))  # JIT compilation for "median"
%time print(summarize("median", arr))  # Subsequent call for "median"

3.0
CPU times: user 24.3 ms, sys: 0 ns, total: 24.3 ms
Wall time: 23.7 ms
3.0
CPU times: user 333 μs, sys: 0 ns, total: 333 μs
Wall time: 288 μs
3.0
CPU times: user 57.4 ms, sys: 4.09 ms, total: 61.5 ms
Wall time: 38.1 ms
3.0
CPU times: user 110 μs, sys: 0 ns, total: 110 μs
Wall time: 113 μs


#### Closure

In [None]:
def outer_function(x: float):
    def inner_function(y: float):
        return y + x
    return inner_function


add_five = outer_function(5.0)
print(add_five(3.0))  # Outputs 8.0
print(add_five.__closure__[0].cell_contents)  # Inspect closure to see captured variables
print(add_five)

8.0
5.0
<function outer_function.<locals>.inner_function at 0x7b250675b100>


#### Generators

In [32]:
import tensorflow as tf

import jax
import jax.numpy as jnp

features = jnp.array([1, 2, 3, 4, 5])
labels = jnp.array([0, 0, 1 , 1, 0])

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

batched_dataset = dataset.batch(2, drop_remainder=True)

ds = iter(batched_dataset)
try:
    print(next(ds))
    print(next(ds))
    print(next(ds))
except StopIteration:
    print("End of dataset reached.")

(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>)
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([3, 4], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
End of dataset reached.


---

### AnatomyAnatomy of a Training Loop with JAX/Flax

---

## Bonus learning materials

### JIT compilation

The term **"Just-in-Time" (JIT)** refers to the exact moment the compilation happens. In traditional programming languages (like C++ or Fortran), compilation happens before you ever run the program. In JAX, the compilation happens *while the program is running*, specifically the very first time a function is called.

Here is a breakdown of why this distinction matters and how it works:

#### 1. The Timing: A "Late" Compilation

In a standard "Ahead-of-Time" (AOT) workflow, you compile your code into a binary file, and then you run that file. In JAX, you provide a Python function, and the "Just-in-Time" compiler stays idle until you actually trigger that function with real data.
    * **Step 1:** You define the function.
    * **Step 2:** You call the function with an input of a specific shape (e.g., a protein sequence of length $L = 500$).
    * **Step 3 (The "Just-in-Time" part):** JAX realizes it doesn't have a compiled version for that specific input shape yet. It pauses, converts the Python code into an optimized XLA kernel, and then executes it.

#### 2. Tracing: Learning by Example

The reason JAX waits until the "last second" (Just-in-Time) is because it needs to see the **shapes** and **types** of your data to optimize effectively. This process is called **Tracing**.

When you call a JIT-ed function, JAX sends "abstract" versions of your data through the function to see what happens. It records every operation (+, −, ×, ÷) to create a **StableHLO** (a high-level intermediate representation). By waiting until you provide data, JIT can:
    * See that your matrix is $1000 \times 1000$.
    * Optimize the machine code specifically for those dimensions.

#### 3. Specialization

If you call the same function later with a *different* shape (e.g., a sequence of length $L=200$), JAX will compile it again, "Just-in-Time" for that new shape. It builds a library of specialized versions of your function in the background.

#### Comparison Summary

| Feature | Interpreted (Python/NumPy) | Ahead-of-Time (C++/Fortran) | Just-in-Time (JAX/XLA) |
| --- | --- | --- | --- |
| **When is it compiled?** | Never (translated line-by-line) | Before the program runs | During execution (on first call) |
| **Performance** | Slow (High overhead) | Very Fast | Very Fast |
| **Flexibility** | High | Low (must re-compile manually) | High (auto-specializes to shapes) |

#### Why this is a "Scientific" Advantage

In biological modeling, we often deal with variable-sized inputs (different DNA lengths, different number of atoms in a molecule). JIT allows us to write flexible Python code that feels "easy," while the compiler works "Just-in-Time" to give us the speed of a low-level language like C++.


### Closure

In Python, a **closure** is a function object that "remembers" values in the enclosing scope even if they are no longer present in memory.

For a closure to exist, three conditions must be met:

1. There must be a **nested function** (a function inside a function).
2. The nested function must refer to a value defined in the **enclosing function**.
3. The enclosing function must **return** the nested function.

---

## 1. Basic Closure Example

Think of a closure as a function with a "backpack" where it stores variables from its birthplace.

```python
def make_multiplier(x):
    # This is the enclosing function
    def multiplier(y):
        # This nested function "closes over" the variable x
        return x * y
    
    return multiplier

# 'times_three' is now a closure that remembers x = 3
times_three = make_multiplier(3)

# 'times_five' is a separate closure that remembers x = 5
times_five = make_multiplier(5)

print(times_three(10))  # Output: 30
print(times_five(10))   # Output: 50

```

---

## 2. Accessing the "Backpack" (`__closure__`)

Python stores these remembered variables in a special attribute called `__closure__`. Each item in this attribute is called a **cell**.

```python
# Continuing from the example above:
print(times_three.__closure__[0].cell_contents) 
# Output: 3

```

---

## 3. The `nonlocal` Keyword

By default, a closure can read the outer variable but cannot modify it. If you want to change a variable in the enclosing scope, you must use the `nonlocal` keyword. This is common for creating "counters" or "accumulators."

```python
def make_counter():
    count = 0
    
    def increment():
        nonlocal count  # Allows modification of the outer 'count'
        count += 1
        return count
    
    return increment

counter_a = make_counter()
print(counter_a())  # Output: 1
print(counter_a())  # Output: 2

counter_b = make_counter()
print(counter_b())  # Output: 1 (Starts its own separate count)

```

---

## 4. Why use Closures?

In machine learning (and especially in **JAX** as mentioned in your earlier text), closures are powerful for:

* **Data Hiding:** They provide a way to store state without using a full Class object.
* **Function Factories:** You can generate specialized versions of a function (like a specific loss function with fixed hyperparameters).
* **Decorators:** Closures are the underlying mechanism that makes Python decorators work.

### Comparison to Classes

If you only have one method in a class, a closure is often a more elegant, lightweight, and memory-efficient solution.

| Feature | Closure | Class |
| --- | --- | --- |
| **Setup** | Lightweight (function) | Heavier (object + methods) |
| **State** | Fixed via "backpack" | Mutable via `self` |
| **Usage** | Functional programming | Object-Oriented programming |

---

**Would you like me to show you how a closure is used to build a Python decorator, or perhaps how to use a closure to "bake" hyperparameters into a JAX loss function?**