## 🧠 Useful Util method when working with Language Models - 

### For traversing through nested attributes (or sub-modules) of a PyTorch module

Language models generally have multiple layers/modules that are nested within each other and you might want to replace a specific layer of a pre-trained model with a new one, or you might want to freeze the weights of specific layers while fine-tuning the rest, or you might want to inspect the parameters or output of a specific layer for debugging. 

For all those cases you need a way to pinpoint the exact sub-module you want to work with.

### And the `get_child_module_by_names()` util method does just that

In [None]:
# 🧠 useful Utility method when working with Language Models
# Traverse through nested attributes (or sub-modules) of a PyTorch module

import torch
from typing import Iterable
from torch.nn import Module

def get_child_module_by_names(module: Module, names: Iterable[str]) -> Module:
    obj = module
    for getter in map(lambda name: lambda obj: getattr(obj, name), names):
        obj = getter(obj)
    # outer lambda function takes `name` as an arg and returns another lambda function.
    # inner lambda function takes `obj` as its arg and applies `getattr(obj, name)`.
    return obj

######### USING THE ABOVE UTIL METHOD #################
# Let's look at an example where we might want to freeze
# the embedding layers of a pre-trained BERT model for fine-tuning.
# This would involve setting the `requires_grad` attribute of the parameters
# of these layers to `False`.

from transformers import BertModel

def get_child_module_by_names(module, names):
    obj = module
    for getter in map(lambda name: lambda obj: getattr(obj, name), names):
        obj = getter(obj)
    return obj

# Load a pre-trained BERT model
model = BertModel.from_pretrained('bert-base-uncased')

# The names to the embedding layer in BERT
names_to_embeddings = ['embeddings']

# Use the function to fetch the embeddings module
embeddings = get_child_module_by_names(model, names_to_embeddings)

# Freeze the embeddings module
for param in embeddings.parameters():
    param.requires_grad = False

## Lets see how the `get_child_module_by_names()` method  is working

👉 The first argument `module`, is a PyTorch `nn.Module` object. It could be any object but, in the context of HuggingFace language models, it's likely to be a model like BERT, GPT-NeoX, etc. The second argument, `names`, is an iterable (most likely a list or tuple) containing the names of the nested sub-modules we're trying to access, in order.

👉 The function initializes `obj` as `module`, which means initially `obj` is the outermost module. 

----------------

👉 Next, the function enters a loop over `names`. For each name in the `names` list, it applies a function that takes an object and returns the attribute of that object corresponding to the current name. The lambda function `lambda name: lambda obj: getattr(obj, name)` is a higher-order function that returns another function. This nested function will, when called with an object, try to fetch the attribute with the name `name`.

This style is an example of nested lambda function within the map function.

📌 The outer lambda function takes `name` as an argument and returns another lambda function.

📌 The inner lambda function takes `obj` as its argument and applies `getattr(obj, name)`.

`getattr()` takes two args: the object you want to get an attribute from, and a string that is the name of the attribute you want to get. For example, if we have an object `foo` with an attribute `bar`, we can get `bar` by calling `getattr(foo, 'bar')`.

This can be useful when we need to access attributes dynamically, i.e., we don't know which attributes we need to access until the program is running. `getattr` is used here to dynamically access the sub-modules of a PyTorch module, given the string names of these sub-modules.

👉 The loop can be read like this: "For each `name` in `names`, get me the attribute of `obj` that has the name `name`, and set `obj` to be this newly fetched attribute". This way, with each iteration, `obj` becomes a deeper sub-module within the original module.

👉 The function repeats this until all names in `names` have been processed. The final object, `obj`, returned by this function will be the innermost requested module (the attribute of the module represented by the last name in `names`).


---------------

👉 Use of the method

The code in the image has an example use-case, where we might want to freeze the embedding layers of a pre-trained BERT model for fine-tuning. This would involve setting the `requires_grad` attribute of the parameters of these layers to `False`. 

The specific path to the embeddings module (`['embeddings']`) is determined by the structure of the BERT model and can be found in its documentation or source code.

The parameters of the fetched embeddings module are iterated over, and their `requires_grad` attribute is set to `False`, which effectively "freezes" them. This means that when we backpropagate through the network during fine-tuning, the gradients w.r.t these parameters will not be computed, and the optimizer will not update these parameters.

--------

👉 In the `get_child_module_by_names()` method why we needed the nested lambda structure.

📌 The nested lambda structure in `get_child_module_by_names` is a way to dynamically create a sequence of getter functions, each tailored to access a specific attribute of an object. Its designed to navigate through nested sub-modules of a PyTorch model by sequentially accessing named attributes.

--------------

## Why use nested lambda structure for getting sub-modules of a Pytorch Language Model

📌  In the `get_child_module_by_names()` method why we needed the nested lambda structure.

📌 The nested lambda structure in `get_child_module_by_names` is a way to dynamically create a sequence of getter functions, each tailored to access a specific attribute of an object. Its designed to navigate through nested sub-modules of a PyTorch model by sequentially accessing named attributes.

-------

📌 Example code in image:

Consider a PyTorch model with a nested structure, like `model.layer1.sublayer2.attribute`. If you want to access `attribute`, you must first navigate through `layer1` and then `sublayer2`.

📌 Here's how the nested lambda works in this context:

1. **Outer Lambda**: Takes the name of a sub-module (e.g., 'layer1') and returns an inner lambda function.

2. **Inner Lambda**: This function is configured to take an object (`obj`) and apply `getattr(obj, 'layer1')`, effectively fetching `obj.layer1`.

📌 Applied to the sequence of names:

- For 'layer1', you get a function that does `getattr(obj, 'layer1')`.
- For 'sublayer2', you get another function for `getattr(obj, 'sublayer2')`.
- These functions are then applied in sequence to dive deeper into the model's structure.

📌 In this simple example, the nested lambda functions act as a chain of getters, each fetching the next level in the nested structure until the final attribute is reached. 

In [None]:
# nested lambda structure for getting
# sub-modules of a Pytorch Language Model

# Imagine a nested object structure

class SubLayer:
    def __init__(self):
        self.attribute = "value"

class Layer:
    def __init__(self):
        self.sublayer = SubLayer()

model = Layer()

# Names to access the nested attribute
names = ["sublayer", "attribute"]

# The method in action
obj = model
for name in names:
    # Equivalent to: obj = getattr(obj, name)
    getter = lambda obj, name=name: getattr(obj, name)
    obj = getter(obj)

print(obj)  # Outputs: value

------------

💡BONUS DISCUSSION - How the `getter` in the `get_child_module_by_names()` function works ❓

The `getter` in the `get_child_module_by_names()` function is an example of a higher-order function, which is a function that returns another function. In this case, `getter` is generated by the expression `lambda name: lambda obj: getattr(obj, name)`. This can be broken down into two parts:

1. `lambda name: ...` is a function that takes a `name` and returns another function.

2. `... lambda obj: getattr(obj, name)` is the function that gets returned. It's a function that takes an `obj` and uses `getattr` to fetch the attribute of `obj` that has the name `name`.

So, the `getter` is a function that, when given an object, fetches the attribute of that object with a certain name. The specific name is determined when the getter is created.

Here's a simple example of how the  `getter` works:

In this example, `create_getter()` is analogous to `lambda name: lambda obj: getattr(obj, name)` in the `get_child_module_by_names()` function, and `getter1` and `getter2` are analogous to `getter`.

This example demonstrates the power of higher-order functions. By creating different getters, we can fetch different attributes from the same object, or the same attribute from different objects, depending on how we use them.



```py
# Define a simple class with a couple of attributes
class MyClass:
    attr1 = 'hello'
    attr2 = 'world'

# Create an instance of the class
my_instance = MyClass()

# Define a function that creates a getter for an attribute
def create_getter(name):
    return lambda obj: getattr(obj, name)

# Create a getter for 'attr1'
getter1 = create_getter('attr1')

# Use the getter to fetch 'attr1' from my_instance
print(getter1(my_instance))  # Output: hello

# Create a getter for 'attr2'
getter2 = create_getter('attr2')

# Use the getter to fetch 'attr2' from my_instance
print(getter2(my_instance))  # Output: world


```