In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PytorchMultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, causal=True, p_dropout=0.1, bias=True):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.causal = causal
        self.attn_hidden_dim = n_embd // n_head
        self.dropout = p_dropout

        self.q_projection = nn.Linear(n_embd, n_embd, bias=bias)
        self.k_projection = nn.Linear(n_embd, n_embd, bias=bias)
        self.v_projection = nn.Linear(n_embd, n_embd, bias=bias)
        self.out_projection = nn.Linear(n_embd, n_embd, bias=bias)

    def create_causal_mask(self, seq_len, device):
        # (1, 1, T, T) mask filled with -inf in upper triangle
        mask = torch.triu(torch.ones((1, 1, seq_len, seq_len), device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def reshape_and_multiply_layer(self, m, x):
        B, T, E = x.shape
        x_flattened = x.reshape(B * T, E)                  # (B*T, E)
        result = m(x_flattened)                            # (B*T, E)
        result = result.view(B, T, self.n_head, self.attn_hidden_dim)  # (B, T, H, D)
        result = result.permute(0, 2, 1, 3).contiguous()   # (B, H, T, D)
        return result

    def project_to_query_key_value(self, x):
        q = self.reshape_and_multiply_layer(self.q_projection, x)         # (B, H, T, D)
        k = self.reshape_and_multiply_layer(self.k_projection, x)         # (B, H, T, D)
        v = self.reshape_and_multiply_layer(self.v_projection, x)         # (B, H, T, D)
        kT = k.permute(0, 1, 3, 2).contiguous()                            # (B, H, D, T)
        return q, kT, v

    def self_attention(self, q, kT, v):
        B, H, T, D = q.shape
        scale = D ** 0.5
        attn_scores = torch.matmul(q, kT) / scale                          # (B, H, T, T)

        if self.causal:
            mask = self.create_causal_mask(T, q.device)                   # (1, 1, T, T)
            attn_scores = attn_scores + mask                              # broadcasted

        attn_weights = F.softmax(attn_scores, dim=-1)                     # (B, H, T, T)
        attn_output = torch.matmul(attn_weights, v)                       # (B, H, T, D)

        output = attn_output.permute(0, 2, 1, 3).contiguous()             # (B, T, H, D)
        output = output.view(B, T, H * D)                                 # (B, T, E)
        return output

    def forward(self, x):
        B, T, E = x.shape
        q, kT, v = self.project_to_query_key_value(x)
        out = self.self_attention(q, kT, v)
        out = self.out_projection(out.view(B * T, E)).view(B, T, E)
        return out

In [None]:
import numpy as np
import torch
import os

def load_numpy_array(arr_path):
    with open(arr_path, 'rb') as f:
        return np.load(f).astype(np.float32)

def test_multihead_attention_student(batch_size, queries_len, n_embd, num_heads, p_dropout):
    test_dir = f'./tests/data/multihead_attention'
    test_str = '_'.join(map(str, (batch_size, queries_len, n_embd, num_heads)))

    # Load numpy test data
    data = load_numpy_array(os.path.join(test_dir, f'{test_str}_data.npy'))
    w_q = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_q.npy'))
    w_k = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_k.npy'))
    w_v = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_v.npy'))
    w_out = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_out.npy'))

    result_ = load_numpy_array(os.path.join(test_dir, f'{test_str}_result.npy'))
    x_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_x_grad.npy'))
    w_q_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_q_grad.npy'))
    w_k_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_k_grad.npy'))
    w_v_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_v_grad.npy'))
    w_out_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_out_grad.npy'))

    # Create tensors with requires_grad=True to compute gradients
    X = torch.from_numpy(data).requires_grad_(True)

    # Initialize your PyTorch implementation
    layer = PytorchMultiHeadAttention(n_embd, num_heads, causal=True, bias=False)

    # Load provided weights into PyTorch layers
    with torch.no_grad():
        layer.q_projection.weight.copy_(torch.from_numpy(w_q.T))
        layer.k_projection.weight.copy_(torch.from_numpy(w_k.T))
        layer.v_projection.weight.copy_(torch.from_numpy(w_v.T))
        layer.out_projection.weight.copy_(torch.from_numpy(w_out.T))

    # Forward pass
    result = layer(X)

    # Check forward pass output
    np.testing.assert_allclose(result.detach().numpy(), result_, atol=1e-5, rtol=1e-5)

    # Backward pass to compute gradients
    result.sum().backward()

    # # Check gradients
    np.testing.assert_allclose(X.grad.numpy(), x_grad, atol=1e-5, rtol=1e-5)
    np.testing.assert_allclose(layer.q_projection.weight.grad.numpy(), w_q_grad.T, atol=1e-5, rtol=1e-5)
    np.testing.assert_allclose(layer.k_projection.weight.grad.numpy(), w_k_grad.T, atol=1e-5, rtol=1e-5)
    np.testing.assert_allclose(layer.v_projection.weight.grad.numpy(), w_v_grad.T, atol=1e-5, rtol=1e-5)
    np.testing.assert_allclose(layer.out_projection.weight.grad.numpy(), w_out_grad.T, atol=1e-5, rtol=1e-5)
    print("All tests passed successfully!")

In [3]:
m = test_multihead_attention_student(128,32,256,8,0.0)

All tests passed successfully!


In [4]:
from minitorch import MultiHeadAttention
import minitorch
from minitorch.cuda_kernel_ops import CudaKernelOps
import numpy as np
from minitorch.tensor import tensor, tensor_from_numpy
from minitorch.module import Module, Parameter
from minitorch.tensor_ops import *


backend = minitorch.TensorBackend(CudaKernelOps)
def test_multihead_attention_student_minitorch(batch_size, queries_len, n_embd, num_heads, p_dropout, backend):
    test_dir = f'./tests/data/multihead_attention'
    test_str = '_'.join(map(str, (batch_size, queries_len, n_embd, num_heads)))

    data = load_numpy_array(os.path.join(test_dir, f'{test_str}_data.npy'))
    w_q = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_q.npy'))
    w_k = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_k.npy'))
    w_v = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_v.npy'))
    w_out = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_out.npy'))
    result_ = load_numpy_array(os.path.join(test_dir, f'{test_str}_result.npy'))
    x_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_x_grad.npy'))
    w_q_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_q_grad.npy'))
    w_k_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_k_grad.npy'))
    w_v_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_v_grad.npy'))
    w_out_grad = load_numpy_array(os.path.join(test_dir, f'{test_str}_w_out_grad.npy'))

    X    = minitorch.tensor_from_numpy(data, backend, True)

    layer = minitorch.MultiHeadAttention(n_embd, num_heads, True, p_dropout, bias=False, backend=backend)
    
    layer.q_projection.weights.value   = minitorch.tensor_from_numpy((w_q), backend=backend, requires_grad=True)
    layer.k_projection.weights.value   = minitorch.tensor_from_numpy((w_k), backend=backend, requires_grad=True)
    layer.v_projection.weights.value   = minitorch.tensor_from_numpy((w_v), backend=backend, requires_grad=True)
    layer.out_projection.weights.value = minitorch.tensor_from_numpy((w_out), backend=backend, requires_grad=True)
    result = layer(X)
    np.testing.assert_allclose(result.to_numpy(), result_, atol=1e-5, rtol=1e-5)

    result.sum().backward()
    
    np.testing.assert_allclose(X.grad.to_numpy(), x_grad, atol=1e-5, rtol=1e-5)
    np.testing.assert_allclose(layer.out_projection.weights.value.grad.to_numpy(), w_out_grad, atol=1e-5, rtol=1e-5)
    #np.testing.assert_allclose(layer.q_projection.weights.value.grad.to_numpy(), w_q_grad, atol=1e-5, rtol=1e-5)


In [5]:
m2 = test_multihead_attention_student_minitorch(1,32,64,1,0.0,backend)

AssertionError: Must be contiguous to view

In [None]:
m.q_projection.weight

In [6]:
print((m2.q_projection.bias.value.to_numpy()-m.q_projection.bias.data.numpy()).sum())

NameError: name 'm2' is not defined

In [22]:
m.q_projection.weight

AttributeError: 'NoneType' object has no attribute 'q_projection'

In [33]:
x = tensor_from_numpy(np.random.randn(2, 3, 4),backend, requires_grad=True)
y = x.permute(0, 2, 1)  # now shape (2, 4, 3)
out = y.sum()
out.backward()

print(x.grad)  # should be all ones, same shape as x


[
	[
		[1.000000 1.000000 1.000000 1.000000]
		[1.000000 1.000000 1.000000 1.000000]
		[1.000000 1.000000 1.000000 1.000000]]
	[
		[1.000000 1.000000 1.000000 1.000000]
		[1.000000 1.000000 1.000000 1.000000]
		[1.000000 1.000000 1.000000 1.000000]]]
