In [None]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Callable

In [None]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim**0.5
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.g

In [None]:
class GLU(nn.Module):
    def __init__(self, dim_in, dim_out, activation: Callable, mult_bias=False):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)
        self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.0

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * self.act(gate) * self.mult_bias

In [None]:
def l2norm(t, groups=1):
    B, H, T, E = t.size() # Batch, head size, block size, n_embd
    t = t.view(B, H, T, groups, E//groups) # we split the last dimension into groups
    t = F.normalize(t, dim=-1)
    return t.view(B, H, T, E) # reassemble the input

def pad_at_dim(t, pad, dim=-1, value=0.0):
    dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = (0, 0) * dims_from_right
    return F.pad(t, (*zeros, *pad), value=value)

In [None]:
class AlibiPositionalBias(nn.Module):
    """
    Implementation of the Alibi positional bias. This will apply a negative bias to the sequence proportionally to its distance.
    This means that close tokens will have more relevance than further ones.
    """
    def __init__(self, heads, total_heads, **kwargs):
        super().__init__()
        self.heads = heads
        self.total_heads = total_heads

        slopes = Tensor(self._get_slopes(heads))
        slopes = slopes.unsqueeze(1).unsqueeze(2)
        self.register_buffer("slopes", slopes, persistent=False)
        self.register_buffer("bias", None, persistent=False)

    def get_bias(self, i, j, device):
        i_arange = torch.arange(j - i, j, device=device)
        j_arange = torch.arange(j, device=device)
        bias = -torch.abs(
            j_arange.unsqueeze(0).unsqueeze(0)
            - i_arange.unsqueeze(0).unsqueeze(2)
        )
        return bias

    @staticmethod
    def _get_slopes(heads):
        """"
        This function returns the slopes for the power of 2 of the heads. If the heads are not an exponent of 2 then we correct it.
        """
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return sorted(
            get_slopes_power_of_2(closest_power_of_2)
            + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
                : heads - closest_power_of_2
            ], reverse=True
        )

    @property
    def device(self):
        return next(self.buffers()).device

    def forward(self, i, j):
        h, device = self.total_heads, self.device

        if (
            self.bias is not None
            and self.bias.shape[-1] >= j
            and self.bias.shape[-2] >= i
        ):
            return self.bias[..., :i, :j]

        bias = self.get_bias(i, j, device)
        bias = bias * self.slopes

        num_heads_unalibied = h - bias.shape[0]
        bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0)
        self.register_buffer("bias", bias, persistent=False)

        return self.bias

In [None]:
def apply_rotary_pos_emb(t, freqs, scale=1):
    seq_len = t.shape[-2]
    freqs = freqs[-seq_len:, :]
    return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)

def rotate_half(x):
    x = x.reshape(x.shape[:-1] + (2, -1))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)

def create_causal_mask(i, j, device):
    return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        use_xpos=True,
        scale_base=512,
        interpolation_factor=1.0,
        base=10000,
        alpha=1.0,
    ):
        super().__init__()
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
        base *= alpha ** (dim / (dim - 2))

        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        assert interpolation_factor >= 1.0
        self.interpolation_factor = interpolation_factor

        if not use_xpos:
            self.register_buffer("scale", None)
            return

        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = scale_base
        self.register_buffer("scale", scale)

    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        t = t / self.interpolation_factor
        
        freqs = t[:, None] * self.inv_freq[None, :]
        freqs = torch.cat((freqs, freqs), dim=-1)

        if self.scale is None:
            return freqs, 1.0

        power = (
            torch.arange(seq_len, device=device) - (seq_len // 2)
        ) / self.scale_base
        scale = self.scale ** power.unsqueeze(1)
        scale = torch.cat((scale, scale), dim=-1)

        return freqs, scale

In [None]:
class CrossAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        self.kv_heads = config.kv_heads

        q_dim = config.dim_head * config.n_head
        k_dim = config.dim_head * config.kv_heads
        v_dim = config.dim_head * config.kv_heads
        out_dim = config.dim_head * config.n_head
        
        self.qk_norm_q_scale = nn.Parameter(torch.ones(config.dim_head))
        self.qk_norm_k_scale = nn.Parameter(torch.ones(config.dim_head))
        
        # key, query, value projections for all heads, but in a batch
        self.linear_enc = nn.Linear(config.n_embd*2, config.n_embd, bias=config.bias)
        self.to_q = nn.Linear(config.n_embd, q_dim, bias=config.bias)
        self.to_k = nn.Linear(config.n_embd, k_dim, bias=config.bias)
        self.to_v = nn.Linear(config.n_embd, v_dim, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(out_dim, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.kv_head = config.kv_heads
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.mask = config.mask
        self.dim_head = config.dim_head
        self.qk_norm_scale = 10
        assert hasattr(torch.nn.functional, 'scaled_dot_product_attention'), "Flash Attention requires PyTorch >= 2.0"
        

    def forward(self, x, encoded_data, rotary_pos_emb, pos_emb):
        B, T, C = x.size() # batch size, block size, embedding dimensionality (dim_head)

        q = self.to_q(x) # batch size, sequence length, dim_head*n_head
        k = self.to_k(encoded_data) # batch size, sequence length, dim_head*kv_head
        v = self.to_v(encoded_data) # batch size, sequence length, dim_head*kv_head
        
        q = q.view(B, T, self.n_head, self.dim_head).transpose(1, 2)
        k = k.view(B, T, self.kv_head, self.dim_head).transpose(1, 2)
        v = v.view(B, T, self.kv_head, self.dim_head).transpose(1, 2)
        # We arrange the heads -> batch_size, n head (or kv head), block size, dim_head
        
        q, k = map(l2norm, (q, k))
        q = q * self.qk_norm_q_scale
        k = k * self.qk_norm_k_scale
        # We apply qk normalization on the dot products
        
        freqs, xpos_scale = rotary_pos_emb # block_size, dim_head/2
        l = freqs.shape[-1]
        
        q_xpos_scale, k_xpos_scale = (
            (xpos_scale, xpos_scale**-1.0)
            if xpos_scale is not None
            else (1.0, 1.0)
        )

        (ql, qr), (kl, kr), (vl, vr) = map(
            lambda t: (t[..., :l], t[..., l:]), (q, k, v)
        )
        # We split the qkv values into the left and right parts on the n_embd dimension.
        # batch, n_head, block_size, dim_head/2

        ql, kl, vl = map(
            lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]),
            ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)),
        )
        # We apply the rotary embeddings only to the left part
        
        q, k, v = map(
            lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))
        )
        # We rebuild the qkv values into the old dimensions -> batch_size, n head (or kv), block size, dim_head
        
        attn_bias = pos_emb(T, T)
        
        k = k.repeat(1,2,1,1)
        v = v.repeat(1,2,1,1)
        # We repeat the dimensions needed to the kv heads
        
        default_scale = q.shape[-1] ** -0.5
        q *= (default_scale / self.qk_norm_scale)
        # We apply normalization again to the q value?
            
        attn_bias = attn_bias.unsqueeze(0).expand(B, self.n_head, -1, -1)
        mask_value = -torch.finfo(q.dtype).max
        causal_mask = create_causal_mask(
            q.size(-2), k.size(-2), device=device
        )
        # We add the batch dimension to the attn bias

        attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)

        # attend:  (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=self.dropout * self.training, is_causal=False)  # [B, n_head, block_size, dim_head]
        y = y.transpose(1, 2).contiguous().view(B, T, self.dim_head * self.n_head) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [None]:
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.glu     = GLU(config.n_embd, 4 * config.n_embd, nn.SiLU())
        self.norm    = LayerNorm(4 * config.n_embd, bias=config.bias)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.glu(x)
        x = self.c_proj(self.norm(x))
        x = self.dropout(x)
        return x

In [None]:
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.ln_2 = RMSNorm(config.n_embd)
        self.cross_attn = CrossAttention(config)
        self.ln_3 = RMSNorm(config.n_embd)
        self.linear = nn.Linear(config.n_embd, config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x, encoded_x, rotary_pos_emb, pos_emb):
        # We add x to each layer to skip connections.
        x = x + self.cross_attn(self.ln_2(x), encoded_x, rotary_pos_emb, pos_emb)
        x = x + self.mlp(self.ln_3(x))
        return x

In [None]:
@dataclass
class DecoderConfig:
    block_size: int = 2046
    vocab_size: int = 10172
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 320
    dim_head: int = 64
    dropout: float = 0.0
    bias: bool = False # True: bias in Linears and LayerNorms. False: a bit better and faster
    mask: bool = False # Whether or not the attention is causal. (the future tokens get masked away)
    kv_heads: int = 2
    alibi_num_heads: int = 3
    device: str = 'cuda'

In [None]:
class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.base_std = 1/math.sqrt(config.n_embd) # The base std for initialization is calculated by taking the inverse of the sqrt of the embedding size
        
        self.decoder = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            rms = RMSNorm(config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            snorm = RMSNorm(config.n_embd),
        ))
        
        rotary_emb_dim = config.dim_head // 2
        self.rotary_pos_emb = RotaryEmbedding(
            rotary_emb_dim,
            scale_base=512,
            interpolation_factor=1,
            alpha=1,
        )
        self.pos_emb = AlibiPositionalBias(config.alibi_num_heads, config.n_head)

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # This applies weight tying. It is useful to tie the weights of the head that generates the logit tokens with the token embedding layer.
        self.decoder.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=self.base_std/math.sqrt(2 * config.n_layer)) # Why 2? 2 operations on each Block

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.base_std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.base_std)

    def forward(self, idx, imgs, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

        tok_emb = self.decoder.wte(idx) # token embeddings of shape (b, t, n_embd)
        
        x = self.decoder.drop(self.decoder.rms(tok_emb))
        
        rotary_pos_emb = self.rotary_pos_emb(
            x.size(1), x.device
        )
        
        for idx in range(self.config.n_layer):
            x = self.decoder.h[idx](x, imgs, rotary_pos_emb, self.pos_emb)
            x = self.decoder.snorm(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # during inference we only need to apply the head to the temporal dimension
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, imgs, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            x = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            logits, _ = self(x, imgs)
                        
            logits = logits[:, -1, :] / temperature
            
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
                
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)

            sample = torch.multinomial(probs, 1)

            idx = torch.cat((idx, sample), dim=-1)

        return idx

In [None]:
vocab_size = 10172
block_size = 2046
n_layer = 4
n_embd = 320
dropout = 0.0
n_head = 4

dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
bias = False
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}

In [None]:
device = 'cuda'

In [None]:
torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

In [None]:
# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=vocab_size, dropout=dropout, device=device) # init the model

# This sets the matrix calculations precision to tensorfloat 32, which speeds up computation by a lot, with negligible cost for precision
# Check https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html for more info
torch.set_float32_matmul_precision('high')

In [None]:
from torch import Tensor

def patch_img(x: Tensor, patches: int):
    B, C, H, W = x.size()

    # Step 1: Reshape to (b, c, h//patch_size, patch_size, w//patch_size, patch_size)
    x_reshaped = x.view(B, C, H // patches, patches, W // patches, patches)

    # Step 2: Permute to (b, h//patch_size, w//patch_size, patch_size, patch_size, c)
    x_permuted = x_reshaped.permute(0, 2, 4, 3, 5, 1)

    # Step 3: Reshape to (b, h//patch_size * w//patch_size, patch_size * patch_size * c)
    x_final = x_permuted.contiguous().view(B, (H // patches) * (W // patches), patches * patches * C)
    return x_final


def threed_to_text_fixed(x: torch.Tensor, W1: Tensor, W2: Tensor):
    """
    Transforms a patched 3d image into a text representation using a fixed transformation provided by W1 and W2.
    x has to be patched.
    W1 has to have dimensions -> (x.size(-1), n_embd)
    W2 has to have dimensions -> (x.size(1), block_size)
    """
    x = x @ W1
    x = x.transpose(1, 2)
    x = x @ W2
    x = x.transpose(1, 2)
    
    return x


class Fuyu(nn.Module):
    def __init__(
        self,
        config,
        image_reshape_is_learnable: bool=False,
        patches: int=16,
    ):
        super().__init__()
        self.config = config
        self.patches = patches
        self.fuyu = Transformer(dconf)
        self.s_norm = LayerNorm(self.config.n_embd, bias=True)
        self.image_reshape_is_learnable = image_reshape_is_learnable
        
        if not image_reshape_is_learnable:
            # In case the image transformation does not need learnable parameters we will do it using fixed matrices
            random_img = torch.randn(1,3,256,256)
            random_img = patch_img(random_img, patches=self.patches)
            _, S, D = random_img.size()
            # Fixed transformation matrices.
            self.W1 = torch.randn(D, config.n_embd, device=config.device)
            self.W2 = torch.randn(S, config.block_size, device=config.device)
        else:
            # If the image reshaping needs to be learnable
            self.threed_to_text = nn.Sequential(
                nn.Linear(self.patches*self.patches*3, self.config.n_embd),
                nn.Linear(256, self.config.block_size)
            )
        
    def forward(
        self, text: torch.Tensor, img: torch.Tensor=None, targets: torch.Tensor=None
    ):
        """
        Forward pass of the model.

        Text input shape: [batch, block_size, n_embd]
        img input shape: [batch, channels, height, width]

        Output shape: [batch, 1, vocab_size]

        """
        try:
            # If image is provided, concat it with the text
            if img is not None:
                # Patch the image
                img = patch_img(img, patches=self.patches)
                if self.image_reshape_is_learnable:
                    img = self.threed_to_text[0](img)
                    img = img.transpose(1, 2)
                    img = self.threed_to_text[1](img)
                    img = img.transpose(1, 2)
                else:
                    img = threed_to_text_fixed(img, self.W1, self.W2)
                img = self.s_norm(img)
            return self.fuyu(text, img, targets)
        except Exception as e:
            print("Failed in forward method: ", e)
            raise
    
    @torch.no_grad()
    def generate(self, text, img, max_new_tokens, temperature=1.0, top_k=None):
        if img is not None:
            # Patch the image
            img = patch_img(img, patches=self.patches)
            if self.image_reshape_is_learnable:
                img = self.threed_to_text[0](img)
                img = img.transpose(1, 2)
                img = self.threed_to_text[1](img)
                img = img.transpose(1, 2)
            else:
                img = threed_to_text_fixed(img, self.W1, self.W2)
            img = self.s_norm(img)
        return self.fuyu.generate(text, img, max_new_tokens, temperature, top_k)

In [None]:
dconf = DecoderConfig(**model_args)
model = Fuyu(
    dconf,
    patches=16,
)

In [None]:
model.to(device)

In [None]:
# Text shape: [batch, block_size]
text = torch.randint(0, vocab_size, (1, block_size), device=device)

# Img shape: [batch, channels, height, width]
img = torch.randn(1, 3, 256, 256, device=device)

# Random targets to test the loss function
targets = torch.randint(0, vocab_size, (1, block_size), device=device)

# Apply model to text and img
y = model(text, img, targets=None)

# Output shape: [batch, block_size, vocab_size] if targets are provided. [batch, 1, vocab_size] if targets are not provided
print(y, y[0].size(), y[1])

In [None]:
model.eval()
sample = model.generate(text, img, 30, temperature=1, top_k=200)
sample.size()