In [1]:
import torch
import torch.nn as nn

In [51]:
class CasualAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
    super().__init__()
    self.d_out=d_out
    self.w_query=nn.Linear(d_in,d_out,bias=False)
    self.w_key=nn.Linear(d_in,d_out,bias=False)
    self.w_value=nn.Linear(d_in,d_out,bias=False)
    self.dropout=nn.Dropout(dropout)
    # self.mask=nn.Parameter(torch.triu(torch.ones(num_tokens,num_tokens)),diagonal=1)
    self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length)))
  def forward(self,x):
    b,num_tokens,d_in=x.shape
    keys=self.w_key(x)
    queries=self.w_query(x)
    values=self.w_value(x)
    attn_scores=queries@keys.transpose(1,2)
    attn_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
    attn_weights=self.dropout(torch.softmax(attn_scores,dim=-1))
    attn_scores=torch.bmm(queries,keys.transpose(1,2))
    attn_scores=attn_scores/torch.sqrt(torch.tensor(d_in))
    attn_weights=self.dropout(torch.softmax(attn_scores,dim=-1))
    context_vector=attn_weights@values
    return context_vector

In [52]:
torch.manual_seed(1234)
inputs = torch.tensor([[0.43,0.15,0.89],
                       [0.55,0.87,0.66]])
batch = torch.stack((inputs,inputs),dim=0)   # shape (2,2,3)

context_length = batch.shape[1]  # = 2
d_out = 2

ca_without_buffer = CasualAttention(d_in=3, d_out=d_out,
                                    context_length=context_length,
                                    dropout=0)

with torch.no_grad():
    context_vec = ca_without_buffer(batch)   # <-- use batch here

print(context_vec.shape)  # (2, 2, 2)
print(context_vec)

torch.Size([2, 2, 2])
tensor([[[0.1511, 0.1134],
         [0.1472, 0.1155]],

        [[0.1511, 0.1134],
         [0.1472, 0.1155]]])


In [14]:
print("Machine has a gpu",torch.cuda.is_available())

Machine has a gpu False


In [38]:
# batch=batch.to("cuda")
# ca_without_buffer.to("cuda")


In [40]:
print("W_Query_device", ca_without_buffer.w_query.weight.device)
print("W_Key_device",   ca_without_buffer.w_key.weight.device)
print("W_Value_device", ca_without_buffer.w_value.weight.device)

W_Query_device cpu
W_Key_device cpu
W_Value_device cpu


In [53]:
print(ca_without_buffer.mask)

tensor([[1., 1.],
        [0., 1.]])
