1. Input Embedding: https://enakai00.hatenablog.com/entry/2023/02/10/102036
2. Multi-Head Attention: https://enakai00.hatenablog.com/entry/2023/02/10/144940
3. https://enakai00.hatenablog.com/entry/2023/02/10/180105
4. https://enakai00.hatenablog.com/entry/2023/02/10/195227

In [53]:
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame
from functools import partial

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import fit, train_state, checkpoints

plt.rcParams.update({'font.size': 12})

ImportError: cannot import name 'fit' from 'flax.training' (/Users/junhyeong.kim/Workspaces/JAX_transformer/.venv/lib/python3.11/site-packages/flax/training/__init__.py)

In [3]:
from datasets import load_dataset
emotions = load_dataset('emotion')

Downloading builder script: 100%|██████████| 3.97k/3.97k [00:00<00:00, 485kB/s]
Downloading metadata: 100%|██████████| 3.28k/3.28k [00:00<00:00, 821kB/s]
Downloading readme: 100%|██████████| 8.78k/8.78k [00:00<00:00, 1.45MB/s]


Downloading and preparing dataset emotion/split to /Users/junhyeong.kim/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd...


Downloading data: 100%|██████████| 592k/592k [00:00<00:00, 8.51MB/s]
Downloading data: 100%|██████████| 74.0k/74.0k [00:00<00:00, 5.07MB/s]
Downloading data: 100%|██████████| 74.9k/74.9k [00:00<00:00, 4.98MB/s]
Downloading data files: 100%|██████████| 3/3 [00:08<00:00,  2.81s/it]
Extracting data files: 100%|██████████| 3/3 [00:00<00:00, 108.88it/s]
                                                                                       

Dataset emotion downloaded and prepared to /Users/junhyeong.kim/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 547.37it/s]


In [6]:
emotions['train']['text'][:2]

['i didnt feel humiliated',
 'i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake']

In [8]:
from transformers import AutoTokenizer, AutoConfig

model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
vocab_size = AutoConfig.from_pretrained(model_ckpt).vocab_size

vocab_size

30522

In [9]:
max([len(text.split(' ')) for text in emotions['train']['text'] + emotions['validation']['text']])

66

In [17]:
text_length = 128

# training set
train_set = tokenizer(
    emotions['train']['text'],
    max_length=text_length,
    padding='max_length',
    truncation=True
)
train_text = np.array(train_set['input_ids'])
train_mask = np.array(train_set['attention_mask'])
train_label = np.eye(6)[emotions['train']['label']]

# validation set
valid_set = tokenizer(
    emotions['validation']['text'],
    max_length=text_length,
    padding='max_length',
    truncation=True
)
valid_text = np.array(valid_set['input_ids'])
valid_mask = np.array(valid_set['attention_mask'])
valid_label = np.eye(6)[emotions['validation']['label']]

# label map
emotion_labels = emotions['train'].features['label'].names

In [18]:
train_text[0]

array([  101,  1045,  2134,  2102,  2514, 26608,   102,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0]

In [19]:
class Embeddings(nn.Module):
    embed_dim: int
    text_length: int = text_length
    vocab_size: int = vocab_size
    
    @nn.compact
    def __call__(self, input_ids, eval):
        token_embeddings = nn.Embed(
            self.vocab_size, self.embed_dim
        )(input_ids)

        position_ids = jnp.arange(self.text_length)
        position_embeddings = nn.Embed(
            self.text_length, self.embed_dim
        )(position_ids)

        # [N, トークン数, 埋め込み空間の次元] + [トークン数, 埋め込み空間の次元]
        embeddings = token_embeddings + position_embeddings

        embeddings = nn.LayerNorm(epsilon=1e-12)(embeddings)
        embeddings = nn.Dropout(0.5, deterministic=eval)(embeddings)
        
        return embeddings

In [28]:
variables = Embeddings(embed_dim=512).init(random.PRNGKey(0), train_text[:1], eval=True)
jax.tree_util.tree_map(lambda x: x.shape, variables['params'])

FrozenDict({
    Embed_0: {
        embedding: (30522, 512),
    },
    Embed_1: {
        embedding: (128, 512),
    },
    LayerNorm_0: {
        bias: (512,),
        scale: (512,),
    },
})

In [29]:
input_text = train_text[:3]
output = Embeddings(embed_dim=512).apply(variables, input_text, eval=True)

input_text.shape, output.shape

((3, 128), (3, 128, 512))

In [30]:
class AttentionHead(nn.Module):
    head_dim: int
    
    def scaled_dot_product_attention(self, q, k, v, mask):  # mask: [テキスト数, トークン数]
        scores = jnp.matmul(q, jnp.transpose(k, (0, 2, 1)))
        if mask is not None:
            mask = jnp.tile(mask, mask.shape[-1]).reshape(
                mask.shape[0], -1, mask.shape[-1]
            )  # mask: [テキスト数, トークン数, トークン数]
            scores = jnp.where(mask == 0, -jnp.inf, scores)
        w = nn.softmax(scores / jnp.sqrt(self.head_dim))  # w: [テキスト数, トークン数（Query側）, トークン数（Key側）]
        return jnp.matmul(w, v)

    @nn.compact
    def __call__(self, hidden_state, attention_mask): # hidden_state: [テキスト数, トークン数, 埋め込み空間の次元]
        q = nn.Dense(features=self.head_dim)(hidden_state)  # q: [テキスト数, トークン数, Query の次元数]
        k = nn.Dense(features=self.head_dim)(hidden_state)  # k: [テキスト数, トークン数, Key の次元数]
        v = nn.Dense(features=self.head_dim)(hidden_state)  # v: [テキスト数, トークン数, Value の次元数]
        output = self.scaled_dot_product_attention(
            q=q,
            k=k,
            v=v,
            mask=attention_mask
        )
        return output  # output: [テキスト数, トークン数, Value の次元数]

In [37]:
class MultiHeadAttention(nn.Module):
    num_heads: int
    embed_dim: int

    def setup(self):
        head_dim = self.embed_dim // self.num_heads
        self.attention_heads = [AttentionHead(head_dim=head_dim) for _ in jnp.arange(self.num_heads)]

    @nn.compact
    def __call__(self, hidden_state, attention_mask):
        attention_outputs = [head(hidden_state, attention_mask) for head in self.attention_heads]
        x = jnp.concatenate(attention_outputs, axis=-1)
        x = nn.Dense(features=self.embed_dim)(x)
        return x

In [43]:
class FeedForward(nn.Module):
    embed_dim: int
    intermediate_size: int = 2_048
    
    @nn.compact
    def __call__(self, x, eval):
        x = nn.Dense(features=self.intermediate_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.embed_dim)(x)
        x = nn.Dropout(rate=0.1, deterministic=eval)(x)
        return x

In [48]:
class TransformerEncoderBlock(nn.Module):
    num_heads: int
    embed_dim: int
    
    def setup(self):
        self.attention = MultiHeadAttention(
            num_heads=self.num_heads,
            embed_dim=self.embed_dim
        )
        self.feed_forward = FeedForward(
            embed_dim=self.embed_dim
        )
    
    @nn.compact
    def __call__(self, x, attention_mask, eval):
        x = x + self.attention(hidden_state=x, attention_mask=attention_mask)  # Skip connection
        x = nn.LayerNorm()(x)
        x = x + self.feed_forward(x, eval)
        x = nn.LayerNorm()(x)
        return x

In [49]:
class TransformerEncoder(nn.Module):
    num_heads: int
    embed_dim: int
    num_hidden_layers: int
    
    def setup(self):
        self.embeddings = Embeddings(embed_dim=self.embed_dim)
        self.layers = [
            TransformerEncoderBlock(
                num_heads=self.num_heads,
                embed_dim=self.embed_dim
            ) for _ in range(self.num_hidden_layers)
        ]
    
    def __call__(self, input_ids, attention_mask, eval):
        x = self.embeddings(input_ids=input_ids, eval=eval)
        for layer in self.layers:
            x = layer(x=x, attention_mask=attention_mask, eval=eval)
        return x

In [50]:
class TransformerForSequenceClassifier(nn.Module):
    num_labels: int
    num_heads: int
    embed_dim: int
    num_hidden_layers: int
    
    def setup(self):
        self.transformer_encoder = TransformerEncoder(
            num_heads=self.num_heads,
            embed_dim=self.embed_dim,
            num_hidden_layers=self.num_hidden_layers
        )
    
    @nn.compact
    def __call__(self, input_ids, attention_mask=None, eval=True):
        x = self.transformer_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            eval=eval
            )[:, 0, :]  # select [CLS] token
        x = nn.Dropout(rate=0.1, deterministic=eval)(x)
        logits = nn.Dense(features=self.num_labels)(x)
        return logits