In [6]:
import torch
import torch.nn as nn

k = 16         # number of memory slots
d = 768        # hidden size (same as BART-base)

# 1. Learnable memory matrix
M = nn.Parameter(torch.randn(k, d))   # <<< THIS is your memory!

# 2. h_z from encoder (just an example here)
h_z = torch.randn(d)  # shape: [768]

# 3. Linear projection: h_z -> logits over memory slots
W = nn.Linear(d, k)   # maps 768 -> 16
logits = W(h_z)       # shape: [16]

# 4. Attention weights over memory
alpha = torch.softmax(logits, dim=0)  # shape: [16]

# 5. Weighted sum over memory slots
#       alpha: [16], M: [16, 768]
#       → weighted sum across memory rows
z = torch.sum(alpha.unsqueeze(1) * M, dim=0)  # shape: [768]


In [8]:
alpha.unsqueeze(1).shape, M.shape, z.shape

(torch.Size([16, 1]), torch.Size([16, 768]), torch.Size([768]))

In [None]:
z.shape

torch.Size([768])

### Here is Link to learn more
[Learning PDF](https://chatgpt.com/share/67f3649b-fef0-8013-9c1c-2e3727affd56) 

In [3]:
####

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set seed for reproducibility
torch.manual_seed(0)

# Dimensions
k = 3    # number of memory slots
d = 4    # hidden size

# Example hidden vector h_z (1 x d)
h_z = torch.tensor([[0.2, 0.4, 0.6, 0.8]], dtype=torch.float32)  # shape: (1, 4)

# Memory matrix M (k x d)
M = torch.tensor([
    [0.5, 0.1, 0.0, 0.3],
    [0.2, 0.4, 0.1, 0.0],
    [0.3, 0.3, 0.3, 0.3]
], dtype=torch.float32)  # shape: (3, 4)

# Linear layer weights W (k x d) and bias (k)
W = torch.tensor([
    [ 0.1, 0.0,  0.2, 0.3],
    [-0.2, 0.5, -0.1, 0.4],
    [ 0.3, 0.2,  0.0, 0.1]
], dtype=torch.float32)  # shape: (3, 4)

b = torch.tensor([0.1, 0.0, -0.2], dtype=torch.float32)  # shape: (3,)

# Step 1: Compute logits
logits = torch.matmul(W, h_z.T).squeeze(1) + b  # shape: (3,)

# Step 2: Softmax to get attention weights
alpha = F.softmax(logits, dim=0)  # shape: (3,)

# Step 3: Compute z as weighted sum over memory M
z = torch.sum(alpha.unsqueeze(1) * M, dim=0)  # shape: (4,)

# Print results
print("h_z:\n", h_z)
print("logits:\n", logits)
print("attention weights (alpha):\n", alpha)
print("final memory vector (z):\n", z)


h_z:
 tensor([[0.2000, 0.4000, 0.6000, 0.8000]])
logits:
 tensor([0.4800, 0.4200, 0.0200])
attention weights (alpha):
 tensor([0.3886, 0.3660, 0.2453])
final memory vector (z):
 tensor([0.3411, 0.2589, 0.1102, 0.1902])


In [16]:
logits,alpha

(tensor([0.4800, 0.4200, 0.0200]), tensor([0.3886, 0.3660, 0.2453]))

In [4]:
# !pip install torchinfo

In [18]:
import torch
import torch.nn as nn

class MemoryAttentionModel(nn.Module):
    def __init__(self, k=16, d=768):
        super(MemoryAttentionModel, self).__init__()
        
        # Learnable memory matrix
        self.M = nn.Parameter(torch.randn(k, d))  # Memory slots of size [16, 768]

        # Linear projection from hidden state to logits over memory slots
        self.W = nn.Linear(d, k)  # Linear layer to map from 768 to 16 (logits)

    def forward(self, h_z):
        # h_z is the encoder output (e.g., from BART or any encoder) of size [768]
        
        # Step 1: Compute logits over memory slots
        logits = self.W(h_z)  # shape: [16]

        # Step 2: Compute attention weights over memory slots (softmax)
        alpha = torch.softmax(logits, dim=0)  # shape: [16]

        # Step 3: Weighted sum of memory slots using attention weights
        z = torch.sum(alpha.unsqueeze(1) * self.M, dim=0)  # shape: [768]

        return z

# Example usage:
model = MemoryAttentionModel(k=16, d=768)
h_z = torch.randn(768)  # Example encoder output (hidden state) of shape [768]
output = model(h_z)  # Output will be a weighted sum of memory slots of shape [768]

print(output.shape)  # Should print torch.Size([768])


torch.Size([768])


In [19]:
import torch
import torch.nn as nn
from torchinfo import summary

class MemoryAttentionModel(nn.Module):
    def __init__(self, k=16, d=768):
        super(MemoryAttentionModel, self).__init__()
        
        # Learnable memory matrix
        self.M = nn.Parameter(torch.randn(k, d))  # Memory slots of size [16, 768]

        # Linear projection from hidden state to logits over memory slots
        self.W = nn.Linear(d, k)  # Linear layer to map from 768 to 16 (logits)

    def forward(self, h_z):
        # h_z is the encoder output (e.g., from BART or any encoder) of size [768]
        
        # Step 1: Compute logits over memory slots
        logits = self.W(h_z)  # shape: [16]

        # Step 2: Compute attention weights over memory slots (softmax)
        alpha = torch.softmax(logits, dim=0)  # shape: [16]

        # Step 3: Weighted sum of memory slots using attention weights
        z = torch.sum(alpha.unsqueeze(1) * self.M, dim=0)  # shape: [768]

        return z

# Create an instance of the model
model = MemoryAttentionModel(k=10, d=1024)

# Use torchinfo to display the model summary with input shape [768] (without batch size dimension)
summary(model, input_size=(1024,))  # Corrected input size


Layer (type:depth-idx)                   Output Shape              Param #
MemoryAttentionModel                     [1024]                    10,240
├─Linear: 1-1                            [10]                      10,250
Total params: 20,490
Trainable params: 20,490
Non-trainable params: 0
Total mult-adds (M): 0.10
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.04
Estimated Total Size (MB): 0.05