# Implementation from ChatGPT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SingleHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, keys, queries, values, mask=None):
        N = queries.shape[0]
        key_len, query_len, value_len = keys.shape[1], queries.shape[1], values.shape[1]

        # Linear projections
        keys = self.keys(keys)
        queries = self.queries(queries)
        values = self.values(values)

        # Scaled dot-product attention
        energy = torch.matmul(queries, keys.transpose(-2, -1)) / (self.embed_size ** (1/2))
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, values)
        out = self.fc_out(out)
        return out
