In [1]:

# imports
import os
import sys
import types
import json
import base64

# figure size/format
fig_width = 7
fig_height = 5
fig_format = 'retina'
fig_dpi = 96
interactivity = ''
is_shiny = False
is_dashboard = False
plotly_connected = True

# matplotlib defaults / format
try:
  import matplotlib.pyplot as plt
  plt.rcParams['figure.figsize'] = (fig_width, fig_height)
  plt.rcParams['figure.dpi'] = fig_dpi
  plt.rcParams['savefig.dpi'] = "figure"

  # IPython 7.14 deprecated set_matplotlib_formats from IPython
  try:
    from matplotlib_inline.backend_inline import set_matplotlib_formats
  except ImportError:
    # Fall back to deprecated location for older IPython versions
    from IPython.display import set_matplotlib_formats
    
  set_matplotlib_formats(fig_format)
except Exception:
  pass

# plotly use connected mode
try:
  import plotly.io as pio
  if plotly_connected:
    pio.renderers.default = "notebook_connected"
  else:
    pio.renderers.default = "notebook"
  for template in pio.templates.keys():
    pio.templates[template].layout.margin = dict(t=30,r=0,b=0,l=0)
except Exception:
  pass

# disable itables paging for dashboards
if is_dashboard:
  try:
    from itables import options
    options.dom = 'fiBrtlp'
    options.maxBytes = 1024 * 1024
    options.language = dict(info = "Showing _TOTAL_ entries")
    options.classes = "display nowrap compact"
    options.paging = False
    options.searching = True
    options.ordering = True
    options.info = True
    options.lengthChange = False
    options.autoWidth = False
    options.responsive = True
    options.keys = True
    options.buttons = []
  except Exception:
    pass
  
  try:
    import altair as alt
    # By default, dashboards will have container sized
    # vega visualizations which allows them to flow reasonably
    theme_sentinel = '_quarto-dashboard-internal'
    def make_theme(name):
        nonTheme = alt.themes._plugins[name]    
        def patch_theme(*args, **kwargs):
            existingTheme = nonTheme()
            if 'height' not in existingTheme:
              existingTheme['height'] = 'container'
            if 'width' not in existingTheme:
              existingTheme['width'] = 'container'

            if 'config' not in existingTheme:
              existingTheme['config'] = dict()
            
            # Configure the default font sizes
            title_font_size = 15
            header_font_size = 13
            axis_font_size = 12
            legend_font_size = 12
            mark_font_size = 12
            tooltip = False

            config = existingTheme['config']

            # The Axis
            if 'axis' not in config:
              config['axis'] = dict()
            axis = config['axis']
            if 'labelFontSize' not in axis:
              axis['labelFontSize'] = axis_font_size
            if 'titleFontSize' not in axis:
              axis['titleFontSize'] = axis_font_size  

            # The legend
            if 'legend' not in config:
              config['legend'] = dict()
            legend = config['legend']
            if 'labelFontSize' not in legend:
              legend['labelFontSize'] = legend_font_size
            if 'titleFontSize' not in legend:
              legend['titleFontSize'] = legend_font_size  

            # The header
            if 'header' not in config:
              config['header'] = dict()
            header = config['header']
            if 'labelFontSize' not in header:
              header['labelFontSize'] = header_font_size
            if 'titleFontSize' not in header:
              header['titleFontSize'] = header_font_size    

            # Title
            if 'title' not in config:
              config['title'] = dict()
            title = config['title']
            if 'fontSize' not in title:
              title['fontSize'] = title_font_size

            # Marks
            if 'mark' not in config:
              config['mark'] = dict()
            mark = config['mark']
            if 'fontSize' not in mark:
              mark['fontSize'] = mark_font_size

            # Mark tooltips
            if tooltip and 'tooltip' not in mark:
              mark['tooltip'] = dict(content="encoding")

            return existingTheme
            
        return patch_theme

    # We can only do this once per session
    if theme_sentinel not in alt.themes.names():
      for name in alt.themes.names():
        alt.themes.register(name, make_theme(name))
      
      # register a sentinel theme so we only do this once
      alt.themes.register(theme_sentinel, make_theme('default'))
      alt.themes.enable('default')

  except Exception:
    pass

# enable pandas latex repr when targeting pdfs
try:
  import pandas as pd
  if fig_format == 'pdf':
    pd.set_option('display.latex.repr', True)
except Exception:
  pass

# interactivity
if interactivity:
  from IPython.core.interactiveshell import InteractiveShell
  InteractiveShell.ast_node_interactivity = interactivity

# NOTE: the kernel_deps code is repeated in the cleanup.py file
# (we can't easily share this code b/c of the way it is run).
# If you edit this code also edit the same code in cleanup.py!

# output kernel dependencies
kernel_deps = dict()
for module in list(sys.modules.values()):
  # Some modules play games with sys.modules (e.g. email/__init__.py
  # in the standard library), and occasionally this can cause strange
  # failures in getattr.  Just ignore anything that's not an ordinary
  # module.
  if not isinstance(module, types.ModuleType):
    continue
  path = getattr(module, "__file__", None)
  if not path:
    continue
  if path.endswith(".pyc") or path.endswith(".pyo"):
    path = path[:-1]
  if not os.path.exists(path):
    continue
  kernel_deps[path] = os.stat(path).st_mtime
print(json.dumps(kernel_deps))

# set run_path if requested
run_path = 'L1VzZXJzL3ByYXR5dXNoc2luaGEvRG9jdW1lbnRzL3ByYXR5dXNoLW1sLmdpdGh1Yi5pby9wb3N0cw=='
if run_path:
  # hex-decode the path
  run_path = base64.b64decode(run_path.encode("utf-8")).decode("utf-8")
  os.chdir(run_path)

# reset state
%reset

# shiny
# Checking for shiny by using False directly because we're after the %reset. We don't want
# to set a variable that stays in global scope.
if False:
  try:
    import htmltools as _htmltools
    import ast as _ast

    _htmltools.html_dependency_render_mode = "json"

    # This decorator will be added to all function definitions
    def _display_if_has_repr_html(x):
      try:
        # IPython 7.14 preferred import
        from IPython.display import display, HTML
      except:
        from IPython.core.display import display, HTML

      if hasattr(x, '_repr_html_'):
        display(HTML(x._repr_html_()))
      return x

    # ideally we would undo the call to ast_transformers.append
    # at the end of this block whenver an error occurs, we do 
    # this for now as it will only be a problem if the user 
    # switches from shiny to not-shiny mode (and even then likely
    # won't matter)
    import builtins
    builtins._display_if_has_repr_html = _display_if_has_repr_html

    class _FunctionDefReprHtml(_ast.NodeTransformer):
      def visit_FunctionDef(self, node):
        node.decorator_list.insert(
          0,
          _ast.Name(id="_display_if_has_repr_html", ctx=_ast.Load())
        )
        return node

      def visit_AsyncFunctionDef(self, node):
        node.decorator_list.insert(
          0,
          _ast.Name(id="_display_if_has_repr_html", ctx=_ast.Load())
        )
        return node

    ip = get_ipython()
    ip.ast_transformers.append(_FunctionDefReprHtml())

  except:
    pass

def ojs_define(**kwargs):
  import json
  try:
    # IPython 7.14 preferred import
    from IPython.display import display, HTML
  except:
    from IPython.core.display import display, HTML

  # do some minor magic for convenience when handling pandas
  # dataframes
  def convert(v):
    try:
      import pandas as pd
    except ModuleNotFoundError: # don't do the magic when pandas is not available
      return v
    if type(v) == pd.Series:
      v = pd.DataFrame(v)
    if type(v) == pd.DataFrame:
      j = json.loads(v.T.to_json(orient='split'))
      return dict((k,v) for (k,v) in zip(j["index"], j["data"]))
    else:
      return v

  v = dict(contents=list(dict(name=key, value=convert(value)) for (key, value) in kwargs.items()))
  display(HTML('<script type="ojs-define">' + json.dumps(v) + '</script>'), metadata=dict(ojs_define = True))
globals()["ojs_define"] = ojs_define
globals()["__spec__"] = None



In [2]:
import torch

# Declaring an input token of size 6 with embedding size of 3
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

# For demonstration purposes, we select one input token
query = inputs[1]

# 
attn_scores_2 = torch.empty(inputs.shape[0])
# print(attn_scores_2)

# Calculating step 1. This results in a vector of the same size as the context length
# Context length here means the number of tokens
for i , x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)


print(f"Attention score for second word i.e Journey {attn_scores_2}")

  cpu = _conversion_method_template(device=torch.device("cpu"))


Attention score for second word i.e Journey tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
attn_weights_2 = torch.softmax(attn_scores_2, dim =0)
print(f"This is normalizing the vector of immediate weights {attn_weights_2}")
print(f"Sum should be 1: {attn_weights_2.sum()}")

This is normalizing the vector of immediate weights tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum should be 1: 1.0


In [4]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(f"The final attention weight: {context_vec_2}")

The final attention weight: tensor([0.4419, 0.6515, 0.5683])


In [5]:
attn_scores = torch.empty(inputs.shape)
print(attn_scores)
for i, x_i in enumerate(inputs):
    for j , x_j in enumerate(inputs):
        attn_scores[i] = torch.dot(x_i, x_j)
print(f"Attention score for all words {attn_scores}")

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
Attention score for all words tensor([[0.6310, 0.6310, 0.6310],
        [1.0865, 1.0865, 1.0865],
        [1.0605, 1.0605, 1.0605],
        [0.6565, 0.6565, 0.6565],
        [0.2935, 0.2935, 0.2935],
        [0.9450, 0.9450, 0.9450]])


In [6]:
# Step 1
attn_scores = inputs @ inputs.T 

# Step 2
attn_weights = torch.softmax(attn_scores, dim =-1) # Do it on the last dimension of the vector, which in this case is equal to token size.

# Step 3

all_context_vectors = attn_weights @ inputs
print(all_context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [7]:
# Picking one token
x_2 = inputs[1]
d_in = inputs.shape[1] # Picking the embedding size. Here this value is 3
d_out = 2 # This is output size of the matrix

# We have initialized the random weights of the same size as the token embeddings and arbitary number 2.

torch.manual_seed(123)
# All these 3 matrices are of the same size 3 * 2
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Multiplying the input token of size 1 * 3 with 3 * 2
# The result is a matrix of size 1 * 2
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)


# These matrices are updated here but they could very well be used for one token
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)


keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print(attn_scores_22)

attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim =-1) # Why do we use d_k ** 0.5?
print(attn_weights_2)

context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.4306, 1.4551])
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
tensor(1.8524)
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
tensor([0.3061, 0.8210])


In [8]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out)) # Got rid of requires gradient = False part
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out)) 
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out)) 
        

    def forward(self, x):
        query = x @ self.W_query 
        key = x @ self.W_key
        value = x @ self.W_value

        attn_scores = query @ key.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [9]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [10]:
# First two steps combined
queries = sa_v2.W_query(inputs)

keys = sa_v2.W_key(inputs)

attnn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/ keys.shape[-1]** 0.5, dim = -1)

print(attn_weights)

# Third step - this looks a whole lot complex but is rather simple. If you only think about the values that tril is generating.
context_length = attn_scores.shape[0]
masked_simple = torch.tril(torch.ones(context_length,context_length))
print(masked_simple)

# Fourth step
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[0.1972, 0.1910, 0.1894, 0.1361, 0.1344, 0.1520],
        [0.1476, 0.2164, 0.2134, 0.1365, 0.1240, 0.1621],
        [0.1479, 0.2157, 0.2129, 0.1366, 0.1260, 0.1608],
        [0.1505, 0.1952, 0.1933, 0.1525, 0.1375, 0.1711],
        [0.1571, 0.1874, 0.1885, 0.1453, 0.1819, 0.1399],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])


In [11]:
batch = torch.stack((inputs,inputs), dim = 0)
print(batch.shape) ## The shape now must be 2 * 6 * 3 which is 2 items in the batch. 6 tokens in each batch. Each token with an embedding size of 3

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in,d_out,bias = qkv_bias)
        self.W_key = nn.Linear(d_in,d_out,bias = qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias = qkv_bias)
        self.dropout = nn.Dropout(dropout) # The input here indicates the percentage of values in the matrix that must be zero

        # This tells pytorch that these weights are non trainable so no gradient-descent for them. The mask is saved in state_dict and is loaded/saced with the model
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length,context_length),diagonal=1)
        ) # Will take this in a bit

    def forward(self, x):
        b, num_tokens, d_in = x.shape # This is the reason that we do not need to initialize the d_in. d_in is derived from an object of the class

        # Same step as the previous loop
        # Shape of x = 2,6,3
        # Shape of key, query, value is 2,6,2
        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        # Pretty much the same as the previous step. This is just doing transformations so that keys merge properly.
        # Shape of attention scores is 2,6,6
        # Keys.transpose essentially means switching the 1st dimenion with the second. so the dimesion is now 2,2,6
        # queries (2,6,2) @ keys.transpose(1,2) (2,2,6) ==> (2,6,6)
        # this is because torch considers everythin after the last two dimesnions as batch. A better way to say this would be anything the matrix multipllication here is 6,2 @ 2,6 --> 6,6 and then there are two batches in each.

        attn_scores = queries @ keys.transpose(1,2)
        
        # This is making sure that all the values that are masked are now 
        # Fill elemts of the tensor with value where mask is true. The shape of the mask must be broadcastable
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf)

        # This is the same step as before.
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim = -1)

        # Droput we had discussed
        # Zeroes the values of the matrix at random. the shape of the matrix is the same

        attn_weights = self.dropout(attn_weights)

        # This the same step as before.
        # attn_weights dim = (2,6,6). values dim (2,6,2) --> 2,6,2
        context_vec = attn_weights @ values

        return context_vec

torch.Size([2, 6, 3])


In [12]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        # Defines the number of Causal Attention Modules
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)]
            )
    # Concatenates the declaration from before
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

In [13]:
torch.manual_seed(123)
context_length = batch.shape[1] #This is the number of tokens

d_in, d_out = 3,2

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2 
)

context_vecs = mha(batch)

print(context_vecs)
print("contex_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
contex_vecs.shape: torch.Size([2, 6, 4])


In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        assert(d_out %num_heads == 0), "d_out be divisible by the num_heads" # Why do we need to do this?

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduces the projection dimension to match the desired output dim

        # In mulit-headed attention, we don't actually create several seperate linear layers for each head. Instead we perform one large matrix multiplication and then slice the result into smaller for each head. If d_out is not divisible by number of heads, we would end up with fractional dimensions or uneven heads. To mitigate this,we could use padding which can be totally avoided

        self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias) 
        self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias)

        # Uses a linear layer to combine individual head outputs
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self,x):
        # We assume number of heads(num_heads) to 4
        # Assuming output dimensions(d_out) to 8
        # Dimensions assuming: 5,6,3 <==> (batch, token length, embedding vector)
        b, num_tokens, d_in = x.shape

        # Dimension output (5,6,3) * (3,8) --> (5,6,8)
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a num_heads dimension, then we unroll the last dimension
        # (batch, num_tokens, d_out) --> (b, num_tokens, num_heads, head_dim)

        # (5,6,8) --> (5,6,4,2)        
        keys = keys.view(b, num_tokens,self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transposes from shape (b, num_tokens, num_heads, head_dim) --> (b, num_heads, num_tokens, head_dim)
        # (5,6,4,2) --> (5,4,6,2)
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        # Can we not combine view and step to one step, view should be able to do it?
        #  View changes how we interpret the memory. In this case it asking the interpretation of 8 vector of dim 1 to 4 rows and 1 column.
        # Transpose changes the physical order. When we move the num_heads dimesnio, we are effectively asking to group all tokens for a single head together

        # Computes dot product for each head
        # (5,4,6,2) @ (5,4,2, 6) --> 5,4,6,6
        attn_scores = queries @ keys.transpose(2,3)

        # Masks truncated to the number of tokens
        # (6,6)
        mask_bool = self.mask[:num_tokens,:num_tokens].bool()

        # uses masks to fill the attention scores
        # This will broadcast 6 ,6 across 
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)
        
        attn_weights = self.dropout(attn_weights)


        # (5,4,6,6) @ (5,4,6,2) --> (5,4,6,2).T --> (5,6,4,2)
        context_vec = (attn_weights @ values).transpose(1,2)
        # Combines heads where self.d_out = self.num_heads * self.head_dim
        # (5,6,4,2) --> (5,6,8) We are doing the unroll here
        # here contiguous().view() is asking to do the same thing as before but instead of the previous 8 --> 4,2, here it is asking to interpret this as 4,2 to 8
        # After the transpose the data in memory is not longer lined up in a single row ergo not contiguous
        # View has the requirement for the data to be in a contiguous block to work
        # Calling .contiguis shuffles the physical memory to match the transposed shape so that it can flatten out.
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        # Adding an optional linear projection. This is not strictly necessary but is commonly used in many LLM Architectures
        context_vec = self.out_proj(context_vec)

        return context_vec