In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.nn.attention.flex_attention import (
    create_block_mask,
    flex_attention,
)

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x28290066fc0>

In [4]:
from functools import partial

import torch

from world_machine.profile import profile_range

from world_machine.layers.positional_encoder import create_positional_encoder

def apply_score_mod(score, score_mode, batch_index, head_index, query_index, key_index):
    return score+score_mode[batch_index, head_index, query_index, key_index]


class MultiHeadSelfAttention(torch.nn.Module):

    def __init__(self, embed_dim: int, n_head: int, is_causal: bool, positional_encoder_type: str | None = None, fast: bool = True):
        super().__init__()

        self.attention = MultiHeadAttention(
            embed_dim, n_head, is_causal, positional_encoder_type, fast)

    @profile_range("multi_head_self_attention_forward", domain="world_machine")
    def forward(self, x: torch.Tensor):
        return self.attention(x, x, x)


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_dim: int, n_head: int, is_causal: bool, positional_encoder_type: str | None = None, fast: bool = True) -> None:
        """
        Creates the layer.

        Args:
            embed_dim (int): size of the embedding in the layer input and output.
        """
        super().__init__()

        self.embed_dim = embed_dim

        self.n_head = n_head
        self.head_dim = embed_dim//n_head

        self.is_causal = is_causal
        self.fast = fast

        if self.head_dim * n_head != embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads ({embed_dim}/{n_head} is not integer).")

        # Initialize weights

        # d_model = dv = dk = embed_dim
        # h = 1

        wQ = torch.Tensor(embed_dim, embed_dim)  # embed, embed
        wK = torch.Tensor(embed_dim, embed_dim)  # embed, dk
        wV = torch.Tensor(embed_dim, embed_dim)  # embed, dv
        w0 = torch.Tensor(embed_dim, embed_dim)  # embed, embed

        self.wQ = torch.nn.Parameter(wQ)
        self.wK = torch.nn.Parameter(wK)
        self.wV = torch.nn.Parameter(wV)
        self.w0 = torch.nn.Parameter(w0)

        self.register_buffer("dk_root", torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32)))
        self.dk_root: torch.Tensor

        self._positional_encoder = create_positional_encoder(
            positional_encoder_type, embed_dim, 0, n_head)

        for w in [self.wQ, self.wK, self.wV, self.w0]:
            torch.nn.init.kaiming_normal_(w)

        self.fast2 = False

    @profile_range("multi_head_attention_forward", domain="world_machine")
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Process the inputs using the attention process.

        Input tensors must be in [batch, sentence, embed] order.

        Args:
            query (torch.Tensor): queries tensor, are compared against the keys.
            key (torch.Tensor): keys tensor, represents the keys.
            value (torch.Tensor): values tensor.

        Returns:
            torch.Tensor: the layer output, the values pondered by the compability between the keys and queries.
        """

        # Check input
        if query.shape[2] != self.embed_dim:
            raise ValueError(
                f"Inputs must have embed dimension of {self.embed_dim} ({query.shape[2]} != {self.embed_dim})")

        # Get dimensions
        batch_size = query.shape[0]
        context_size = query.shape[1]

        # Linear input transformation
        # Transpose weights because PyTorch does that
        with profile_range("linear_input_transformation", category="multi_head_attention", domain="world_machine"):
            Q = query @ self.wQ.T
            K = key @ self.wK.T
            V = value @ self.wV.T

        # Compute bias
        with profile_range("compute_attention_bias", category="multi_head_attention", domain="world_machine"):
            attention_bias = torch.zeros(
                (context_size, context_size), device=Q.device)

            if self.is_causal:
                with profile_range("causal_bias", category="multi_head_attention", domain="world_machine"):

                    mask = torch.ones(
                        (context_size, context_size), dtype=torch.bool, device=Q.device)
                    mask = mask.tril()  # Lower triangular is one
                    # Upper triangular without diagonal is ones
                    mask = torch.bitwise_not(mask)

                    attention_bias = torch.zeros(
                        (context_size, context_size), device=Q.device)
                    attention_bias[mask] = -torch.inf

            attention_bias = attention_bias.unsqueeze(
                0).repeat([batch_size*self.n_head, 1, 1])

            with profile_range("positional_encoder", category="multi_head_attention", domain="world_machine"):
                attention_bias = self._positional_encoder.apply_attention_bias_pe(
                    attention_bias)

        # attention bias: [head*batch, context, context]

        if self.fast:
            E = self._fast_attention(Q, K, V, attention_bias)
        else:
            E = self._manual_attention(Q, K, V, attention_bias)

        result = E @ self.w0.T

        return result

    def _fast_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attention_bias: torch.Tensor) -> torch.Tensor:
        batch_size = Q.shape[0]
        context_size = Q.shape[1]
        embed_size = Q.shape[2]

        with profile_range("pre_reshape", category="multi_head_attention", domain="world_machine"):
            Q = Q.view(batch_size, -1, self.n_head,
                       self.head_dim).transpose(1, 2)
            K = K.view(batch_size, -1, self.n_head,
                       self.head_dim).transpose(1, 2)
            V = V.view(batch_size, -1, self.n_head,
                       self.head_dim).transpose(1, 2)

            # attention_bias: [head*batch, seq, seq]
            # attention_bias2: [bath, head, seq, seq]
            attention_bias = attention_bias.reshape(
                [batch_size, self.n_head, context_size, context_size])
    
        score_mod = lambda score, batch_index, head_index, query_index, key_index : apply_score_mod(score, attention_bias, batch_index, head_index, query_index, key_index)
        with profile_range("scaled_dot_product_attention", category="multi_head_attention", domain="world_machine"):
            if self.fast2:
                E = torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=attention_bias, scale=1/self.dk_root)
            else:
                E = flex_attention(Q, K, V, score_mod, scale=1/self.dk_root)
            #result = flex_attention(Q, K, V, scale=1/self.dk_root)
        
        with profile_range("post_reshape", category="multi_head_attention", domain="world_machine"):
            E = E.transpose(1, 2).view(
                batch_size, context_size, embed_size)

        return E

    def _manual_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attention_bias: torch.Tensor) -> torch.Tensor:
        batch_size = Q.shape[0]
        context_size = Q.shape[1]

        # batch_size, sentence, embed
        # to
        # batch_size,  n_head, sentence, head_dim
        with profile_range("pre_reshape", category="multi_head_attention", domain="world_machine"):
            Q = Q.transpose(0, 1).reshape(context_size, batch_size *
                                          self.n_head, self.head_dim).transpose(0, 1)
            K = K.transpose(0, 1).reshape(context_size, batch_size *
                                          self.n_head, self.head_dim).transpose(0, 1)
            V = V.transpose(0, 1).reshape(context_size, batch_size *
                                          self.n_head, self.head_dim).transpose(0, 1)
        # Now we have [
        # [batch0word0part0, batch0word1part0],
        # [batch0word0part1, batch0word1part1],
        # [batch1word0part0, batch1word1part0],
        # [batch1word0part1, batch1word1part1],
        # ]

        with profile_range("scores_computation", category="multi_head_attention", domain="world_machine"):
            scores = Q @ K.transpose(-2, -1)  # K.permute(0,1,3,2)
            scores /= self.dk_root

        with profile_range("add_attention_bias", category="multi_head_attention", domain="world_machine"):
            scores += attention_bias

        probs = torch.softmax(scores, dim=-1)
        E = probs @ V

        # Return elements to correct place
        with profile_range("post_reshape", category="multi_head_attention", domain="world_machine"):
            E = E.reshape(batch_size, self.n_head, context_size, self.head_dim)
            E = E.transpose(-3, -2)
            E = E.reshape(batch_size, context_size, self.embed_dim)
        # Now we have [
        # [batch0word0, batch0word1],
        # [batch1word0, batch1word1]
        # ]

        return E


In [5]:
embed_dim = 8
n_head = 4
is_causal = True
positional_encoder_type ="alibi"

mhsa_manual = MultiHeadSelfAttention(embed_dim, n_head, is_causal, positional_encoder_type, False)
mhsa_fast = MultiHeadSelfAttention(embed_dim, n_head, is_causal, positional_encoder_type, True)
mhsa_fast2 = MultiHeadSelfAttention(embed_dim, n_head, is_causal, positional_encoder_type, True)
torch_version = torch.nn.MultiheadAttention(embed_dim, num_heads=n_head, bias=False, batch_first=True).eval()

mhsa_fast2.attention.fast2 = True

In [7]:
mhsa_fast.attention.wQ = mhsa_manual.attention.wQ
mhsa_fast.attention.wK = mhsa_manual.attention.wK
mhsa_fast.attention.wV = mhsa_manual.attention.wV
mhsa_fast.attention.w0 = mhsa_manual.attention.w0


mhsa_fast2.attention.wQ = mhsa_manual.attention.wQ
mhsa_fast2.attention.wK = mhsa_manual.attention.wK
mhsa_fast2.attention.wV = mhsa_manual.attention.wV
mhsa_fast2.attention.w0 = mhsa_manual.attention.w0

torch_version.in_proj_weight = torch.nn.Parameter(torch.concat((mhsa_manual.attention.wQ, mhsa_manual.attention.wK, mhsa_manual.attention.wV)))
torch_version.out_proj.weight = mhsa_manual.attention.w0

In [8]:
x = torch.rand([32, 10, embed_dim])

In [9]:
y_manual = mhsa_manual(x)
y_fast = mhsa_fast(x)
y_fast2 = mhsa_fast2(x)
result_torch, _ = torch_version(x, x, x, need_weights=False)

In [10]:
((y_manual-y_fast).abs() > 1e-3).sum()

tensor(0)

In [12]:
((y_manual-y_fast2).abs() > 1e-3).sum()

tensor(0)

In [11]:
((y_manual-result_torch) > 1e-4).sum()

tensor(1276)

In [168]:
((y_fast-result_torch) > 1e-4).sum()

tensor(1102)

In [111]:
y_manual.numel()

2560

In [60]:
(y_manual/y_fast).flatten()

tensor([-0.4316, -1.0001,  0.9387,  ...,  4.1635, -6.3521, 12.4690])

m = a/x


f = a/y

m/f = a/x * y/a = y/x 

In [61]:
fast_scale = 1/np.sqrt(embed_dim)
manual_scale = 1/np.sqrt(mhsa_manual.attention.head_dim)

In [62]:
fast_scale, manual_scale

(0.35355339059327373, 0.7071067811865475)

In [63]:
fast_scale/manual_scale

0.5

In [64]:
(y_manual/y_fast).median()

tensor(-0.2615)

In [65]:
y_manual[0]

tensor([[-0.6754, -0.7069,  0.6798,  1.5112, -0.8935, -1.7815, -0.8695,  1.2731],
        [-0.6679, -0.7051,  0.6177,  1.4932, -0.8417, -1.8086, -0.7701,  1.3000],
        [-0.7154, -0.6486,  0.5826,  1.4753, -0.8954, -1.7991, -0.8409,  1.2300],
        [-0.7138, -0.6970,  0.6227,  1.5187, -0.8986, -1.8443, -0.8403,  1.2671],
        [-0.7241, -0.6508,  0.5309,  1.5137, -0.8776, -1.7682, -0.7733,  1.2858],
        [-0.6522, -0.6995,  0.6463,  1.4729, -0.8497, -1.7742, -0.8005,  1.2941],
        [-0.6666, -0.6871,  0.6764,  1.4478, -0.8675, -1.7611, -0.8445,  1.2663],
        [-0.6634, -0.6942,  0.6372,  1.4637, -0.8470, -1.7935, -0.7980,  1.2871],
        [-0.6590, -0.6966,  0.6276,  1.5074, -0.8627, -1.7871, -0.8209,  1.2677],
        [-0.6464, -0.6991,  0.6569,  1.4561, -0.8415, -1.7620, -0.7948,  1.2885]])

In [101]:
torch.argmin((y_fast-y_manual[1,1,0]).abs())

tensor(434)

In [91]:
426*2

852

In [95]:
426-434

-8

In [85]:
y_fast.shape

torch.Size([32, 10, 8])

53

In [63]:
n_head = 4
device = "cuda"

In [None]:
def generate_alibi_bias():
    alibi_bias = []
    for h in range(n_head):
        alibi_bias.append(-((h + 1) * 8.0 / n_head))
    alibi_bias = torch.tensor(alibi_bias, device=device)
    alibi_bias = torch.exp2(alibi_bias)
    return alibi_bias


alibi_bias = generate_alibi_bias()


# In this case we are going to use a mask_mod and a score_mod
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def alibi_and_causal_closure(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (kv_idx - q_idx)
    return score + bias

score_mod = alibi_and_causal_closure
mask_mod = causal_mask

In [74]:
x_reshape = x.view(32, -1, n_head,
                       128//n_head).transpose(1, 2)

In [75]:
x_reshape = x_reshape.to(device)

In [76]:
flex_attention(x_reshape, x_reshape, x_reshape)

tensor([[[[0.4794, 0.5454, 0.5413,  ..., 0.4885, 0.4858, 0.5556],
          [0.4698, 0.5364, 0.5326,  ..., 0.4786, 0.4887, 0.5545],
          [0.4737, 0.5467, 0.5350,  ..., 0.4942, 0.4928, 0.5496],
          ...,
          [0.4863, 0.5527, 0.5354,  ..., 0.4890, 0.4874, 0.5446],
          [0.4736, 0.5417, 0.5275,  ..., 0.4963, 0.4962, 0.5477],
          [0.4844, 0.5622, 0.5359,  ..., 0.4822, 0.4975, 0.5528]],

         [[0.4946, 0.5040, 0.4788,  ..., 0.4799, 0.4680, 0.4833],
          [0.4910, 0.4968, 0.4859,  ..., 0.4832, 0.4657, 0.4918],
          [0.5061, 0.4946, 0.4749,  ..., 0.4867, 0.4639, 0.4976],
          ...,
          [0.5104, 0.5015, 0.4819,  ..., 0.4799, 0.4835, 0.4827],
          [0.5015, 0.5021, 0.4742,  ..., 0.4920, 0.4677, 0.4832],
          [0.5074, 0.5034, 0.4832,  ..., 0.4800, 0.4755, 0.4806]],

         [[0.4929, 0.4862, 0.5478,  ..., 0.4798, 0.4643, 0.5114],
          [0.4829, 0.4821, 0.5508,  ..., 0.4802, 0.4802, 0.4880],
          [0.4908, 0.4804, 0.5514,  ..., 0

In [77]:
block_mask = create_block_mask(mask_mod, 1, 1, embed_dim, embed_dim, device=device)

In [None]:
flex_attention(x_reshape, x_reshape, x_reshape, score_mod, block_mask)

Unsupported: builtin: print [<class 'torch._dynamo.variables.lists.SizeVariable'>] False

from user code:
   File "c:\Users\eltsu\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\attention\flex_attention.py", line 1033, in _flex_attention_hop_wrapper
    return flex_attention_hop(*args, **kwargs)
  File "C:\Users\eltsu\AppData\Local\Temp\ipykernel_41424\1960214997.py", line 20, in alibi_and_causal_closure
    print(score.shape)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [48]:
import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.register_buffer("a", torch.ones(100, dtype=torch.float32), False)
        self.a *= 2

    def test(self):

        self.a = torch.ones(200)

    def forward(self, x):
        return x+self.a

In [49]:
a = MyModule()

In [50]:
a.a

tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])

In [42]:
a.test()

In [43]:
torch.save(a, "test.pt")

In [44]:
b = torch.load("test.pt")

  b = torch.load("test.pt")


In [46]:
b.a.shape

torch.Size([200])

In [47]:
a.a

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])

In [53]:
torch.tensor([]).dim()

1

In [1]:
import torch
from world_machine.profile import profile_range

In [2]:
def func(a:int):
    return a

In [3]:
func = profile_range("my_message")(func)

In [4]:
func(1)

1

In [11]:
import sys

sys.path.insert(0, "C:\\Users\\eltsu\\Documentos\\Projetos\\WorldMachine\\WorldMachine\\benchmark")

In [12]:
from utils import get_benchmark_model

In [13]:
model = get_benchmark_model()

In [18]:
list(model.modules())

[WorldMachine(
   (_blocks): ModuleList(
     (0-1): 2 x BlockContainer(
       (block): AdaLNZeroBlock(
         (conditioning_mlp): Sequential(
           (0): SiLU()
           (1): Linear(in_features=128, out_features=768, bias=True)
         )
         (layer_norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
         (modulate1): Modulate()
         (attention): MultiHeadSelfAttention(
           (attention): MultiHeadAttention(
             (_positional_encoder): AlibiPositionalEncoder()
             (input_projection): Linear(in_features=128, out_features=384, bias=False)
           )
         )
         (dropout_attention): Dropout(p=0.0, inplace=False)
         (modulate2): Modulate()
         (layer_norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
         (modulate3): Modulate()
         (linear1): Linear(in_features=128, out_features=512, bias=True)
         (dropout_linear1): Dropout(p=0.0, inplace=False)
         (act): GELU(approximate='tanh')
  

# Trainer ranges

In [8]:
from world_machine.train.stages import PrepareModel, TrainStage
import inspect

In [2]:
pm = PrepareModel()

In [None]:
pm.pre_train == 

<code object inner at 0x7f1244092c10, file "/usr/lib/python3.12/contextlib.py", line 78>

In [9]:
PrepareModel.__dict__

mappingproxy({'__module__': 'world_machine.train.stages.prepare_model',
              '__init__': <function world_machine.train.stages.prepare_model.PrepareModel.__init__(self)>,
              'pre_batch': <function world_machine.train.stages.prepare_model.PrepareModel.pre_batch(self, model: world_machine.world_machine.WorldMachine, mode: world_machine.train.mode.DatasetPassMode, criterions: dict[str, dict[str, torch.nn.modules.module.Module]], optimizer: torch.optim.optimizer.Optimizer, device: torch.device, losses: dict, train_criterions: dict[str, dict[str, float]]) -> None>,
              'post_batch': <function world_machine.train.stages.prepare_model.PrepareModel.post_batch(self, model: world_machine.world_machine.WorldMachine, losses: dict, criterions: dict[str, dict[str, torch.nn.modules.module.Module]], train_criterions: dict[str, dict[str, float]]) -> None>,
              '__doc__': None,
              '__abstractmethods__': frozenset(),
              '_abc_impl': <_abc._abc_