# Seminar 2. Custom PyTorch Operators

# Building Models in PyTorch Through Composition

PyTorch models are built using **composition**.  
Instead of defining one large monolithic network, we construct models by combining smaller, reusable modules.

Each module can contain other modules, which allows us to build hierarchical and well-structured architectures.

---

## Composition

Composition means:

- A model is built from smaller blocks.
- Each block can contain multiple layers.
- Blocks can be reused in larger architectures.
- Complex models are created by stacking simpler components.

This keeps code:

- Modular  
- Reusable  
- Readable  
- Easy to extend  


## Key Ideas

- Inherit from `nn.Module`
- Define layers inside `__init__`
- Define computation in `forward()`
- Create reusable blocks
- Build larger models by combining blocks

---

## Example: Model Built from Two Blocks

Below is a simple example where:

- We define a reusable blocks: `LinearReLUBlock` and `LinearTanhBlock`
- The final model is composed of two such blocks


In [None]:
import torch
import torch.nn as nn


class LinearReLUBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()

        self.linear = nn.Linear(in_features, out_features)
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        x = self.activation(x)
        return x


class LinearTanhBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()

        self.linear = nn.Linear(in_features, out_features)
        self.activation = nn.Tanh()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        x = self.activation(x)
        return x


class CombinedModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.block1 = LinearReLUBlock(4, 8)
        self.block2 = LinearTanhBlock(8, 8)
        self.output = nn.Linear(8, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = self.output(x)
        return x


model = CombinedModel()
print(model)

# What `nn.Module` Enables

When we inherit from `nn.Module`, we automatically gain powerful functionality that works **recursively across all submodules**.

## What `nn.Module` Gives Us



### Parameter Registration

All layers assigned as attributes (e.g. `self.linear = nn.Linear(...)`) are:

- Automatically registered
- Collected in `model.parameters()`
- Included in `model.state_dict()`

This works **recursively** for all sub-blocks.

In [None]:
print("Registered parameters:")
for name, param in model.named_parameters():
    print(name, param.shape)


### Automatic Gradient Tracking

During the forward pass:

- PyTorch dynamically builds a computation graph
- Calling `loss.backward()` computes gradients
- Gradients are stored in each parameter’s `.grad`

No manual graph management is required.

In [None]:
x = torch.randn(5, 4)
target = torch.randn(5, 2)

criterion = nn.MSELoss()
output = model(x)
loss = criterion(output, target)

loss.backward()

print("\nGradient computed for output layer:",
      model.output.weight.grad is not None)
model.output.weight.grad

NameError: name 'torch' is not defined

### Device and Type Transfer (`.to()`)

Calling:

    model.to(device)

or

    model.to(dtype)

moves all:

- Parameters
- Buffers
- Submodules

to CPU/GPU/dtype automatically.

In [None]:
import torch

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

dtype = torch.float32

model = model.to(device=device, dtype=dtype)

x: torch.Tensor = torch.randn(5, 4, device=device, dtype=dtype)

print("Model device:", next(model.parameters()).device)
print("Model dtype:", next(model.parameters()).dtype)

### Saving & Loading (`state_dict()`)

- `model.state_dict()` returns all parameters recursively
- `model.load_state_dict(...)` restores them

This works across the full module tree.

In [None]:
state_dict = model.state_dict()
torch.save(state_dict, "combined_model.pt")

# `train()` vs `eval()` Mode in PyTorch

PyTorch modules have two main modes: **training mode** and **evaluation mode**.  
Switching between them affects layers that behave differently during training and inference.

---

## `model.train()`

- Sets the model to **training mode**.
- Used when training the model with gradient updates.
- Affects certain layers, such as:

| Layer Type        | Behavior in `train()` Mode                  |
|------------------|--------------------------------------------|
| `Dropout`         | Randomly zeroes some activations           |
| `BatchNorm`       | Updates running statistics (mean/variance) |

- Gradients are computed as usual.

---

## `model.eval()`

- Sets the model to **evaluation (inference) mode**.
- Used when evaluating or deploying the model.
- Affects certain layers:

| Layer Type        | Behavior in `eval()` Mode                   |
|------------------|--------------------------------------------|
| `Dropout`         | Passes all activations through unchanged  |
| `BatchNorm`       | Uses stored running mean/variance         |

- No layers update internal statistics.
- Gradients are usually not required (often used with `torch.no_grad()`).

---

## Key Points

- Always use `model.train()` during training.
- Always use `model.eval()` during evaluation or testing.
- Forgetting to switch can lead to inconsistent results, especially with `Dropout` or `BatchNorm`.



In [None]:
import torch
import torch.nn as nn

# Simple model with Dropout and BatchNorm
class SimpleModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.bn = nn.BatchNorm1d(8)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.bn(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


model = SimpleModel()
x = torch.randn(5, 4)

# Training mode
model.train()
output_train = model(x)
print("Training mode output:\n", output_train)

# Evaluation mode
model.eval()
with torch.no_grad():
    output_eval = model(x)
print("Evaluation mode output:\n", output_eval)


# `torch.no_grad()` and `torch.inference_mode()` in PyTorch

When performing inference (evaluating a model without updating parameters), PyTorch provides context managers to **disable gradient tracking**. This saves memory and speeds up computation.

---

## `torch.no_grad()`

- Disables gradient tracking.
- Useful during evaluation or inference.
- Gradients are **not computed**, but autograd still tracks operations for some internal purposes.
- Can be used as a **context manager** or a **function decorator**.

---

## `torch.inference_mode()`

- Introduced in PyTorch 1.9.
- Similar to `no_grad()`, but **more efficient**.
- Completely disables autograd and reduces memory usage.
- Recommended for pure inference pipelines.



In [None]:
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(4, 2)

    # -----------------------------
    # Using torch.no_grad() as method decorator
    # -----------------------------
    @torch.no_grad()
    def forward_no_grad(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

    # -----------------------------
    # Using torch.inference_mode() as method decorator
    # -----------------------------
    @torch.inference_mode()
    def forward_inference(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)


model = SimpleModel()
x = torch.randn(5, 4)

# -----------------------------
# Call decorated methods
# -----------------------------
output_no_grad_method = model.forward_no_grad(x)
output_infer_method = model.forward_inference(x)

print("Output no_grad method:\n", output_no_grad_method)
print("Output inference_mode method:\n", output_infer_method)

# -----------------------------
# Using context managers
# -----------------------------

with torch.no_grad():
    output_no_grad_cm = model(x)

with torch.inference_mode():
    output_infer_cm = model(x)

print("Output no_grad context manager:\n", output_no_grad_cm)
print("Output inference_mode context manager:\n", output_infer_cm)


# Disabling Gradients with `requires_grad_(False)`

PyTorch provides a convenient method `requires_grad_()` that can **enable or disable gradients in-place** for all parameters of a model or a tensor.

Using:

```python
param.requires_grad_(False)
```

- Sets `requires_grad=False` **in-place** for that parameter.
- This is useful for freezing models during inference or transfer learning.
- Can be applied to an entire model recursively by iterating over its parameters.


In [None]:
model = SimpleModel()

# Disable gradient computation for all parameters using requires_grad_()
for param in model.parameters():
    param.requires_grad_(False)

# Verify
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Forward pass still works
x = torch.randn(5, 4)
output = model(x)
print("Output shape:", output.shape)

# Redefining `train()` and `eval()` in `nn.Module`

PyTorch’s `nn.Module` provides built-in `train(mode: bool = True)` and `eval()` methods to switch between **training** and **evaluation** modes.  

Sometimes, when creating **custom modules or blocks**, you might want to **override these methods** to perform extra actions whenever the mode changes.

---

## Why Override?

- Apply mode-specific logic to sub-blocks or attributes that are not standard layers
- Log or track mode switches
- Automatically modify internal flags or buffers along with training/eval mode

---

## How It Works

- `train(mode: bool = True)` sets `self.training = mode` for the module
- `eval()` is equivalent to `train(False)`
- Default implementation recursively calls `train(mode)` on all submodules
- Overriding allows custom behavior while keeping recursive updates intact


In [None]:
import torch
import torch.nn as nn
from torch import Tensor
from typing import Self

class CustomBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(4, 4)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        x = torch.relu(x)
        x = self.dropout(x)
        return x

    # -----------------------------
    # Override train() method
    # -----------------------------
    def train(self, mode: bool = True) -> Self:
        print(f"CustomBlock set to {'train' if mode else 'eval'} mode")
        super().train(mode)  # Call original method to update submodules
        # Add any custom logic here
        return self

    # -----------------------------
    # Override eval() method
    # -----------------------------
    def eval(self) -> Self:
        print("CustomBlock set to eval mode")
        return super().eval()


# Example usage
model = CustomBlock()
x = torch.randn(2, 4)

# Switch to training mode
model.train()
output_train = model(x)

# Switch to evaluation mode
model.eval()
with torch.no_grad():
    output_eval = model(x)

print("Output training mode:", output_train)
print("Output eval mode:", output_eval)


# Common Module Aggregators in PyTorch

When building neural networks, it is often useful to group multiple layers or submodules together.  
PyTorch provides several **module aggregators** that help organize layers and blocks. The most common ones are:



## `nn.Sequential`

- Holds modules in a sequential order.
- Executes them **in the order they are added** during the forward pass.
- Ideal for simple **stacked layers** with a single input and output.

**Key points:**

- Forward pass is automatically defined.
- Cannot handle multiple inputs or branching.

In [None]:
seq_model = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 2)
)

x = torch.randn(5, 4)
output_seq = seq_model(x)
print("nn.Sequential output shape:", output_seq.shape)


## `nn.ModuleList`

- Holds a **list of modules**.
- Does **not define a forward pass automatically**.
- Useful when you need to **loop over modules**, or have conditional computation.

**Key points:**

- Modules are registered properly, so parameters are tracked.
- You must define your own `forward()`.

In [None]:
class ModuleListModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(4, 8),
            nn.ReLU(),
            nn.Linear(8, 2)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

ml_model = ModuleListModel()
output_ml = ml_model(x)
print("nn.ModuleList output shape:", output_ml.shape)

## `nn.ModuleDict`

- Holds modules in a **dictionary** with string keys.
- Useful for architectures with **named branches**, **dynamic selection**, or **multi-head outputs**.
- Like `ModuleList`, it does **not define a forward pass**.

In [None]:
class ModuleDictModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.branches = nn.ModuleDict({
            "branch1": nn.Linear(4, 8),
            "branch2": nn.Linear(4, 8)
        })
        self.output: nn.Linear = nn.Linear(8, 2)

    def forward(self, x: torch.Tensor, branch_name: str = "branch1") -> torch.Tensor:
        x = self.branches[branch_name](x)
        return self.output(x)

md_model = ModuleDictModel()
output_md = md_model(x, branch_name="branch2")
print("nn.ModuleDict output shape:", output_md.shape)

## Homework

2 задания:
1. Реализуйте требуемый в заголовке блок (максмсум 0.8 балов).

## ResNet Block (0.1 балл)

![Resnet](assets/ResBlock.png)

https://arxiv.org/pdf/1512.03385

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...

## Depthwise Separable Convolution (0.1 балл)
![DepthWiseConv](assets/DepthWiseConv.png)

https://arxiv.org/pdf/1610.02357

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...

## Vanilla Attention (0.1 балл)

Let:

$$
\text{query} \in \mathbb{R}^{B \times d} \\
\text{key} \in \mathbb{R}^{B \times L \times d}
$$

---

### Alignment Scores

$$
\text{score} = \text{key} \cdot (W_\text{align} \, \text{query})^T \\
\text{score} \in \mathbb{R}^{B \times L}
$$

---

### Attention Weights

$$
\text{att} = \text{softmax}(\text{score}, \text{dim}=1) \\
\text{att} \in \mathbb{R}^{B \times L}
$$

---

### Context Vector

$$
\text{context} = \sum_{i=1}^{L} \text{att}_i \cdot \text{key}_i \\
\text{context} \in \mathbb{R}^{B \times d}
$$

---

### Output

$$
\text{out} = \tanh(W_\text{value} \, \text{context} + W_\text{query} \, \text{query}) \\
\text{out} \in \mathbb{R}^{B \times d}
$$



https://arxiv.org/abs/1409.0473


https://arxiv.org/abs/1508.04025

In [None]:
from typing import Optional
import torch
from torch import nn
import numpy as np

class VanillaAttention(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...

## Dot Product Attention (0.1 балл)

$$
Q \in \mathbb{R}^{B \times L_q \times d_k} \\
K \in \mathbb{R}^{B \times L_k \times d_k} \\
V \in \mathbb{R}^{B \times L_k \times d_k}
$$

$$
S = \frac{Q K^T}{\sqrt{d_k}}
$$

$$
\text{Attention}(Q, K, V) = \text{softmax}(S, \text{dim}=-1) \, V
$$



https://arxiv.org/abs/1706.03762


In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...

## Multihead Attention (0.1 балл)

![MultiheadAttention](assets/MultiheadAttention.webp)

https://arxiv.org/abs/1706.03762


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...

## Transformer Encoder Layer (0.1 балл)


![Transformer Encoder Layer](assets/TransformerEncoder.png)


https://arxiv.org/abs/1706.03762

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...


## MLP Mixer (0.1 балл)


![MLPMixer](assets/MLPMixer.png)


https://arxiv.org/abs/2105.01601

In [None]:


class MLPMixerBlock(nn.Module):
    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...


## ConvMixer (0.1 балл)

![ConvMixer](assets/ConvMixer.png)


https://arxiv.org/abs/2201.09792

In [None]:
class ConvMixer(nn.Module):

    def __init__(self, ...):
        super().__init__()


    def forward(self, ...):
        ...


## Вопрос (0.2 балла)

Объясните, почему MLPMixer, ConvMixer может работать почти так же эффективно, как обычный Multihead Attention.

Напишите формулу, связывающую Multihead Attention, ConvMixer и MLPMixer

Опишите преимущества и недостатки между ConvMixer, MLPMixer и Multihead Attention

---

Ответ: ...