## How to Use This Notebook

The first half of the notebook focuses on the **RibonanzaNet Backbone**. This section is adapted from Moth's notebook (which I used to understand the RibonanzaNet). I have added additional comments and covered some topics in greater detail, such as **Triangular Attention**, to make it more beginner-friendly.

The second half of the notebook delves into **RibonanzaNet 2.0**, which incorporates a **Diffusion Model** alongside the RibonanzaNet backbone. This too is beginner-friendly, with detailed explanations, comments and diagrams.

# RibonanzaNet Explained

***

Welcome! 👋🏼

In this tutorial we will try to understand all the building blocks that compose the `RibonanzaNet` architecture, proposed by @shujun717 et al., which unifies features of `RNAdegformer` and top Kaggle models (from last year's Stanford Ribonanza RNA Folding competition) into a single, self-contained model.

In the *Table of Contents* you will find links to all the building blocks involved, the block's description, purpose and diagram. Each class is detailed with the different input definitions, tensor shapes and data types to better understand them.

Let's get started!

### Table of Contents
<div>
    <li><a href="#Introduction">Introduction</a></li>
    <li><a href="#Ribonanza-Backbone-Architecture">Ribonanza Backbone Architecture</a></li>
    <li><a href="#2-Outer-Product-Mean">Outer Product Mean</a></li>
    <li><a href="#4-Relative-Positional-Encoding">Relative Positional Encoding</a></li>
    <li><a href="#5-Transformer-Encoder">Transformer Encoder</a></li>
    <li><a href="#f-Triangular-Multiplicative-Module">Triangle Multiplicative Module</a></li>
    <li><a href="#Bonus-Triangular-Attention">Triangle Attention</a></li>
    <li><a href="#RibonanzaNet-Backbone">Ribonanza Net Backbone</a></li>
    <li><a href="#RibonanzaNet-20">RibonanzaNet 2.0</a></li>
    <li><a href="#Diffusion-Basics">Diffusion Basics</a></li>
    <li><a href="#RibonanzaNet-20-Architecture">RibonanzaNet 2.0 Architecture</a></li>
    <li><a href="#Time-Embedder">Time Embedder</a></li>
    <li><a href="#Embed-Pairwise-Distances">Embed Pairwise Distances</a></li>
    <li><a href="#Structure-Module">Structure Module</a></li>
    <li><a href="#Full-Model-Understanding">Full Model - Training - Inference logic</a></li>
</div>


# Introduction [↑](#top) 
***
## Some Basics first

1. **Residue/Nucleotide** - A single unit in a chain. For RNA residue is a nucleotide. For example, in the sequence `AUGC`, `A`, `U`, `G` and `C` are residues.
2. **BPP**: In a sequence, any two residues are called a pair. For some pairs reactivity might be high, for some it may be low. If we know all such pairs for which reactivity is high, it helps us to understand the structure of the RNA, and in some other tasks too. So, it is an important data. For a sequence of length `n`, we represent the BPP as a matrix of size `n*n`. The value at position `(i,j)` in the matrix is the reactivity of the pair of residues `i` and `j`. Pair with high reactivity are called Base Pair. This data is called Base Pair Probability (BPP) data. **TLDR** - We know that this Pair information is important and you'll see this in every other architecture. 
***

[RibonanzaNet][4] was proposed by Shujun He (competition host) et al. in their paper *"Ribonanza: deep learning of RNA structure through dual crowdsourcing"*.

**TLDR** - The previous Kaggle competition was to predict the Chemical Reactivity of RNA sequences. They took the top three models from the competition. Below are the tree models:

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/1.png" width=800 class="center">

In this tutorial, we will break down the different components that create the `RibonanzaNet`, specially its main layer/block: the `ConvTransformerEncoderLayer`.

<img src="https://kaggle-images.s3.us-west-2.amazonaws.com/ribonanza-3d/ribonanza_diagram.png" width=800 class="center">

### References

- [RibonanzaNet code][1]
- [How does DeepMind AlphaFold2 work?][2]
- [AlphaFold v2 Github Repository][3]
- [Ribonanza paper][4]
- [AlphaFold2 complementary paper][5]

[1]: https://github.com/Shujun-He/RibonanzaNet/blob/main/Network.py
[2]: https://borisburkov.net/2021-12-25-1/
[3]: https://github.com/google-deepmind/alphafold
[4]: https://www.biorxiv.org/content/10.1101/2024.02.24.581671v1.full.pdf
[5]: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf

# Import Libraries [↑](#top) 

***

In [1]:
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml


from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from functools import partialmethod
from torch import einsum
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

# Ribonanza Backbone Architecture
***

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/2.png" width=400 class="center">

We will go through the architecture step by step, like the inputs flowing form the start to the end.

Few Conventions: B-> Batch Size, T or S-> Sequence Length, C-> Number of Channels

## Steps:

### 1. **Sequence Embeddings**
`nn.Embedding(config.ntoken, config.ninp, padding_idx=4)`. Shape (B, T, 256)

### 2. **Outer Product Mean**

One core idea of the `RibonanzaNet` was not to use any precalculated features like BPP but to create a representation of the pairs. If we have a sequence of length `T`, we can create a matrix of size `T*T` to represent the pairs.

- The `OuterProductMean` layer is used to create a matrix of size `T*T` from the sequence embeddings. The output of this layer is a tensor of shape `(B, T, T, 2*C)`.

#### Example:
I had to understand it using pen and paper. Below is an image of the process. 

Taking an example sequence of length 4, with `C=2`. Input is `[4x2]`.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/3.png" width=800 class="center">

- Finally, we get a matrix of size `[4x4x4]` with the values of the pairs. The output is a tensor of shape `(B, T, T, 2*C)`.
- The logic is for the Outer Product Mean. However, we also use linear layers as computing this in full 256 channels is not very efficient.

#### Steps:
1. **Input Transformation**: `(B, T, 256)` → **Linear Layer** → `(B, T, 32)`
2. **Outer Product Mean**: `(B, T, 32)` → **Outer Product** → `(B, T, T, 32*2)`
3. **Output Transformation**: `(B, T, T, 64)` → **Linear Layer** → `(B, T, T, 64)`

---

### Some History:
The `OuterProductMean` class was proposed in the paper [Highly accurate protein structure prediction with AlphaFold][1]. It worked differently there, as it was used to update the pair representation through MSA.

[1]: https://www.nature.com/articles/s41586-021-03819-2

In [6]:
class Outer_Product_Mean(nn.Module):
    """
    Outer Product Mean class.
    :param in_dim: Dimensionality of the input sequence representations (default: 256).
    :param dim_msa: Intermediate lower-dimensional representation (default: 32).
    :param pairwise_dim: Final dimensionality of the pairwise output (default: 64).
    """
    def __init__(
        self,
        in_dim: int = 256,
        dim_msa: int = 32,
        pairwise_dim: int = 64
    ):
        super(Outer_Product_Mean, self).__init__()
        self.proj_down1 = nn.Linear(in_dim, dim_msa)  # projects the input sequence representation into a lower dimensional space
        self.proj_down2 = nn.Linear(dim_msa ** 2, pairwise_dim)  # projects the outer product representation (reshaped) to the final pairwise_dim.

    def forward(
        self,
        seq_rep: torch.Tensor,  # shape: (batch_size, seq_length, in_dim)
        pair_rep: torch.Tensor = None  # shape: (batch_size, seq_length, seq_length, pairwise_dim)
    ):
        seq_rep = self.proj_down1(seq_rep)  # output shape: (batch_size, seq_length, dim_msa)
        outer_product = torch.einsum('bid,bjc -> bijcd', seq_rep, seq_rep)  # output shape: (batch_size, seq_length, seq_length, dim_msa, dim_msa)
        outer_product = rearrange(outer_product, 'b i j c d -> b i j (c d)')  # flattens the last two dimensions: (batch_size, seq_length, seq_length, dim_msa * 2).
        outer_product = self.proj_down2(outer_product)  # output shape: (batch_size, seq_length, seq_length, pairwise_dim)

        if pair_rep is not None:
            outer_product = outer_product + pair_rep

        return outer_product 

### **3. Pairwise Representation**
Basically the Outer Product Mean layer creates a matrix of size `T*T` to represent the pairs. actually `T*T*2*C`

### **4. Relative Positional Encoding**

In RNA using the absolute positions is not needed, as a sequence `AUGCAU` will have same structure as it's reverse `UACGUA`. We don't need the conventional positional encoding we use in transformers. There is this symmetry.
 
What we need is the **Relative positions of the nucleotides**. 
- For example, in the sequence `AUGCAU`, the relative position of `A` and `U` is `1`, while the relative position of `A` and `C` is `2`. 
- We need to encode the relative positions of the nucleotides.
- The interactions of these nucleotides are important for the structure of the RNA.

#### **Example**
Let's take an example of a sequence of length 4. the Pair representation will be of size `4x4`. 

We clip the values less than -8 and more than 8. So in total we have 17 values.

We do something like 
a[i,j] = i - j

\begin{bmatrix}
0 & -1 & -2 & -3 \\
1 & 0 & -1 & -2 \\
2 & 1 & 0 & -1 \\
3 & 2 & 1 & 0 \\
\end{bmatrix}


This is integer level relative positions, we don't add it directly to Pairwise Matrix. What we do is encode them as one-hot vectors. So, we have 17 values, and we can represent them as one-hot vectors of size 17.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/4.png" width=400 class="center">

1. `TxT` -> **One Hot Encoding** -> `TxTx17`
2. Add a **Linear layer** from `17` to pairwise representation size `(64)`. `TxTx17` -> `TxTx64`.



In [7]:
class relpos(nn.Module):
    """
    Implements relative positional encoding for sequence-based models.
    :param dim: (int) The output embedding dimension. Default is 64.
    """
    
    def __init__(self, dim: int = 64):
        super(relpos, self).__init__()
        self.linear = nn.Linear(33, dim)  # (17,) -> (dim,)

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        """
        Computes the relative positional encodings for a given sequence.

        :param src: Input tensor of shape (B, L, D), where:
            - B: Batch size
            - L: Sequence length
            - D: Feature dimension (ignored in this module)
        :return: Relative positional encoding of shape (L, L, dim)
        """
        L = src.shape[1]  # Sequence length
        res_id = torch.arange(L, device=src.device).unsqueeze(0)  # (1, L)
        
        device = res_id.device
        bin_values = torch.arange(-16, 17, device=device)  # (33,)

        d = res_id[:, :, None] - res_id[:, None, :]  # (1, L, L)
        bdy = torch.tensor(16, device=device)

        # Clipping the values within the range [-16, 16]
        d = torch.minimum(torch.maximum(-bdy, d), bdy)  # (1, L, L)

        # One-hot encoding of relative positions
        d_onehot = (d[..., None] == bin_values).float()  # (1, L, L, 33)

        assert d_onehot.sum(dim=-1).min() == 1  # Ensure proper one-hot encoding

        # Linear transformation to embedding space
        p = self.linear(d_onehot)  # (1, L, L, 33) -> (1, L, L, dim)

        return p.squeeze(0)  # (L, L, dim)


### **5. Add**
Now, We add this to the pairwise representation. The pairwise representation will have **some info** about the relative positions of the nucleotides.

### **5. Transformer Encoder**
The `ConvTransformerEncoderLayer` is the main layer of the `RibonanzaNet`.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/5.png" width=600 class="center">

### **a. 1D Convolution**

Simple 1D convolution.

```python
self.conv = nn.Conv1d(d_model, d_model, k, padding=k // 2)
src = src + self.conv(src.permute(0, 2, 1)).permute(0, 2, 1)  # Shape: (batch_size, seq_len, d_model)
```

### **b. Linear**

Simple linear layer. This is used to project the pairwise representation to the attention layer. Pairwise representation is of size `TxTx64` and we need to project it to `T x T x num_heads`.

We are adding this as a bias to the attention layer. So MHA will have size of `T x T x num_heads`.

### **c. MHA**

Multi head attention

Scaled Dot Product Attention woth Multiple heads. This is simply the attention, Nothing special here. Pretty standard.

We will the the Pairwise represntation from the previous step as the attention bias.

```python
if mask is not None:
    attn = attn + mask  # Apply bias mask (B, nhead, L, L)
```

`mask` is the attention bias, or the pairwise representation.

In [8]:
class ScaledDotProductAttention(nn.Module):
    '''
    Scaled Dot-Product Attention module, computing attention scores based on query and key similarity.
    '''
    
    def __init__(self, temperature: float, attn_dropout: float = 0.1) -> None:
        """
        Initializes the Scaled Dot-Product Attention module.
        
        :param temperature: Scaling factor for the dot product attention scores.
        :param attn_dropout: Dropout rate applied to attention weights.
        """
        super().__init__()
        self.temperature: float = temperature
        self.dropout: nn.Dropout = nn.Dropout(attn_dropout)

    def forward(
        self, 
        q: torch.Tensor,  # (B, nhead, L, d_k) or (B, nH, T, C) T is time, C is num_channels.
        k: torch.Tensor,  # (B, nhead, L, d_k)
        v: torch.Tensor,  # (B, nhead, L, d_v)
        mask: torch.Tensor | None = None,  # (B, 1, L, L) or None
        attn_mask: torch.Tensor | None = None  # (B, 1, L, L) or None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the Scaled Dot-Product Attention.
        
        :param q: Query tensor of shape (B, nhead, L, d_k), where B is batch size, nhead is the number of attention heads,
                  L is the sequence length, and d_k is the key/query dimension.
        :param k: Key tensor of shape (B, nhead, L, d_k).
        :param v: Value tensor of shape (B, nhead, L, d_v), where d_v is the value dimension.
        :param mask: Optional bias mask tensor of shape (B, 1, L, L), used for causal masking or padding.
        :param attn_mask: Optional attention mask tensor of shape (B, 1, L, L), where -1 values indicate positions to mask.
        :return: Tuple containing:
            - output (torch.Tensor): The result of the attention mechanism, shape (B, nhead, L, d_v).
            - attn (torch.Tensor): Attention weights after softmax and dropout, shape (B, nhead, L, L).
        """
        
        attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature  # (B, nhead, L, L)
        
        if mask is not None:
            attn = attn + mask  # Apply bias mask (B, nhead, L, L)
        
        if attn_mask is not None:
            attn = attn.float().masked_fill(attn_mask == -1, float('-1e-9'))  # Apply attention mask (B, nhead, L, L)
        
        attn = self.dropout(F.softmax(attn, dim=-1))  # (B, nhead, L, L)
        output = torch.matmul(attn, v)  # (B, nhead, L, d_v)
        
        return output, attn

# MultiHead Attention[↑](#top) 

***

The same process as before can be repeated many times with different Key, Query, and Value projections, forming what is called a multi-head attention. Each head can focus on different projections of the input embeddings. Multihead attention extends self-attention by applying multiple attention mechanisms (or "heads") in parallel. Each head learns different attention patterns, which are then combined to produce a more expressive representation.

<img src="https://kaggle-images.s3.us-west-2.amazonaws.com/introduction-to-transformers/multihead_attention.png" width="400" class="center">

### Input shapes

The inputs `q`, `k`, and `v` (query, key, value) have the following shapes:
- `q`: [bs, len_q, d_model]
- `k`: [bs, len_k, d_model]
- `v`: [bs, len_v, d_model]

Where:
- `bs` is the batch size (first dimension)
- `len_q` is the sequence length of the query
- `len_k` is the sequence length of the key
- `len_v` is the sequence length of the value (typically equal to `len_k`)
- `d_model` is the model's embedding dimension

The module then projects these inputs into multiple heads:
- Each head has dimension `d_k` for queries and keys
- Each head has dimension `d_v` for values
- There are `n_head` different attention heads

The attention calculations happen in a shape of `[bs, n_head, len_q, len_k]` and the output has the same dimensionality as the input query: `[bs, len_q, d_model]`.

This is a standard multi-head attention implementation where vectors are projected into multiple subspaces, attention is calculated separately in each subspace, and then the results are concatenated and projected back to the original dimension.

### Important Note
The original RibonanzaNet Model shared initially by Shijun is changed a little bit. We are looking at the new Diffusion model's backbone. Some code is commented out.


In [14]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module
    :param d_model: The number of input features. or C num_channels in input.
    :param n_head: The number of heads to use.
    :param d_k: The dimensionality of the keys.
    :param d_v: The dimensionality of the values.
    :param dropout: The dropout rate to apply to the attention weights.
    """
    def __init__(
        self,
        d_model: int,
        n_head: int,
        d_k: int,
        d_v: int,
        dropout: float = 0.1
    ):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)  # (d_model) -> (n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)  # (d_model) -> (n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)  # (d_model) -> (n_head * d_v)
        # self.fc = nn.Linear(n_head * d_v, d_model, bias=False)  # (n_head * d_v) -> (d_model)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        # self.dropout = nn.Dropout(dropout)
        # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)


    def forward(
        self, 
        q: torch.Tensor,  # Shape: [batch_size, len_q, d_model] # B, T, C
        k: torch.Tensor,  # Shape: [batch_size, len_k, d_model]
        v: torch.Tensor,  # Shape: [batch_size, len_v, d_model]
        mask: Optional[torch.Tensor] = None,  # Optional attention mask
        src_mask: Optional[torch.Tensor] = None  # Optional source mask
               
    ) -> Tuple[torch.Tensor, torch.Tensor]:  # Returns (output, attention)

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        bs, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
    
        residual = q  # Shape: [bs, len_q, d_model]

        # Linear projections and reshape to multiple heads
        q = self.w_qs(q).view(bs, len_q, n_head, d_k)  # Shape: [bs, len_q, n_head, d_k]
        k = self.w_ks(k).view(bs, len_k, n_head, d_k)  # Shape: [bs, len_k, n_head, d_k]
        v = self.w_vs(v).view(bs, len_v, n_head, d_v)  # Shape: [bs, len_v, n_head, d_v]

        # Transpose for multi-head attention computation
        q, k, v = (
            q.transpose(1, 2),  # Shape: [bs, n_head, len_q, d_k]
            k.transpose(1, 2),  # Shape: [bs, n_head, len_k, d_k]
            v.transpose(1, 2)
        )  # Shape: [bs, n_head, len_v, d_v]

        if mask is not None:
            mask = mask  # Shape remains unchanged

        if src_mask is not None:
            src_mask=src_mask.clone().unsqueeze(-1).long()
            src_mask[src_mask==0]=-1
            src_mask=src_mask.float()
            #src_mask=src_mask.unsqueeze(-1)#.float()
            attn_mask=torch.matmul(src_mask,src_mask.permute(0,2,1)).unsqueeze(1).long()
            q, attn = self.attention(q, k, v, mask=mask,attn_mask=attn_mask)
        else:
            q, attn = self.attention(q, k, v, mask=mask)
        # the output from the attention is self_atten * v. It should not be called q here. Bad confusing variable naming.
        # Reshape back to original format
        q = q.transpose(1, 2).contiguous().view(bs, len_q, -1)  # Shape: [bs, len_q, n_head * d_v]
        # q = self.dropout(self.fc(q))  # Shape: [bs, len_q, d_model]
        # q += residual  # Shape: [bs, len_q, d_model]

        # q = self.layer_norm(q)  # Shape: [bs, len_q, d_model]

        return q, attn  # Output: [bs, len_q, n_head * d_v], Attention: [bs, n_head, len_q, len_k]


### **d. Position-wise Feedforward Network**

Check the code for 

Simply a feedforward network with **two linear** layers and a **GELU** activation function. 
```python
# Position-wise Feedforward according to architecture diagram
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))  # Shape: (batch_size, seq_len, d_model)
src = src + self.dropout2(src2)
src = self.norm2(src)

```


### **e. Outer Product Mean**

The output of MHA is a sequence back to the original size. `BxTxC` or `BxTx256`. 
Now we have to update the pairwise representation also. We want to update those representation using the MHA output. 

Simply take the Outer Product Mean
#### Steps:
1. **Input Transformation**: `(B, T, 256)` → **Linear Layer** → `(B, T, 32)`
2. **Outer Product Mean**: `(B, T, 32)` → **Outer Product** → `(B, T, T, 32*2)`
3. **Output Transformation**: `(B, T, T, 64)` → **Linear Layer** → `(B, T, T, 64)`


Then add this to the pairwise representation. This helps information flow from the sequence MHA to the pairwise representation.

### **f. Triangular Multiplicative Module**

Triangular Update:

According to my understanding, we aim to establish relationships between pairs of nucleotides. This is achieved using Pairwise Representation. Each pair is influenced by a third nucleotide. While it is true that all nucleotides influence each other, the inspiration for this approach is rooted in geometry. For any consistent structure, three points in space always follow the Triangle Inequality.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/7.png" width=400 class="center">

We are not explicitly using the Triangle Inequality but are inspired by its concept.

---

### **Think in Terms of Edges, Not Nucleotides**

The Pairwise Representation can be thought of as directed edge storage, where `a[i, j]` represents the edge from `i` to `j`.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/6.png" width=300 class="center">

Consider an edge from `i` to `j`:

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/8.png" width=300 class="center">

For this edge, there are two types of edges to consider:

1. **Incoming Edges**: Edges incoming to `i` and `j`.
2. **Outgoing Edges**: Edges outgoing from `i` and `j`.

The influence of these edges on the edge `i → j` is crucial. Therefore, we update the Pairwise Representation using these edges.

---

### **How Do We Update the Pairwise Representation?**

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/9.png" width=800 class="center">

> **Note**: Ignore the `x'` in the image. The edge `i → j` is the focus for now.

#### **Steps**:

1. Take all the **left edges** and apply a linear layer to them.
2. Take all the **right edges** and apply a linear layer to them.
3. Take the **central edge** and apply a linear layer, followed by a sigmoid activation `(0, 1)`. This acts as a **gate**.
4. Multiply the left and right edges with the gate. This determines how much influence the left and right edges have on the central edge.
5. Multiply the left and right influences with each other.
6. Sum the influence of all the edges.
7. Apply a linear layer to the sum to project it back to the Pairwise Representation size.
8. Take the **central edge**, apply a linear layer, and then a sigmoid activation `(0, 1)`. This acts as another **gate**.
9. Multiply the sum (combined influence of left & right edges) with the gate. This determines how much influence the sum has on the central edge.
10. Finally, add this to the Pairwise Representation at `i, j`. Essentially, we are updating the Pairwise Representation of `i → j` using the influence of all the edges.

Do dive into the code to answer whether those two gates are same or not.

---

### **Key Insight**

We are considering all the triangles with respect to the edge `i → j`.

In [16]:
def exists(val):
    return val is not None

def default(val, d):
    return val if val is not None else d

class TriangleMultiplicativeModule(nn.Module):
    """
    This class is applied to the pairwise residue representations, ensuring that the predicted distances 
    between residues adhere to the triangle inequality principle.
    """
    def __init__(
        self,
        *,
        dim: int,
        hidden_dim: Optional[int] = None,
        mix: str = 'ingoing'
    ):
        super().__init__()
        assert mix in {'ingoing', 'outgoing'}, 'mix must be either ingoing or outgoing'

        hidden_dim = default(hidden_dim, dim)
        self.norm = nn.LayerNorm(dim)

        self.left_proj = nn.Linear(dim, hidden_dim)
        self.right_proj = nn.Linear(dim, hidden_dim)
        self.left_gate = nn.Linear(dim, hidden_dim)
        self.right_gate = nn.Linear(dim, hidden_dim)
        self.out_gate = nn.Linear(dim, hidden_dim)

        # Initialize all gating to identity
        for gate in (self.left_gate, self.right_gate, self.out_gate):
            nn.init.constant_(gate.weight, 0.)
            nn.init.constant_(gate.bias, 1.)

        if mix == 'outgoing':
            self.mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
        elif mix == 'ingoing':
            self.mix_einsum_eq = '... k j d, ... k i d -> ... i j d'

        self.to_out_norm = nn.LayerNorm(hidden_dim)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(
        self,
        x: torch.Tensor,                  # (batch_size, seq_len, seq_len, dim)
        src_mask: Optional[torch.Tensor] = None  # (batch_size, seq_len)
    ) -> torch.Tensor:                    # Output: (batch_size, seq_len, seq_len, dim)
        if exists(src_mask):
            src_mask = src_mask.unsqueeze(-1).float()  # (batch_size, seq_len, 1)
            mask = torch.matmul(src_mask, src_mask.permute(0, 2, 1))  # (batch_size, seq_len, seq_len)
            mask = rearrange(mask, 'b i j -> b i j ()')  # (batch_size, seq_len, seq_len, 1)

        assert x.shape[1] == x.shape[2], 'feature map must be symmetrical'
        
        x = self.norm(x)  # (batch_size, seq_len, seq_len, dim)

        left = self.left_proj(x)  # (batch_size, seq_len, seq_len, hidden_dim)
        right = self.right_proj(x) # (batch_size, seq_len, seq_len, hidden_dim)

        if exists(src_mask):
            left = left * mask  # (batch_size, seq_len, seq_len, hidden_dim)
            right = right * mask # (batch_size, seq_len, seq_len, hidden_dim)

        left_gate = self.left_gate(x).sigmoid()   # (batch_size, seq_len, seq_len, hidden_dim)
        right_gate = self.right_gate(x).sigmoid() # (batch_size, seq_len, seq_len, hidden_dim)
        out_gate = self.out_gate(x).sigmoid()     # (batch_size, seq_len, seq_len, hidden_dim)

        left = left * left_gate   # (batch_size, seq_len, seq_len, hidden_dim)
        right = right * right_gate # (batch_size, seq_len, seq_len, hidden_dim)

        out = einsum(self.mix_einsum_eq, left, right)  # (batch_size, seq_len, seq_len, hidden_dim)

        out = self.to_out_norm(out)  # (batch_size, seq_len, seq_len, hidden_dim)
        out = out * out_gate         # (batch_size, seq_len, seq_len, hidden_dim)
        return self.to_out(out)      # (batch_size, seq_len, seq_len, dim)


### **g. Triangular Update**

Same for Incoming Edges, Column wise in code instead of row wise.

### **Bonus: Triangular Attention**

This is not actually used in the RibonanzaNet. It is present but disabled by the original notebook. But since we are here, let's understand it.

Simple attention is to capture relationships between the tokens, here the nucleotides. But we need to capture the relationships between the edges. Each edge is influenced by the other edges or other nucleotides. So, to capture the influence of some other nucleotide on an edge, we need at least 3 nucleotides. Thus the name **Triangular Attention**.

The AlphaFold paper mentions this, again inspired by the triangle inequality. But for me, the explanation above is better.

Pairwise representation is a matrix of size `T*T`, acting as a storage of directed edges. `a[i, j]` is the edge from `i` to `j`.

For each edge, we have a starting node and an ending node.  
We will calculate two attentions:  
1. Around the starting node  
2. Around the ending node  

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/10.png" width=400 class="center">

Brush up on your basics of attention. It will be needed here.  
<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/11.png" width=600 class="center">

**Around the starting node:**  

For understanding, let's just take the one edge `i → j` as the Central Edge. We will calculate the attention around the starting node `i`, specifically for the edge `i → j`.

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/12.png" width=600 class="center">

We have to do the same for each edge. Then this will be called the attention around the starting node.

**What does this mean?**  
<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/13.png" width=400 class="center">

This whole process will be done for the Ending node.(Column wise in code)

In [15]:
class TriangleAttention(nn.Module):
    def __init__(
        self,
        in_dim: int = 128,
        dim: int = 32,
        n_heads: int = 4,
        wise: Literal['row', 'col'] = 'row'
    ):
        """
        Implements Triangle Attention Mechanism.
        :param in_dim: Input feature dimension.
        :param dim: Dimension of query, key, and value per head.
        :param n_heads: Number of attention heads.
        :param wise: Whether to apply row-wise or column-wise attention.
        """
        super(TriangleAttention, self).__init__()
        self.n_heads = n_heads
        self.wise = wise
        self.norm = nn.LayerNorm(in_dim)
        self.to_qkv = nn.Linear(in_dim, dim * 3 * n_heads, bias=False)
        self.linear_for_pair = nn.Linear(in_dim, n_heads, bias=False)
        self.to_gate = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.Sigmoid()
        )
        self.to_out = nn.Linear(n_heads * dim, in_dim)

    def forward(self, z: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for TriangleAttention.
        :param z: Input tensor of shape (B, I, J, in_dim). I,J are actually T but here we call them I,J. it's actually B,T,T,C.
        :param src_mask: Source mask of shape (B, I, J).
        :return: Output tensor of shape (B, I, J, in_dim).
        """
        # Spawn pair mask
        src_mask = src_mask.clone()
        src_mask[src_mask == 0] = -1
        src_mask = src_mask.unsqueeze(-1).float()  # (B, I, J, 1)
        attn_mask = torch.matmul(src_mask, src_mask.permute(0, 2, 1))  # (B, I, J, I)

        wise = self.wise
        z = self.norm(z)  # (B, I, J, in_dim)

        # Compute bias and gate
        gate = self.to_gate(z)  # [1] (B, I, J, in_dim)
        b = self.linear_for_pair(z)  # [5] (B, I, J, n_heads) 

        # Compute Q, K, V
        q, k, v = torch.chunk(self.to_qkv(z), 3, -1)  # [2], [3], [4]: each (B, I, J, n_heads * dim)
        q, k, v = map(lambda x: rearrange(x, 'b i j (h d)->b i j h d', h=self.n_heads), (q, k, v))  
        # Each: (B, I, J, n_heads, dim)
        scale = q.size(-1) ** 0.5  # Scalar

        if wise == 'row':
            eq_attn = 'brihd,brjhd->brijh'
            eq_multi = 'brijh,brjhd->brihd'
            b = rearrange(b, 'b i j (r h)->b r i j h', r=1)  # (B, 1, I, J, n_heads)
            softmax_dim = 3
            attn_mask = rearrange(attn_mask, 'b i j->b 1 i j 1')  # (B, 1, I, J, 1)
        elif wise == 'col':
            eq_attn = 'bilhd,bjlhd->bijlh'
            eq_multi = 'bijlh,bjlhd->bilhd'
            b = rearrange(b, 'b i j (l h)->b i j l h', l=1)  # (B, I, J, 1, n_heads)
            softmax_dim = 2
            attn_mask = rearrange(attn_mask, 'b i j->b i j 1 1')  # (B, I, J, 1, 1)
        else:
            raise ValueError('wise should be col or row!')

        # Compute attention logits
        logits = (torch.einsum(eq_attn, q, k) / scale + b)  # [6], [7] (B, I, J, I, n_heads) or (B, I, J, J, n_heads)
        logits = logits.masked_fill(attn_mask == -1, float('-1e-9'))  # Apply mask

        # Compute attention weights
        attn = logits.softmax(softmax_dim)  # [8] (B, I, J, I, n_heads) or (B, I, J, J, n_heads)

        # Compute attention output
        out = torch.einsum(eq_multi, attn, v)  # [9] (B, I, J, n_heads, dim)
        out = gate * rearrange(out, 'b i j h d-> b i j (h d)')  # [10] (B, I, J, in_dim)

        # Final projection
        z_ = self.to_out(out)  # (B, I, J, in_dim)

        return z_


### **h. Transition**

Nothing special here. Two linear layers.

## Transformer Encoder Block/ ConvTransformer

The output will be
1. Sequence Representation: `B x T x 256`
2. Pairwise Representation: `B x T x T x 64`

In [19]:
class ConvTransformerEncoderLayer(nn.Module):
    """
    A Transformer Encoder Layer with convolutional enhancements and pairwise feature processing.
    """
    
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        pairwise_dimension: int,
        use_triangular_attention: bool,
        dim_msa: int,
        dropout: float = 0.1,
        k: int = 3,
    ):
        """
        :param d_model: Dimension of the input embeddings
        :param nhead: Number of attention heads
        :param dim_feedforward: Hidden layer size in feedforward network
        :param pairwise_dimension: Dimension of pairwise features
        :param use_triangular_attention: Whether to use triangular attention modules
        :param dropout: Dropout rate
        :param k: Kernel size for the 1D convolution
        """
        super(ConvTransformerEncoderLayer, self).__init__()

        # === Attention Layers ===
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model // nhead, d_model // nhead, dropout=dropout)

        # self.linear1 = nn.Linear(d_model, dim_feedforward)
        # self.dropout = nn.Dropout(dropout)
        # self.linear2 = nn.Linear(dim_feedforward, d_model)

        # === Layer Norms ===
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # self.norm3 = nn.LayerNorm(d_model)
        
        # === Dropout Layers ===
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        # self.dropout3 = nn.Dropout(dropout)

        self.pairwise2heads = nn.Linear(pairwise_dimension, nhead, bias=False)
        self.pairwise_norm = nn.LayerNorm(pairwise_dimension)
        self.activation = nn.GELU()

        # self.conv = nn.Conv1d(d_model, d_model, k, padding=k // 2)

        self.triangle_update_out = TriangleMultiplicativeModule(dim=pairwise_dimension, mix='outgoing')
        self.triangle_update_in = TriangleMultiplicativeModule(dim=pairwise_dimension, mix='ingoing')

        self.pair_dropout_out = DropoutRowwise(dropout)
        self.pair_dropout_in = DropoutRowwise(dropout)

        self.use_triangular_attention = use_triangular_attention
        if self.use_triangular_attention:
            self.triangle_attention_out = TriangleAttention(
                in_dim=pairwise_dimension,
                dim=pairwise_dimension // 4,
                wise='row'
            )
            self.triangle_attention_in = TriangleAttention(
                in_dim=pairwise_dimension,
                dim=pairwise_dimension // 4,
                wise='col'
            )

            self.pair_attention_dropout_out = DropoutRowwise(dropout)
            self.pair_attention_dropout_in = DropoutColumnwise(dropout)

        self.outer_product_mean=Outer_Product_Mean(in_dim=d_model,dim_msa=dim_msa,pairwise_dim=pairwise_dimension)

        self.pair_transition = nn.Sequential(
            nn.LayerNorm(pairwise_dimension),
            nn.Linear(pairwise_dimension, pairwise_dimension * 4),
            nn.ReLU(inplace=True),
            nn.Linear(pairwise_dimension * 4, pairwise_dimension)
        )
        # Sequence transition is new
        self.sequence_transititon=nn.Sequential(nn.Linear(d_model,d_model*4),
                                                nn.ReLU(),
                                                nn.Linear(d_model*4,d_model))


    def forward(
        self,
        input
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the ConvTransformerEncoderLayer.

        :param src: Input tensor of shape (batch_size, seq_len, d_model)
        :param pairwise_features: Pairwise feature tensor of shape (batch_size, seq_len, seq_len, pairwise_dimension)
        :param src_mask: Optional mask tensor of shape (batch_size, seq_len)
        :param return_aw: Whether to return attention weights
        :return: Tuple containing processed src and pairwise_features (and optionally attention weights)
        """
        src , pairwise_features, src_mask, return_aw= input

        use_gradient_checkpoint=False

        # src = src * src_mask.float().unsqueeze(-1)  # Shape: (batch_size, seq_len, d_model)
        # res = src  # residual
        # # 1D convolution
        # src = src + self.conv(src.permute(0, 2, 1)).permute(0, 2, 1)  # Shape: (batch_size, seq_len, d_model)
        # src = self.norm3(src)

        # Linear on Pairwise features
        pairwise_bias = self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0, 3, 1, 2) # Shape: (batch_size, n_head, seq_len, seq_len)
        # MHA + Pairwise mask
        src2, attention_weights = self.self_attn(src, src, src, mask=pairwise_bias, src_mask=src_mask)  # Shape: (batch_size, seq_len, d_model)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        
        # Position-wise Feedforward according to architecture diagram. Or Sequence transition
        res=src
        src=self.sequence_transititon(src)
        src = res + self.dropout2(src)
        src = self.norm2(src)
        
        pairwise_features = pairwise_features + self.outer_product_mean(src)  # Shape: (batch_size, seq_len, seq_len, pairwise_dimension)
        #Triangular update
        pairwise_features = pairwise_features + self.pair_dropout_out(self.triangle_update_out(pairwise_features, src_mask))
        pairwise_features = pairwise_features + self.pair_dropout_in(self.triangle_update_in(pairwise_features, src_mask))
        
        if self.use_triangular_attention:
            pairwise_features = pairwise_features + self.pair_attention_dropout_out(self.triangle_attention_out(pairwise_features, src_mask))
            pairwise_features = pairwise_features + self.pair_attention_dropout_in(self.triangle_attention_in(pairwise_features, src_mask))
        
        pairwise_features = pairwise_features + self.pair_transition(pairwise_features)  # Shape: (batch_size, seq_len, seq_len, pairwise_dimension)

        if return_aw:
            return src, pairwise_features, attention_weights  # Shapes: (batch_size, seq_len, d_model), (batch_size, seq_len, seq_len, pairwise_dimension), (batch_size, nhead, seq_len, seq_len)
        else:
            return src, pairwise_features  # Shapes: (batch_size, seq_len, d_model), (batch_size, seq_len, seq_len, pairwise_dimension)


In [20]:
# generate a dummy input for ConvTransformerEncoderLayer and run the forward pass
# if __name__ == "__main__":
# Dummy input
batch_size = 2
seq_len = 10
d_model = 128
dim_msa = 32
pairwise_dimension = 64
nhead = 4
dim_feedforward = 256
use_triangular_attention = True

# Create a random input tensor
src = torch.randn(batch_size, seq_len, d_model)
pairwise_features = torch.randn(batch_size, seq_len, seq_len, pairwise_dimension)
src_mask = torch.ones(batch_size, seq_len)

# Create the ConvTransformerEncoderLayer instance
layer = ConvTransformerEncoderLayer(d_model, nhead, dim_feedforward, pairwise_dimension, use_triangular_attention, dim_msa)

# Forward pass
output = layer((src, pairwise_features, src_mask, False))
print(output[0].shape)  # Output shape: (batch_size, seq_len, d_model)

[2, 10, 10, 64]
[2, 1, 10, 64]
[2, 10, 10, 64]
[2, 1, 10, 64]
[2, 10, 10, 64]
[2, 1, 10, 64]
[2, 10, 10, 64]
[2, 10, 1, 64]
torch.Size([2, 10, 128])


### **7. Reactivities Head**

Simply a Feedforward network, take the sequence representation and project it to the output size, whatever you want to predict. The important thing is the backbone above. 

They trained the model on the reactivities data, Secondary structure data, and some other things too, this worked well.

## **Some Important Helper modules**

### Dropout

***

For the Pairwise Representation, we need to use dropout differently, when we are performing row level operations we need to use dropout on the rows. When we are performing column level operations, we need to use dropout on the columns.

In [10]:
class Dropout(nn.Module):
    """
    Implementation of dropout with the ability to share the dropout mask
    along a particular dimension.

    If not in training mode, this module computes the identity function.
    """

    def __init__(self, r: float, batch_dim: Union[int, List[int]]):
        """
        Args:
            r:
                Dropout rate
            batch_dim:
                Dimension(s) along which the dropout mask is shared
        """
        super(Dropout, self).__init__()

        self.r = r
        if type(batch_dim) == int:
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim
        self.dropout = nn.Dropout(self.r)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x:
                Tensor to which dropout is applied. Can have any shape
                compatible with self.batch_dim
        """
        shape = list(x.shape)
        print(shape)
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd] = 1
        print(shape)
        mask = x.new_ones(shape)
        mask = self.dropout(mask)
        x = x * mask
        return x


class DropoutRowwise(Dropout):
    """
    Convenience class for rowwise dropout as described in subsection
    1.11.6.
    """

    __init__ = partialmethod(Dropout.__init__, batch_dim=-3)


class DropoutColumnwise(Dropout):
    """
    Convenience class for columnwise dropout as described in subsection
    1.11.6.
    """

    __init__ = partialmethod(Dropout.__init__, batch_dim=-2)

### Utils[↑](#top) 

In [22]:
class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)
        

def default(val: Any, d: Any) -> Any:
    """
    Returns `val` if it is not None, otherwise returns the default value `d`.
    :param val: The primary value.
    :param d: The default value to return if `val` is None.
    :return: `val` if it is not None, otherwise `d`.
    """
    return val if exists(val) else d


def exists(val: Any) -> bool:
    """
    Checks whether a given value is not None.
    :param val: The value to check.
    :return: True if `val` is not None, otherwise False.
    """
    return val is not None


def init_weights(m: torch.nn.Module) -> None:
    """
    Initializes the weights of a given module if it is an instance of `torch.nn.Linear`. 
    Currently, the function does not apply any initialization but has commented-out 
    Xavier initialization methods.
    :param m: The module to initialize, expected to be a `torch.nn.Linear` instance.
    :return: None
    """
    if m is not None and isinstance(m, nn.Linear):
        pass


def load_config_from_yaml(file_path):
    """Load YAML file"""
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)


def sep():
    print("—"*100)
    
def recursive_linear_init(m,scale_factor):
    for child_name, child in m.named_modules():
        if 'gate' not in child_name:
            custom_weight_init(child,scale_factor)
            
def custom_weight_init(m, scale_factor):
    if isinstance(m, nn.Linear):
        d_model = m.in_features  # Set d_model to the input dimension of the linear layer
        upper = 1.0 / (d_model ** 0.5) * scale_factor
        lower = -1.0 / (d_model ** 0.5) * scale_factor
        torch.nn.init.uniform_(m.weight, lower, upper)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

# RibonanzaNet Backbone

In [None]:
import torch.utils.checkpoint as checkpoint

class RibonanzaNet(nn.Module):

    #def __init__(self, ntoken=5, nclass=1, ninp=512, nhead=8, nlayers=9, kmers=9, dropout=0):
    def __init__(self, config):

        super(RibonanzaNet, self).__init__()
        self.config=config
        nhid=config.ninp*4
        self._tied_weights_keys = [] #avoids AttributeError: 'RibonanzaNet' object has no attribute '_tied_weights_keys'

        self.transformer_encoder = []
        print(f"constructing {config.nlayers} ConvTransformerEncoderLayers")
        for i in range(config.nlayers):
            if i!= config.nlayers-1:
                k=5
            else:
                k=1
            #print(k)
            self.transformer_encoder.append(ConvTransformerEncoderLayer(d_model = config.ninp, nhead = config.nhead,
                                                                        dim_feedforward = nhid, 
                                                                        pairwise_dimension= config.pairwise_dimension,
                                                                        use_triangular_attention=config.use_triangular_attention,
                                                                        dim_msa=config.dim_msa,
                                                                        dropout = config.dropout, k=k))
        self.transformer_encoder= nn.ModuleList(self.transformer_encoder)
        
        for i,layer in enumerate(self.transformer_encoder):
            scale_factor=1/(i+1)**0.5
            #scale_factor=i+1
            #scale_factor=0
            recursive_linear_init(layer,scale_factor)
        
        self.encoder = nn.Embedding(config.ntoken, config.ninp, padding_idx=4)
        self.decoder = nn.Linear(config.ninp,config.nclass)
        recursive_linear_init(self.decoder,scale_factor)
        
        self.outer_product_mean=Outer_Product_Mean(in_dim=config.ninp,dim_msa=config.dim_msa,pairwise_dim=config.pairwise_dimension)
        self.pos_encoder=relpos(config.pairwise_dimension)
        self.use_gradient_checkpoint=False
        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(inputs[0])
            return inputs
        return custom_forward

    def forward(self, src,src_mask=None,return_aw=False):
        B,L=src.shape
        src = src
        src = self.encoder(src).reshape(B,L,-1)
        
        #spawn outer product
        # outer_product = torch.einsum('bid,bjc -> bijcd', src, src)
        # outer_product = rearrange(outer_product, 'b i j c d -> b i j (c d)')
        # print(outer_product.shape)
        pairwise_features=self.outer_product_mean(src)
        pairwise_features=pairwise_features+self.pos_encoder(src)
        # print(pairwise_features.shape)
        # exit()

        attention_weights=[]
        for i,layer in enumerate(self.transformer_encoder):
            src,pairwise_features=layer(src, pairwise_features, src_mask,return_aw=return_aw,use_gradient_checkpoint=self.use_gradient_checkpoint)

        output = self.decoder(src).squeeze(-1)+pairwise_features.mean()*0

        if return_aw:
            return output, attention_weights
        else:
            return output
        
    def get_embeddings(self, src,src_mask=None,return_aw=False):
        B,L=src.shape
        src = src
        src = self.encoder(src).reshape(B,L,-1)
        
        #spawn outer product
        # outer_product = torch.einsum('bid,bjc -> bijcd', src, src)
        # outer_product = rearrange(outer_product, 'b i j c d -> b i j (c d)')
        # print(outer_product.shape)
        if self.use_gradient_checkpoint:
            #print("using grad checkpointing")
            pairwise_features=checkpoint.checkpoint(self.custom(self.outer_product_mean), src)
            pairwise_features=pairwise_features+self.pos_encoder(src)
        else:
            pairwise_features=self.outer_product_mean(src)
            pairwise_features=pairwise_features+self.pos_encoder(src)
        # print(pairwise_features.shape)
        # exit()

        attention_weights=[]
        for i,layer in enumerate(self.transformer_encoder):
            # src,pairwise_features=layer(src, pairwise_features, src_mask,return_aw=return_aw,use_gradient_checkpoint=self.use_gradient_checkpoint)
            src,pairwise_features=checkpoint.checkpoint(self.custom(layer), [src, pairwise_features, src_mask, return_aw],use_reentrant=False)#print(src.shape)
        #output = self.decoder(src).squeeze(-1)+pairwise_features.mean()*0


        return src, pairwise_features

We have done till the backbone of the RibonanzaNet. Now we will look at the RibonanzaNet 2.0. Some changes in the architecture and some new modules.

# RibonanzaNet 2.0
***

# Diffusion Basics

***

A diffusion model is a generative model that learns to reverse a gradual noising process, transforming random noise into structured data such as images or molecular structures. It works by simulating a Markov chain that progressively adds noise to data and then learns to denoise it step by step.

---

## Training Process

We decide a number of steps `T` and a noise schedule $\beta_t \in [0, 1]$ for each step $t$. And $\alpha_t = 1 - \beta_t$. We define $\bar{\alpha}_t = \prod_{s=1}^t \alpha_s$.

We need this $\bar{\alpha}_t$ to calculate the forward diffusion process.

**Forward Diffusion Process Equation:**

$$
x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
$$

Where:

- $x_0$ is the original data (e.g., here the correct co-ordinates).
- $x_t$ is the noisy version of the data at step $t$.
- $\epsilon$ is the noise added to the data.

Now, we create a model that takes input as $x_t, t$ and predicts the noise $\epsilon$.

We use a simple MSE loss function to train the model, using the actual noise $\epsilon$ and the predicted noise $\hat{\epsilon}$.

---

## Inference Process

- Start with a random noise $x_T$.
- At each step $t$ (from $T$ down to $1$):
    - Use the trained model to predict the noise $\epsilon$ given $x_t$ and $t$.
    - Compute $x_{t-1}$ using the reverse diffusion process, removing the predicted noise from $x_t$.
- Repeat this process iteratively until reaching $x_0$.
- The final $x_0$ is the generated data (e.g., predicted coordinates).
---

Keep in mind these terms:
- $\epsilon$ is the noise added to the data.
- $x_t$ is the noisy version of the data at step $t$.
- $x_0$ is the original data (e.g., here the correct co-ordinates).
- $t$ is the time step.
- $\beta_t$ is the noise schedule for each step $t$.
- $\alpha_t = 1- \beta_t$.
- $\bar{\alpha}_t$ is the cumulative product of $\alpha_t$ up to time step $t$.
- Forward diffusion formula.

These all are the components of the diffusion model.

# RibonanzaNet 2.0 Architecture

<img src="https://raw.githubusercontent.com/siddhantoon/storage/main/RibonanzaNet-2.0-DDPM-explained/14.png" width=900 class="center">

### Time Embedder
Sinusoidal Positional embeddings for a `768` dimensional vector.
Take an input t (integer timestep of diffusion). Create `768` dim Positional encoded vector for just the single timestep `t`.
Standard positional encoding of transformers but instead of a sequence just a timestep `t`.

In standard transformers, the t is the sequence element. But here we are using `t` as diffusion timestep, so the positional encoding will be same for every element in the sequence. Don't confuse it with the positional encoding of the sequence in transformers.

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

### Embed Pairwise Distances

We have $x_t$ as the input. These are xyz co-ordinates of the nucleotides. Simply take the pairwise distances of the co-ordinates.

`a[i,j]` = euclidean distance between `i-th` and `j-th` nucleotide.

This is a matrix of size `T*T`. Add this to the pairwise representation. Tensor broadcasting will do the work.

In [None]:
# Don't run this, it's part of final class we'll make
def embed_pair_distance(self,inputs):
    pairwise_features,xyz=inputs
    distance_matrix=xyz[:,None,:,:]-xyz[:,:,None,:]
    distance_matrix=(distance_matrix**2).sum(-1).clip(2,37**2).sqrt()
    distance_matrix=distance_matrix[:,:,:,None]
    pairwise_features=pairwise_features+self.distance2pairwise(distance_matrix)

    return pairwise_features

## Structure Module

The diagram is self explanatory.

1. We take diffusion timestep `t` and create a time representation.
2. Sequence features(SF) from the backbone. and apply linear layer to it. 256  -> 768.
3. x_t are xyz co-ordinates at timestep `t`. Apply linear layer to it. 3 -> 768.
4. Combine the above to create a single `tgt` vector, having info about the Sequence features  timestep, and xyz co-ordinates.
5. Take the Pairwise features(PF) and embed xyz distances into it.

The tgt and PF are the inputs to the Structure Module. It's a much Simpler block similar to the above Transformer block in Ribonanza.

Uses the Sequence features for attention and the pairwise features as attention bias. 
Then applies a linear layer.

Output feature is enriched with the sequence features and the pairwise features.

In [None]:
class SimpleStructureModule(nn.Module):

    def __init__(self, d_model, nhead, 
                 dim_feedforward, pairwise_dimension, dropout=0.1,
                 ):
        super(SimpleStructureModule, self).__init__()
        #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)
        #self.cross_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.pairwise2heads=nn.Linear(pairwise_dimension,nhead,bias=False)
        self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        #self.distance2heads=nn.Linear(1,nhead,bias=False)
        #self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        self.activation = nn.GELU()

        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward

    def forward(self, input):
        tgt , src,  pairwise_features, pred_t, src_mask = input
        
        #src = src*src_mask.float().unsqueeze(-1)

        # MHA + Pairwise mask
        pairwise_bias=self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2)
        #print(pairwise_bias.shape,distance_bias.shape)
        #pairwise_bias=pairwise_bias+distance_bias
        res=tgt
        tgt,attention_weights = self.self_attn(tgt, tgt, tgt, mask=pairwise_bias, src_mask=src_mask)
        tgt = res + self.dropout1(tgt)
        tgt = self.norm1(tgt)

        # print(tgt.shape,src.shape)
        # exit()

        # FeedForward network
        res=tgt
        tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = res + self.dropout2(tgt)
        tgt = self.norm2(tgt)
        return tgt

## Full Model Understanding

The task of the model is to predict the noise in xyz co-ordinates.


During training, we will do the forward diffusion process. 

During Inference we will do the reverse diffusion process.


**Training**

1. Take the xyz co-ordinates of the nucleotides. $x_0$.
2. Take a random timestep `t` from `1` to `T`.
3. Run the forward diffusion process to get $x_t$. We have the noise $\epsilon$.
4. Run the model to predict the noise $\hat{\epsilon}$.
5. Calculate the loss using MSE between $\epsilon$ and $\hat{\epsilon}$.
6. Another thing. We apply a Linear layer to the Pairwise representation(calling this distogram) then calculate loss between this and the distance matrix. Objective: to make the pairwise representation as close to the distance matrix as possible. Intuition IDK. Have asked the authors [here](https://www.kaggle.com/code/shujun717/ribonanzanet2-ddpm-training/comments#3204706)
7. The total loss is weighted sum of the two losses. denoising loss + 0.2 * distogram loss.

The full training notebook is present [here](https://www.kaggle.com/code/shujun717/ribonanzanet2-ddpm-training/notebook).

**Inference**

1. Start with a random noise $x_T$.
2. At each step $t$ (from $T$ down to $1$):
    - Use the trained model to predict the noise $\epsilon$ given $x_t$ and $t$.
    - Compute $x_{t-1}$ using the reverse diffusion process, removing the predicted noise from $x_t$.
3. Repeat this process iteratively until reaching $x_0$.
4. The final $x_0$ is the generated data (e.g., predicted coordinates).

The full inference notebook is present [here](https://www.kaggle.com/code/shujun717/ribonanzanet2-ddpm-inference/notebook).

Below is the code for the model. We load the backbone and inherit the class. Then we add the newer parts to it. Structure Module, Time Embedder, Embed Pairwise Distances, adapter, distogram predictor, noising & denoising methods.

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, rnet_config, config, pretrained=False):
        rnet_config.dropout=0.1
        rnet_config.use_grad_checkpoint=True
        super(finetuned_RibonanzaNet, self).__init__(rnet_config)
        if pretrained:
            self.load_state_dict(torch.load(config.pretrained_weight_path,map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1)) 
        self.dropout=nn.Dropout(0.0)

        decoder_dim=config.decoder_dim
        self.structure_module=[
            SimpleStructureModule(
                d_model=decoder_dim, 
                nhead=config.decoder_nhead, 
                dim_feedforward=decoder_dim*4, 
                pairwise_dimension=rnet_config.pairwise_dimension, 
                dropout=0.0) 
            for i in range(config.decoder_num_layers)
        ]
        self.structure_module=nn.ModuleList(self.structure_module)

        self.xyz_embedder=nn.Linear(3,decoder_dim)
        self.xyz_norm=nn.LayerNorm(decoder_dim)
        self.xyz_predictor=nn.Linear(decoder_dim,3)
        
        self.adaptor=nn.Sequential(nn.Linear(rnet_config.ninp,decoder_dim),nn.LayerNorm(decoder_dim))

        self.distogram_predictor=nn.Sequential(nn.LayerNorm(rnet_config.pairwise_dimension),
                                                nn.Linear(rnet_config.pairwise_dimension,40))

        self.time_embedder=SinusoidalPosEmb(decoder_dim)

        self.time_mlp=nn.Sequential(nn.Linear(decoder_dim,decoder_dim),
                                    nn.ReLU(),  
                                    nn.Linear(decoder_dim,decoder_dim))
        self.time_norm=nn.LayerNorm(decoder_dim)

        self.distance2pairwise=nn.Linear(1,rnet_config.pairwise_dimension,bias=False)

        self.pair_mlp=nn.Sequential(nn.Linear(rnet_config.pairwise_dimension,rnet_config.pairwise_dimension),
                                    nn.ReLU(),
                                    nn.Linear(rnet_config.pairwise_dimension,rnet_config.pairwise_dimension))


        #hyperparameters for diffusion
        self.n_times = config.n_times

        #self.model = model
        
        # define linear variance schedule(betas)
        beta_1, beta_T = config.beta_min, config.beta_max
        betas = torch.linspace(start=beta_1, end=beta_T, steps=config.n_times)#.to(device) # follows DDPM paper
        self.sqrt_betas = torch.sqrt(betas)
                                     
        # define alpha for forward diffusion kernel
        self.alphas = 1 - betas
        self.sqrt_alphas = torch.sqrt(self.alphas)
        alpha_bars = torch.cumprod(self.alphas, dim=0)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1-alpha_bars)
        self.sqrt_alpha_bars = torch.sqrt(alpha_bars)

        self.data_std=config.data_std


    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward
    
    def embed_pair_distance(self,inputs):
        pairwise_features,xyz=inputs
        distance_matrix=xyz[:,None,:,:]-xyz[:,:,None,:]
        distance_matrix=(distance_matrix**2).sum(-1).clip(2,37**2).sqrt()
        distance_matrix=distance_matrix[:,:,:,None]
        pairwise_features=pairwise_features+self.distance2pairwise(distance_matrix)

        return pairwise_features

    def forward(self,src,xyz,t):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        
        distogram=self.distogram_predictor(pairwise_features)

        sequence_features=self.adaptor(sequence_features)

        decoder_batch_size=xyz.shape[0]
        sequence_features=sequence_features.repeat(decoder_batch_size,1,1)
        

        pairwise_features=pairwise_features.expand(decoder_batch_size,-1,-1,-1)

        pairwise_features= checkpoint.checkpoint(self.custom(self.embed_pair_distance), [pairwise_features,xyz],use_reentrant=False)

        time_embed=self.time_embedder(t).unsqueeze(1)
        tgt=self.xyz_norm(sequence_features+self.xyz_embedder(xyz)+time_embed)

        tgt=self.time_norm(tgt+self.time_mlp(tgt))

        for layer in self.structure_module:
            #tgt=layer([tgt, sequence_features,pairwise_features,xyz,None])
            tgt=checkpoint.checkpoint(self.custom(layer),
            [tgt, sequence_features,pairwise_features,xyz,None],
            use_reentrant=False)
            # xyz=xyz+self.xyz_predictor(sequence_features).squeeze(0)
            # xyzs.append(xyz)
            #print(sequence_features.shape)
        
        xyz=self.xyz_predictor(tgt).squeeze(0)
        #.squeeze(0)

        return xyz, distogram
    

    def denoise(self,sequence_features,pairwise_features,xyz,t):
        # t is tensor([4, 4, 4, 4, 4]). if the t-th step is 4-th step. and number of samples is 5.
        decoder_batch_size=xyz.shape[0]
        sequence_features=sequence_features.expand(decoder_batch_size,-1,-1)
        pairwise_features=pairwise_features.expand(decoder_batch_size,-1,-1,-1)

        pairwise_features=self.embed_pair_distance([pairwise_features,xyz])

        sequence_features=self.adaptor(sequence_features) # B,T,768
        time_embed=self.time_embedder(t).unsqueeze(1)
        tgt=self.xyz_norm(sequence_features+self.xyz_embedder(xyz)+time_embed)
        tgt=self.time_norm(tgt+self.time_mlp(tgt))
        #xyz_batch_size=xyz.shape[0]
        


        for layer in self.structure_module:
            tgt=layer([tgt, sequence_features,pairwise_features,xyz,None])
            # xyz=xyz+self.xyz_predictor(sequence_features).squeeze(0)
            # xyzs.append(xyz)
            #print(sequence_features.shape)
        xyz=self.xyz_predictor(tgt).squeeze(0)
        # print(xyz.shape)
        # exit()
        return xyz


    def extract(self, a, t, x_shape):
        """
            from lucidrains' implementation
                https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L376
        """
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    
    def scale_to_minus_one_to_one(self, x):
        # according to the DDPMs paper, normalization seems to be crucial to train reverse process network
        return x * 2 - 1
    
    def reverse_scale_to_zero_to_one(self, x):
        return (x + 1) * 0.5
    
    def make_noisy(self, x_zeros, t): 
        # assume we get raw data, so center and scale by 35
        x_zeros = x_zeros - torch.nanmean(x_zeros,1,keepdim=True)
        x_zeros = x_zeros/self.data_std
        #rotate randomly
        x_zeros = random_rotation_point_cloud_torch_batch(x_zeros)


        # perturb x_0 into x_t (i.e., take x_0 samples into forward diffusion kernels)
        epsilon = torch.randn_like(x_zeros).to(x_zeros.device)
        
        sqrt_alpha_bar = self.extract(self.sqrt_alpha_bars.to(x_zeros.device), t, x_zeros.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars.to(x_zeros.device), t, x_zeros.shape)
        
        # Let's make noisy sample!: i.e., Forward process with fixed variance schedule
        #      i.e., sqrt(alpha_bar_t) * x_zero + sqrt(1-alpha_bar_t) * epsilon
        noisy_sample = x_zeros * sqrt_alpha_bar + epsilon * sqrt_one_minus_alpha_bar
    
        return noisy_sample.detach(), epsilon
    
    
    # def forward(self, x_zeros):
    #     x_zeros = self.scale_to_minus_one_to_one(x_zeros)
        
    #     B, _, _, _ = x_zeros.shape
        
    #     # (1) randomly choose diffusion time-step
    #     t = torch.randint(low=0, high=self.n_times, size=(B,)).long().to(x_zeros.device)
        
    #     # (2) forward diffusion process: perturb x_zeros with fixed variance schedule
    #     perturbed_images, epsilon = self.make_noisy(x_zeros, t)
        
    #     # (3) predict epsilon(noise) given perturbed data at diffusion-timestep t.
    #     pred_epsilon = self.model(perturbed_images, t)
        
    #     return perturbed_images, epsilon, pred_epsilon
    
    
    def denoise_at_t(self, x_t, sequence_features, pairwise_features, timestep, t):
        B, _, _ = x_t.shape
        if t > 1:
            z = torch.randn_like(x_t).to(sequence_features.device)
        else:
            z = torch.zeros_like(x_t).to(sequence_features.device)
        
        # at inference, we use predicted noise(epsilon) to restore perturbed data sample.
        epsilon_pred = self.denoise(sequence_features, pairwise_features, x_t, timestep)
        
        alpha = self.extract(self.alphas.to(x_t.device), timestep, x_t.shape)
        sqrt_alpha = self.extract(self.sqrt_alphas.to(x_t.device), timestep, x_t.shape)
        sqrt_one_minus_alpha_bar = self.extract(self.sqrt_one_minus_alpha_bars.to(x_t.device), timestep, x_t.shape)
        sqrt_beta = self.extract(self.sqrt_betas.to(x_t.device), timestep, x_t.shape)
        
        # denoise at time t, utilizing predicted noise
        x_t_minus_1 = 1 / sqrt_alpha * (x_t - (1-alpha)/sqrt_one_minus_alpha_bar*epsilon_pred) + sqrt_beta*z
        
        return x_t_minus_1#.clamp(-1., 1)
                
    def sample(self, src, N):
        """_summary_

        Args:
            src (_type_): SEQUENCE DIRECTLY FROM DATA 'augc'
            N (_type_): num of samples to generate

        Returns:
            _type_: _description_
        """
        # start from random noise vector, NxLx3
        x_t = torch.randn((N, src.shape[1], 3)).to(src.device) #x_T generate from here.
        
        # autoregressively denoise from x_T to x_0
        #     i.e., generate image from noise, x_T

        #first get conditioning #RibonanzaNet Backbone
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        # sequence_features=sequence_features.expand(N,-1,-1)
        # pairwise_features=pairwise_features.expand(N,-1,-1,-1)
        distogram=self.distogram_predictor(pairwise_features).squeeze() # B,T,T,40
        distogram=distogram.squeeze()[:,:,2:40]*torch.arange(2,40).float().cuda() # 
        distogram=distogram.sum(-1)  # T,T -> Or seq,seq

        for t in range(self.n_times-1, -1, -1): #T to .... 0
            timestep = torch.tensor([t]).repeat_interleave(N, dim=0).long().to(src.device)
            x_t = self.denoise_at_t(x_t, sequence_features, pairwise_features, timestep, t)
        
        # denormalize x_0 into 0 ~ 1 ranged values.
        #x_0 = self.reverse_scale_to_zero_to_one(x_t)
        x_0 = x_t * self.data_std
        return x_0, distogram




class SimpleStructureModule(nn.Module):

    def __init__(self, d_model, nhead, 
                 dim_feedforward, pairwise_dimension, dropout=0.1,
                 ):
        super(SimpleStructureModule, self).__init__()
        #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)
        #self.cross_attn = MultiHeadAttention(d_model, nhead, d_model//nhead, d_model//nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.pairwise2heads=nn.Linear(pairwise_dimension,nhead,bias=False)
        self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        #self.distance2heads=nn.Linear(1,nhead,bias=False)
        #self.pairwise_norm=nn.LayerNorm(pairwise_dimension)

        self.activation = nn.GELU()

        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(*inputs)
            return inputs
        return custom_forward

    def forward(self, input):
        tgt , src,  pairwise_features, pred_t, src_mask = input
        
        #src = src*src_mask.float().unsqueeze(-1)

        pairwise_bias=self.pairwise2heads(self.pairwise_norm(pairwise_features)).permute(0,3,1,2)

        


        #print(pairwise_bias.shape,distance_bias.shape)

        #pairwise_bias=pairwise_bias+distance_bias


        res=tgt
        tgt,attention_weights = self.self_attn(tgt, tgt, tgt, mask=pairwise_bias, src_mask=src_mask)
        tgt = res + self.dropout1(tgt)
        tgt = self.norm1(tgt)

        # print(tgt.shape,src.shape)
        # exit()

        res=tgt
        tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = res + self.dropout2(tgt)
        tgt = self.norm2(tgt)


        return tgt
