### Setup

In [1]:
!pip install sentencepiece transformers --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m69.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m72.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from pdb import set_trace as bkpt
import math
from typing import Callable
import gc
from time import sleep
from dataclasses import dataclass
from collections import OrderedDict

import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Module, Embedding, Linear, LayerNorm, MultiheadAttention, ModuleList, Softmax
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity

from transformers import T5Tokenizer

### Model

In [3]:
@dataclass
class TransformerConfig:
    vocab_size: int
    n_layers: int
    d_model: int
    d_ff: int
    n_heads: int
    d_k: int
    num_relative_pos: int
    eps: float = 1e-6

In [4]:
def relative_position_bucket(relative_position: Tensor, is_decoder: bool, num_buckets=32, max_distance=128) -> Tensor:
        relative_buckets = 0
        if is_decoder:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        else:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)

        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

def compute_bias(query_length: int, key_length: int, embedding: Callable, is_decoder: bool, device=None) -> Tensor:
    if device is None:
        device = embedding.pos_embedding_layer.weight.device
    context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
    memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    position_bucket = relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        is_decoder=is_decoder,
    )
    values = embedding(position_bucket)  # shape (query_length, key_length, n_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, n_heads, query_length, key_length)
    return values

In [5]:
class EmbeddingLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(EmbeddingLayer, self).__init__()
        self.word_embedding_layer = Embedding(config.vocab_size, config.d_model)

    def forward(self, x: Tensor) -> Tensor:
        return self.word_embedding_layer(x)


class PositionEmbeddingLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(PositionEmbeddingLayer, self).__init__()
        self.pos_embedding_layer = Embedding(config.num_relative_pos, config.n_heads)

    def forward(self, x: Tensor) -> Tensor:
        return self.pos_embedding_layer(x)


class MLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(MLP, self).__init__()
        self.ff1 = Linear(config.d_model, config.d_ff, bias=False)
        self.ff2 = Linear(config.d_ff, config.d_model, bias=False)
        self.activation = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        x = self.activation(self.ff1(x))
        return self.ff2(x)


class Norm(nn.Module):
    def __init__(self, config: TransformerConfig, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(config.d_model))
        self.eps = eps

    def forward(self, x: Tensor) -> Tensor:
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)

        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * x


class SelfAttention(nn.Module):
    def __init__(self, config: TransformerConfig, is_decoder: bool):
        super(SelfAttention, self).__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_k = int(config.d_model / config.n_heads)
        self.w_q = Linear(config.d_model, config.d_model, bias=False)
        self.w_k = Linear(config.d_model, config.d_model, bias=False)
        self.w_v = Linear(config.d_model, config.d_model, bias=False)
        self.w_o = Linear(config.d_model, config.d_model, bias=False)
        self.is_decoder = is_decoder

    def split_heads(self, x: Tensor) -> Tensor:
        batch_size, n, _ = x.size()
        x = x.view((batch_size, n, self.n_heads, self.d_k))
        x = x.transpose(1, 2)
        return x

    def unify_heads(self, x: Tensor) -> Tensor:
        batch_size, _, n, _ = x.size()
        x = x.transpose(1, 2)
        x = x.reshape((batch_size, n, self.d_model))
        return x

    def forward(self, x: Tensor, position_bias: Tensor) -> Tensor:
        _, n, _ = x.size()
        if self.is_decoder:
            Q = self.w_q(x[:, -1, :].unsqueeze(1))
        else:
            Q = self.w_q(x)
        K, V = self.w_k(x), self.w_v(x)
        Q, K, V = self.split_heads(Q), self.split_heads(K), self.split_heads(V)
        scores = (Q @ K.transpose(-1, -2))
        scores += position_bias
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
        split_attention = attn_weights @ V
        attention = self.unify_heads(split_attention)
        output = self.w_o(attention)
        return output


class EncDecAttention(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(EncDecAttention, self).__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_k = int(config.d_model / config.n_heads)
        self.w_q = Linear(config.d_model, config.d_model, bias=False)
        self.w_k = Linear(config.d_model, config.d_model, bias=False)
        self.w_v = Linear(config.d_model, config.d_model, bias=False)
        self.w_o = Linear(config.d_model, config.d_model, bias=False)

    def split_heads(self, x: Tensor) -> Tensor:
        batch_size, n, _ = x.size()
        x = x.view((batch_size, n, self.n_heads, self.d_k))
        x = x.transpose(1, 2)
        return x

    def unify_heads(self, x: Tensor) -> Tensor:
        batch_size, _, n, _ = x.size()
        x = x.transpose(1, 2)
        x = x.reshape((batch_size, n, self.d_model))
        return x

    def forward(self, x: Tensor, encoding: Tensor, position_bias: Tensor) -> Tensor:
        _, n, _ = x.size()
        _, encoding_n, _ = encoding.size()
        # Q = self.w_q(x[:, -1, :].unsqueeze(1))
        Q = self.w_q(x) # not sure about cross-attn implementation here
        K, V = self.w_k(encoding), self.w_v(encoding)
        Q, K, V = self.split_heads(Q), self.split_heads(K), self.split_heads(V)
        scores = (Q @ K.transpose(-1, -2))
        scores += position_bias
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
        split_attention = attn_weights @ V
        attention = self.unify_heads(split_attention)
        output = self.w_o(attention)
        return output


class EncoderLayer(nn.Module):    
    def __init__(self, config: TransformerConfig):
        super(EncoderLayer, self).__init__()
        self.self_attention = SelfAttention(config=config, is_decoder=False)
        self.norm1 = Norm(config)
        self.mlp = MLP(config=config)
        self.norm2 = Norm(config)

    def forward(self, x: Tensor, position_bias: Tensor) -> Tensor:
        normed_x = self.norm1(x)
        x = x + self.self_attention(normed_x, position_bias=position_bias)
        normed_x = self.norm2(x)
        x = x + self.mlp(normed_x)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(DecoderLayer, self).__init__()
        self.self_attention = SelfAttention(config=config, is_decoder=True)
        self.norm1 = Norm(config)
        self.enc_dec_attention = EncDecAttention(config=config)
        self.norm2 = Norm(config)
        self.mlp = MLP(config=config)
        self.norm3 = Norm(config)

    def forward(self, x: Tensor, encoding: Tensor, self_attention_position_bias: Tensor, enc_dec_attention_position_bias: Tensor) -> Tensor:
        normed_x = self.norm1(x)
        attn_output = self.self_attention(normed_x, position_bias=self_attention_position_bias)
        x += attn_output
        normed_x = self.norm2(x)
        cross_attn_output = self.enc_dec_attention(normed_x, encoding, position_bias=enc_dec_attention_position_bias)
        x += cross_attn_output
        normed_x = self.norm3(x)
        mlp_output = self.mlp(normed_x)
        x += mlp_output
        return x


class Encoder(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(Encoder, self).__init__()
        self.layers = ModuleList([EncoderLayer(config=config) for i in range(config.n_layers)])
        self.self_attention_relative_attention_embedding = PositionEmbeddingLayer(config)

    def forward(self, x: Tensor) -> Tensor:
        _, n, _ = x.size()
        self.self_attention_position_bias = compute_bias(n, n, self.self_attention_relative_attention_embedding, is_decoder=False)
        for layer in self.layers:
            x = layer(x, self.self_attention_position_bias)
        return x


class Decoder(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(Decoder, self).__init__()
        self.layers = ModuleList([DecoderLayer(config=config) for i in range(config.n_layers)])
        self.self_attention_relative_attention_embedding = PositionEmbeddingLayer(config)
        self.enc_dec_attention_relative_attention_embedding = PositionEmbeddingLayer(config)

    def forward(self, x: Tensor, encoding: Tensor) -> Tensor:
        _, n, _ = x.size()
        _, encoding_n, _ = encoding.size()
        self_attention_position_bias = compute_bias(1, n, self.self_attention_relative_attention_embedding, is_decoder=True)
        enc_dec_attention_position_bias = compute_bias(1, encoding_n, self.enc_dec_attention_relative_attention_embedding, is_decoder=True)
        for layer in self.layers:
            x = layer(x, 
                      encoding, 
                      self_attention_position_bias=self_attention_position_bias, 
                      enc_dec_attention_position_bias=enc_dec_attention_position_bias)
        return x[:, -1, :].unsqueeze(1)

class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super(Transformer, self).__init__()
        self.embedding_layer = EmbeddingLayer(config=config)
        self.encoder = Encoder(config=config)
        self.final_encoder_layer_norm = Norm(config=config)
        self.decoder = Decoder(config=config)
        self.final_decoder_layer_norm = Norm(config=config)
        self.d_model = config.d_model
        self.lm_head = Linear(config.d_model, config.vocab_size, bias=False)

    @torch.no_grad()
    def forward(self, x: Tensor, max_tokens=50, debug=False) -> Tensor:
        x_emb = self.embedding_layer(x)
        encoding = self.encoder(x_emb)
        encoding = self.final_encoder_layer_norm(encoding)
        outputs = torch.IntTensor([[0]]).to("cuda")
        while outputs[:, -1] != 1 and len(outputs[0]) < max_tokens:
            if debug:
                bkpt()
            decoder_inputs_embedding = self.embedding_layer(outputs)
            decoding = self.decoder(decoder_inputs_embedding, encoding)
            decoding = self.final_decoder_layer_norm(decoding)
            decoding *= (self.d_model ** -0.5)
            next_logits = self.lm_head(decoding)
            next_token = torch.argmax(next_logits, dim=-1)
            outputs = torch.cat([outputs, next_token], dim=-1)
        return outputs

### Initialize and load weights

In [6]:
transformer_config = TransformerConfig(vocab_size=32128,
                                       n_layers=6,
                                       d_model=512,
                                       d_ff=2048,
                                       n_heads=8,
                                       d_k=64,
                                       num_relative_pos=32)

In [7]:
transformer = Transformer(config=transformer_config)

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [12]:
# Upload your weights here!
# Download from https://huggingface.co/t5-small/tree/main and upload to your Google Drive in the root folder
t5_small_weights = torch.load("drive/MyDrive/t5-small-torch-weights.bin")

In [13]:
def dict_from_array_of_dicts(dicts_array):
    res = {}
    for dict_element in dicts_array:
        res = {**res, **dict_element}
    return res

embedding_dict = {'embedding_layer.word_embedding_layer.weight': torch.Size([32128, 512])}
first_encoder_layer_dict = {
 'encoder.layers.0.self_attention.w_q.weight': torch.Size([512, 512]),
 'encoder.layers.0.self_attention.w_k.weight': torch.Size([512, 512]),
 'encoder.layers.0.self_attention.w_v.weight': torch.Size([512, 512]),
 'encoder.layers.0.self_attention.w_o.weight': torch.Size([512, 512]),
 'encoder.self_attention_relative_attention_embedding.pos_embedding_layer.weight': torch.Size([32, 8]),
 'encoder.layers.0.norm1.weight': torch.Size([512]),
 'encoder.layers.0.mlp.ff1.weight': torch.Size([2048, 512]),
 'encoder.layers.0.mlp.ff2.weight': torch.Size([512, 2048]),
 'encoder.layers.0.norm2.weight': torch.Size([512])
}
encoder_dicts = [{
    f'encoder.layers.{i}.self_attention.w_q.weight': torch.Size([512, 512]),
    f'encoder.layers.{i}.self_attention.w_k.weight': torch.Size([512, 512]),
    f'encoder.layers.{i}.self_attention.w_v.weight': torch.Size([512, 512]),
    f'encoder.layers.{i}.self_attention.w_o.weight': torch.Size([512, 512]),
    f'encoder.layers.{i}.norm1.weight': torch.Size([512]),
    f'encoder.layers.{i}.mlp.ff1.weight': torch.Size([2048, 512]),
    f'encoder.layers.{i}.mlp.ff2.weight': torch.Size([512, 2048]),
    f'encoder.layers.{i}.norm2.weight': torch.Size([512]),
} for i in range(1, 6)]
final_encoder_layer_norm_dict = {'final_encoder_layer_norm.weight': torch.Size([512])}
first_decoder_layer_dict = {
 'decoder.layers.0.self_attention.w_q.weight': torch.Size([512, 512]),
 'decoder.layers.0.self_attention.w_k.weight': torch.Size([512, 512]),
 'decoder.layers.0.self_attention.w_v.weight': torch.Size([512, 512]),
 'decoder.layers.0.self_attention.w_o.weight': torch.Size([512, 512]),
 'decoder.self_attention_relative_attention_embedding.pos_embedding_layer.weight': torch.Size([32, 8]),
 'decoder.layers.0.norm1.weight': torch.Size([512]),
 'decoder.layers.0.enc_dec_attention.w_q.weight': torch.Size([512, 512]),
 'decoder.layers.0.enc_dec_attention.w_k.weight': torch.Size([512, 512]),
 'decoder.layers.0.enc_dec_attention.w_v.weight': torch.Size([512, 512]),
 'decoder.layers.0.enc_dec_attention.w_o.weight': torch.Size([512, 512]),
 'decoder.enc_dec_attention_relative_attention_embedding.pos_embedding_layer.weight': torch.Size([32, 8]),
 'decoder.layers.0.norm2.weight': torch.Size([512]),
 'decoder.layers.0.mlp.ff1.weight': torch.Size([2048, 512]),
 'decoder.layers.0.mlp.ff2.weight': torch.Size([512, 2048]),
 'decoder.layers.0.norm3.weight': torch.Size([512]),
}
decoder_dicts = [{
    f'decoder.layers.{i}.self_attention.w_q.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.self_attention.w_k.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.self_attention.w_v.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.self_attention.w_o.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.norm1.weight': torch.Size([512]),
    f'decoder.layers.{i}.enc_dec_attention.w_q.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.enc_dec_attention.w_k.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.enc_dec_attention.w_v.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.enc_dec_attention.w_o.weight': torch.Size([512, 512]),
    f'decoder.layers.{i}.norm2.weight': torch.Size([512]),
    f'decoder.layers.{i}.mlp.ff1.weight': torch.Size([2048, 512]),
    f'decoder.layers.{i}.mlp.ff2.weight': torch.Size([512, 2048]),
    f'decoder.layers.{i}.norm3.weight': torch.Size([512]),
} for i in range(1, 6)]
final_decoder_layer_norm_dict = {'final_decoder_layer_norm.weight': torch.Size([512])}
lm_head_dict = {'lm_head.weight': torch.Size([32128, 512])}

new_weights = {
    **embedding_dict,
    **first_encoder_layer_dict,
    **(dict_from_array_of_dicts(encoder_dicts)),
    **final_encoder_layer_norm_dict,
    **first_decoder_layer_dict,
    **(dict_from_array_of_dicts(decoder_dicts)),
    **final_decoder_layer_norm_dict,
    **lm_head_dict
}

t5_small_weights['lm_head.weight'] = t5_small_weights['shared.weight']

In [14]:
# Check that number of parameter tensors matches
len(new_weights.keys()), len(t5_small_weights.keys())

(133, 133)

In [15]:
# Copy weights
t5_modified_weights = OrderedDict()

for (kv0, kv1) in zip(t5_small_weights.items(), new_weights.items()):
    key0, value0 = kv0
    key1, value1 = kv1
    t5_modified_weights[key1] = value0.detach().clone()

In [16]:
# Call this function to load new weights
def load_weights():
    transformer.load_state_dict(t5_modified_weights)

In [17]:
load_weights()

In [18]:
# Put model on GPU
transformer = transformer.to("cuda")

### Inference

In [19]:
# Use the tokenizer from Hugging Face transformers library
tokenizer = T5Tokenizer.from_pretrained("t5-small")

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [20]:
input_sentence = "summarize: Space exploration has always fascinated humans, and has led to countless discoveries and advancements in technology. The study of our solar system and beyond continues to uncover new mysteries, from the formation of planets to the search for extraterrestrial life. With the help of advanced technologies such as telescopes and spacecraft, scientists are able to gain a deeper understanding of the universe and our place in it. The pursuit of knowledge and discovery drives continued investment in space exploration, leading to a better understanding of not just the cosmos, but also ourselves."
inputs = tokenizer(input_sentence, return_tensors="pt")
inputs.to("cuda")

{'input_ids': tensor([[21603,    10,  5844,  9740,    65,   373, 24631,  6917,     6,    11,
            65,  2237,    12,     3, 11394, 25175,    11, 14500,     7,    16,
           748,     5,    37,   810,    13,    69,  3693,   358,    11,  1909,
          3256,    12, 19019,   126, 29063,     6,    45,     8,  3239,    13,
          4345,     7,    12,     8,   960,    21,   996,   449,  6216, 12042,
           280,     5,   438,     8,   199,    13,  2496,  2896,   224,    38,
         27480,     7,    11,   628,  6696,     6,  7004,    33,     3,   179,
            12,  2485,     3,     9,  7231,  1705,    13,     8,  8084,    11,
            69,   286,    16,    34,     5,    37, 13709,    13,  1103,    11,
          9087,  9350,  2925,  1729,    16,   628,  9740,     6,  1374,    12,
             3,     9,   394,  1705,    13,    59,   131,     8,   576,     7,
          3972,     6,    68,    92,  3242,     5,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 

In [21]:
use_profiler = False

In [22]:
if use_profiler:
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            transformer_outputs = transformer(inputs.input_ids, debug=False, max_tokens=20)
else:
    transformer_outputs = transformer(inputs.input_ids, debug=False, max_tokens=50)
    
decoded_outputs = tokenizer.decode(transformer_outputs[0], skip_special_tokens=True)
decoded_outputs

'the study of our solar system and beyond continues to uncover new mysteries.'