In [1]:
import numpy as np
import tensorflow as tf
from xfmers import utils
from xfmers import ops
from xfmers import layers

In [2]:
def Transformer(vocab_size, dec_layers, ff_units, d_model, num_heads, dropout, max_seq_len=512, causal=False,
                       weight_sharing=False, efficient_attention=False, shared_qk=False, activation=ops.gelu,
                       conv_filter=1, conv_padding="same", reversible=False, fused_qkv=False, name="Transformer"):
    inputs = tf.keras.Input(shape=(None, ), name="inputs")
    padding_mask = layers.PaddingMaskGenerator()(inputs)
    embeddings = layers.TokenPosEmbedding(d_vocab=vocab_size, d_model=d_model, pos_length=max_seq_len, scale=1)(inputs)
    
    decoder_block = layers.TransformerStack(layers=dec_layers,
                                            ff_units=ff_units,
                                            d_model=d_model,
                                            num_heads=num_heads,
                                            dropout=dropout,
                                            causal=causal,
                                            activation=activation,
                                            weight_sharing=weight_sharing,
                                            conv_filter=conv_filter,
                                            conv_padding=conv_padding,
                                            reversible=reversible,
                                            fused_qkv=fused_qkv,
                                            name="DecoderBlock")
    dec_outputs = decoder_block({"token_inputs": embeddings,
                                 "mask_inputs": padding_mask})
    
    l_dropout = tf.keras.layers.Dropout(rate=dropout)(dec_outputs)
    
    preds = tf.keras.layers.Dense(vocab_size, name="outputs")(l_dropout)
    
    return tf.keras.Model(inputs=inputs, outputs=preds, name=name)

In [8]:
model = Transformer(vocab_size=8192,
                    dec_layers=12,
                    ff_units=3072,
                    d_model=768,
                    num_heads=12,
                    dropout=0.1,
                    max_seq_len=128,
                    fused_qkv=True,
                    causal=True)
model.summary()

Model: "Transformer"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inputs (InputLayer)             [(None, None)]       0                                            
__________________________________________________________________________________________________
PaddingMaskGenerator (PaddingMa (None, 1, 1, None)   0           inputs[0][0]                     
__________________________________________________________________________________________________
TokenPosEmbedding (TokenPosEmbe (None, None, 768)    6391296     inputs[0][0]                     
__________________________________________________________________________________________________
DecoderBlock (TransformerStack) (None, None, 768)    85054464    PaddingMaskGenerator[0][0]       
                                                                 TokenPosEmbedding[0][0]

In [4]:
tlayer = layers.TransformerLayer(ff_units=768*4, d_model=768, num_heads=12, dropout=0.01, causal=True)
_ = tlayer({"token_inputs": np.zeros((1,128,768)),
            "mask_inputs": np.zeros((1,128,1))})



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



In [5]:
tlayer.summary()

Model: "TransformerLayer"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
TransformerLayer_MultiHeadAt multiple                  2362368   
_________________________________________________________________
dropout_25 (Dropout)         multiple                  0         
_________________________________________________________________
layer_normalization_25 (Laye multiple                  1536      
_________________________________________________________________
conv1d_24 (Conv1D)           multiple                  2362368   
_________________________________________________________________
activation_12 (Activation)   multiple                  0         
_________________________________________________________________
conv1d_25 (Conv1D)           multiple                  2360064   
_________________________________________________________________
dropout_26 (Dropout)         multiple             

In [6]:
for layer in tlayer.layers:
    print("name:", layer.name, " - params:", layer.count_params())

name: TransformerLayer_MultiHeadAttention  - params: 2362368
name: dropout_25  - params: 0
name: layer_normalization_25  - params: 1536
name: conv1d_24  - params: 2362368
name: activation_12  - params: 0
name: conv1d_25  - params: 2360064
name: dropout_26  - params: 0
name: layer_normalization_26  - params: 1536


In [7]:
tlayer.layers[0].summary()

Model: "TransformerLayer_MultiHeadAttention"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
d_query (Dense)              multiple                  590592    
_________________________________________________________________
d_key (Dense)                multiple                  590592    
_________________________________________________________________
d_value (Dense)              multiple                  590592    
_________________________________________________________________
d_mha_final (Dense)          multiple                  590592    
Total params: 2,362,368
Trainable params: 2,362,368
Non-trainable params: 0
_________________________________________________________________
