#### Understanding PyTorch Buffers

###### In essence, PyTorch buffers are tensor attributes associated with a PyTorch module or model similar to parameters, but unlike parameters, buffers are not updated during training.

###### Buffers in PyTorch are particularly useful when dealing with GPU computations, as they need to be transferred between devices (like from CPU to GPU) alongside the model's parameters. Unlike parameters, buffers do not require gradient computation, but they still need to be on the correct device to ensure that all computations are performed correctly.

###### In chapter 3, we use PyTorch buffers via self.register_buffer, which is only briefly explained in the book. Since the concept and purpose are not immediately clear, this code notebook offers a longer explanation with a hands-on example.

##### An example without buffers
###### Suppose we have the following code, which is based on code from chapter 3. This version has been modified to exclude buffers. It implements the causal self-attention mechanism used in LLMs:

In [None]:
import torch
import torch.nn as nn

class CausalAttentionWithoutBuffers(nn.Module):
   
   def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
       super().__init__()
       self.d_out = d_out
       self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
       self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
       self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
       self.dropout = nn.Dropout(dropout)
       self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

   def forward(self, x):
       b, num_tokens, d_in = x.shape
       keys = self.W_key
       queries = self.W_query
       values = self.W_value(x)

       attn_scores = queries @ keystranspose(1, 2)
       attn_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
       )
       attn_weights = self.dropout(attn_weights)

       context_vec = attn_weights @ values
       return context_vec


###### We can initialize and run the module as follows on some example data:

In [None]:
torch.manual_seed(123)

inputs = torch.tensor(
      [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
d_in = inputs.shape[1]
d_out = 2

ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)

with torch.no_grad()
