# Forward Pass Implementation Analysis

## Objective
Evaluate if the forward pass implementation in PyNNalign works as expected.

## Key Checks
- **Input dimensions**: Validate architecture compatibility
- **Mathematical operations**: Confirm correct calculations per layer
- **Output**: Validate dimensions 

Imports

In [202]:
import torch
from torch import nn
import torch.nn.functional as F

Mock input tensor and mask

In [203]:
B = 2
W = 4
window_size = 9
encoding_dim = 20
D = window_size * encoding_dim

# input: [B, W, D]
x_tensor = torch.randn(B, W, D)

x_mask = torch.tensor([
    [True,  True,  True,  False],  
    [True,  True,  False, False],  
], dtype=torch.bool)

print(x_tensor.shape)
print(x_mask.shape)

torch.Size([2, 4, 180])
torch.Size([2, 4])


Forward pass layers

In [204]:
in_layer = nn.Linear(window_size * encoding_dim, 66)
out_layer = nn.Linear(66, 1)
activation = nn.ReLU()

First affine transformation

In [205]:
z = in_layer(x_tensor)
print(z.shape)

torch.Size([2, 4, 66])


Activation

In [206]:
z = activation(z)
print(z.shape)

torch.Size([2, 4, 66])


Second affine transformation

In [207]:
z = out_layer(z).squeeze(-1)
print(z.shape)
print(z)

torch.Size([2, 4])
tensor([[ 0.3562, -0.0490, -0.2025,  0.1234],
        [ 0.0281,  0.3740,  0.3633, -0.1988]], grad_fn=<SqueezeBackward1>)


Here we have a 3d tensor, with a loggit for each window of each peptide. We have to find the index of the window with the highest score. Below the scores tensor.

In [208]:
s = torch.sigmoid(z)
print(s.shape)
print(s)

torch.Size([2, 4])
tensor([[0.5881, 0.4878, 0.4495, 0.5308],
        [0.5070, 0.5924, 0.5898, 0.4505]], grad_fn=<SigmoidBackward0>)


Apply the mask to set to 0 the scores of peptides with padding

In [209]:
s = s.masked_fill(~x_mask, float("-inf"))
print(s)

tensor([[0.5881, 0.4878, 0.4495,   -inf],
        [0.5070, 0.5924,   -inf,   -inf]], grad_fn=<MaskedFillBackward0>)


Find the max scores in dim=1

In [210]:
max_idx = s.argmax(dim=1).unsqueeze(1)
print(max_idx)
print(max_idx.shape)

tensor([[0],
        [1]])
torch.Size([2, 1])


Use torch.gather to select the max loggit

In [211]:
z = torch.gather(z, 1, max_idx).squeeze(-1)

print(z)
print(z.shape)

tensor([0.3562, 0.3740], grad_fn=<SqueezeBackward1>)
torch.Size([2])
