In [104]:
import scipy
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
from torch.utils.data import dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat


from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
import more_itertools
import math
from tqdm import tqdm

In [2]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
DEVICE = torch.device(device)
DEVICE

device(type='mps')

## DATA

In [3]:
X = np.load('sub1.npy')
stim1 = np.load('stim.npy')
stim2 = np.load('mask.npy')

In [4]:
Xx = []
stimulus = []
masker = []
for i in range(80):
    x = []
    for j in range(66):
        t = np.array(list(more_itertools.sliced(X[i][j],n=128)))
        x.append(t.T)
    Xx.append(np.asarray(x).T)
    stimulus.append(np.array(list(more_itertools.sliced(stim1[i],n=128))))
    masker.append(np.array(list(more_itertools.sliced(stim2[i],n=128))))

X = np.array(Xx)
stim = np.array(stimulus)
mask = np.array(masker)

X = X.reshape((80*33,128,66))
stim = stim.reshape((80*33,128))
mask = mask.reshape((80*33,128))

In [5]:
X_train,X_valid,y_tr,y_val,m,mv = train_test_split(X,stim,mask,test_size=0.2)

In [6]:
trsc = StandardScaler()
X_tr = []
X_val = []
for i in range(len(X_train)):
    X_tr.append(np.expand_dims(trsc.fit_transform(X_train[i]),axis=-1).astype(np.float32))

for i in range(len(X_valid)):
    X_val.append(np.expand_dims(trsc.fit_transform(X_valid[i]),axis=-1).astype(np.float32))

X_tr = np.array(X_tr)
X_val = np.array(X_val)

trst = StandardScaler()
y_tr = trst.fit_transform(y_tr.T).astype(np.float32).T
y_val = trst.fit_transform(y_val.T).astype(np.float32).T

trsm = StandardScaler()
m_tr = trsm.fit_transform(m.T).astype(np.float32).T
m_val = trsm.fit_transform(mv.T).astype(np.float32).T

In [7]:
y_tr = np.expand_dims(y_tr,axis=-1)
y_val = np.expand_dims(y_val,axis=-1)
m_val = np.expand_dims(m_val,axis=-1)
m_tr = np.expand_dims(m_tr,axis=-1)

In [36]:
y_tr.shape

(2112, 128, 1)

In [11]:
stimtr = np.roll(y_tr,1,axis=1)
stimval = np.roll(y_val,1,axis=1)
stimtr[:,0] = np.mean(X_tr,axis=2)[:,0]
stimval[:,0] =  np.mean(X_val,axis=2)[:,0]

In [12]:
train_dataset = tf.data.Dataset.from_tensor_slices(((X_tr,stimtr),y_tr))
val_dataset = tf.data.Dataset.from_tensor_slices(((X_val,stimval),y_val))

In [13]:
train_dataset = train_dataset.batch(4).shuffle(32)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

In [35]:
train_dataset

<_PrefetchDataset element_spec=((TensorSpec(shape=(None, 128, 66, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 1), dtype=tf.float32, name=None)), TensorSpec(shape=(None, 128, 1), dtype=tf.float32, name=None))>

In [14]:
val_dataset = val_dataset.batch(4)
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)

In [51]:
class PairDataset(Dataset):
    def __init__(self, eeg, att):
        self.eeg = eeg
        self.att = att

    def __len__(self):
        return len(self.eeg)

    def __getitem__(self, idx):
        return self.eeg[idx,:,:,:], self.att[idx,:,:]

In [65]:
train_dataset = PairDataset(X_tr, y_tr)
train_dataloader = DataLoader(train_dataset, batch_size=1,
                              shuffle=True)

In [64]:
val_dataset = PairDataset(X_val, y_val)
val_dataloader = DataLoader(val_dataset, batch_size=1)

In [66]:
for eeg, att in train_dataloader:  
    sample_eeg = eeg    # Reshape them according to your needs.
    sample_att = att
    break

In [67]:
sample_eeg.shape

torch.Size([1, 128, 66, 1])

## TRANSFORMER

In [327]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super().__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)
                        * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)
        self.emb_size = emb_size

    def forward(self, token_embedding: Tensor):
#         print("token_embedding.shape: ",token_embedding.shape)
#         print("token_embedding.size(0): ",token_embedding.size(0))
#         print("elf.pos_embedding[:token_embedding.size(0), :].shape: ",self.pos_embedding[:token_embedding.size(0), :].shape)
        return self.dropout(token_embedding * math.sqrt(self.emb_size)
                            + self.pos_embedding[:token_embedding.size(0), :])

In [446]:
class Embedding(nn.Module):
    def __init__(self, kdim):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 5, (1,16),(1,1),padding="same")
        self.gelu = F.gelu
        self.norm = nn.BatchNorm2d(5)
        self.depthConv = nn.Conv2d(5, kdim, (1,15),padding="same")
        self.conv2 = nn.Conv2d(5, 16, (15,1),padding="same")
        self.dense = nn.Linear(16,13)
        self.conv3 = nn.Conv1d(1, kdim, 2, padding="same")
        self.rearrange = Rearrange('b e (h) (w) -> b (h w) e')


    def forward(self, x):
        if len(x.shape) == 4:
            x = x.permute(0, 3, 1, 2)
            #print("INPUT: ", x.shape)
            x = self.gelu(self.conv1(x))
            #print("CONV: ", x.shape)
            x = self.norm(x)
            #print("BATCH: ", x.shape)
            x = self.gelu(self.conv2(x))
            #print("CONV: ", x.shape)
            x = self.rearrange(x) #x.permute(0, 2, 3, 1)
            #print("PERM: ", x.shape)
#             x = self.dense(x)
#             print("DENSE: ", x.shape)
        else:
            x = x.permute(0, 2, 1)
            x = self.gelu(self.conv3(x))
            x = x.permute(0, 2, 1)
        return x

In [448]:
Emb = Embedding(16)
for eeg, att in train_dataloader:  
    print(eeg.shape)
    print(att.shape)
    sample_eeg_embedding = Emb(eeg)    # Reshape them according to your needs.
    sample_att_embedding = Emb(att)
    break

print(sample_eeg_embedding.shape)
print(sample_att_embedding.shape)

torch.Size([1, 128, 66, 1])
torch.Size([1, 128, 1])
torch.Size([1, 8448, 16])
torch.Size([1, 128, 16])


In [442]:
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size):
        # self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(1, 2, (1, 51), (1, 1)),
            nn.BatchNorm2d(2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.one = nn.Conv2d(1, 2, (1, 51), (1, 1))
        self.two = nn.BatchNorm2d(2)
        self.three = nn.LeakyReLU(0.2)
        self.four = nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5))
        self.five = Rearrange('b e (h) (w) -> b (h w) e')
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        x = x.permute(0, 3, 1, 2)
        b, _, _, _ = x.shape
#         x = self.projection(x)
        print("INPUT: ", x.shape)
        x = self.one(x)
        print("CONV: ", x.shape)
        x = self.two(x)
        print("BATCH: ", x.shape)
        x = self.three(x)
        print("RELU: ", x.shape)
        x = self.four(x)
        print("CONV: ", x.shape)
        x = self.five(x)
        print("REARRANGE: ", x.shape)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

        # position
        # x += self.positions
        return x

In [443]:
Emb = PatchEmbedding(16)
for eeg, att in train_dataloader:  
    print(eeg.shape)
#     print(att.shape)
    sample_eeg_embedding = Emb(eeg)    # Reshape them according to your needs.
#     sample_att_embedding = Emb(att)
    break

print(sample_eeg_embedding.shape)
# print(sample_att_embedding.shape)

torch.Size([1, 128, 66, 1])
INPUT:  torch.Size([1, 1, 128, 66])
CONV:  torch.Size([1, 2, 128, 16])
BATCH:  torch.Size([1, 2, 128, 16])
RELU:  torch.Size([1, 2, 128, 16])
CONV:  torch.Size([1, 16, 113, 3])
REARRANGE:  torch.Size([1, 339, 16])
torch.Size([1, 339, 16])


In [477]:
class EEG_Transformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
#                  emb_size_eeg: int,
#                  emb_size_att: int,
                 nhead: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super().__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)

        self.emb_size = emb_size
        self.embedding = Embedding(emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
        # Bias to be compatible with embeddings
#         self.generator = nn.Linear(emb_size, vocab_size, bias=False)
#         self.generator.weight = self.embedding.weight  # Shared weights Sect. 3.4

    def forward(self,
                src: Tensor,
                trg: Tensor):
        src_emb = self.positional_encoding(self.embedding(src))
        tgt_emb = self.positional_encoding(self.embedding(trg))
        print("SIZE OF src_emb: ", src_emb.shape)
        print("SIZE OF tgt_emb: ", tgt_emb.shape)
        src_mask, tgt_mask = create_mask(src_emb, tgt_emb)
        print("SIZE OF src_mask: ", src_mask.shape)
        print("SIZE OF tgt_mask: ", tgt_mask.shape)
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None)
        return outs #self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
            self.embedding(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
            self.embedding(tgt)), memory,
            tgt_mask)


In [478]:
EMB_SIZE = 16
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = EEG_Transformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [479]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float(
        '-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [480]:
def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),
                           device=DEVICE).type(torch.bool)

    return src_mask, tgt_mask

In [481]:
for eeg, att in train_dataloader:  
    sample_src_mask, sample_tgt_mask = create_mask(eeg, att)
    break

print(sample_src_mask)
print(sample_tgt_mask)

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='mps:0')
tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')


In [482]:
def train_epoch(model, optimizer, dataloader):
    model.train()
    losses = 0
    sent_cnt = 0
    correct, total = 0, 0
    for eeg, att in tqdm(train_dataloader):
        eeg = eeg.to(DEVICE)
        att = att.to(DEVICE)

        att_input = att[:,:-1, :]

#         eeg_mask, att_mask = create_mask(eeg, att_input)
        
        print("EEG: ",eeg.shape)
        print("ATT: ",att_input.shape)
#         print("EEG MASK: ",eeg_mask.shape)
#         print("ATT MASK: ",att_mask.shape)

        logits = model(eeg,
                       att_input)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(
            logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

        with torch.no_grad():
            sent_cnt += tgt_out.size()[-1]

            total += torch.numel(tgt_input)
            _, char_pred = torch.max(logits, -1)
            correct += (char_pred == tgt_out).sum().item()

    return losses / sent_cnt, correct / total


In [483]:
train_loss, train_acc = train_epoch(transformer, optimizer, train_dataloader)

  0%|                                                                                                                                                               | 0/2112 [00:00<?, ?it/s]

EEG:  torch.Size([1, 128, 66, 1])
ATT:  torch.Size([1, 127, 1])
SIZE OF src_emb:  torch.Size([1, 8448, 16])
SIZE OF tgt_emb:  torch.Size([1, 127, 16])
SIZE OF src_mask:  torch.Size([8448, 8448])
SIZE OF tgt_mask:  torch.Size([127, 127])


  0%|                                                                                                                                                               | 0/2112 [00:10<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 13.95 GB, other allocations: 2.13 GB, max allowed: 18.13 GB). Tried to allocate 2.13 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [102]:
def evaluate(model, dataloader):
    model.eval()
    losses = 0
    sent_cnt = 0
    correct, total = 0, 0

    for src_batch, tgt_batch in dataloader:
        src = src_batch.to(DEVICE)
        tgt = tgt_batch.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(
            src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,
                       src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(
            logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
        sent_cnt += tgt_out.size()[-1]

        total += torch.numel(tgt_input)
        _, char_pred = torch.max(logits, -1)
        correct += (char_pred == tgt_out).sum().item()

    return losses / sent_cnt, correct / total

In [121]:
from timeit import default_timer as timer
NUM_EPOCHS = 20#15
train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(1, NUM_EPOCHS + 1):
    start_time = timer()
    train_loss, train_acc = train_epoch(
        transformer, optimizer, train_dataloader)
    train_losses += [train_loss]
    train_accs += [train_acc]
    end_time = timer()
    val_loss, val_acc = evaluate(transformer, val_dataloader)
    val_losses += [val_loss]
    val_accs += [val_acc]
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, Train acc.: {train_acc:.3f}, Val acc.: {val_acc:.3f}, Epoch time = {(end_time - start_time):.3f}s"))


  0%|                                                  | 0/2112 [00:00<?, ?it/s]

torch.Size([1, 16, 1, 1])


  0%|                                                  | 0/2112 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [64, 128, 8, 1], expected input[1, 0, 128, 1] to have 128 channels, but got 0 channels instead

In [None]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    #print(ys)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        #print(ys) # Remove the comment to understand the loop
        if next_word == EOS_IDX:
            break
    return ys


In [None]:
# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text2codes([src_sentence], token2idx)[0].view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 20, start_symbol=BOS_IDX).flatten()
    tgt_chars_specials = codes2text([tgt_tokens], idx2token)[0]
    tgt_chars = []
    for char in tgt_chars_specials:
        if char != '<bos>' and char != '<eos>':
            tgt_chars += [char]
    tgt_chars = ''.join(tgt_chars)
    return tgt_chars

## CNN

In [8]:
eeg = np.expand_dims(eeg,-1)

In [13]:
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(66, 4224, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

In [26]:
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
model.add(layers.Dense(1))

In [27]:
model.summary()

In [28]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
              metrics=['accuracy'])

history = model.fit(eeg[0:60,:,:,:], labels[0:60], epochs=10, 
                    validation_data=(eeg[61:-1,:,:,:], labels[61:-1]))

Epoch 1/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2s/step - accuracy: 0.5208 - loss: 7.7712 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.4896 - loss: 8.2030 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5312 - loss: 7.6273 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5104 - loss: 7.9151 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.4896 - loss: 8.2030 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5104 - loss: 7.9151 - val_accuracy: 0.5000 - val_loss: 8.0590
Epoch 7/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [None]:
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

In [None]:
print(test_acc)

## TRANSFORMER

In [6]:
# Implementing the Add & Norm Layer
class AddNormalization(Layer):
    def __init__(self, **kwargs):
        super(AddNormalization, self).__init__(**kwargs)
        self.layer_norm = LayerNormalization()  # Layer normalization layer
 
    def call(self, x, sublayer_x):
        # The sublayer input and output need to be of the same shape to be summed
        add = x + sublayer_x
 
        # Apply layer normalization to the sum
        return self.layer_norm(add)

In [7]:
# Implementing the Feed-Forward Layer
class FeedForward(Layer):
    def __init__(self, d_ff, d_model, **kwargs):
        super(FeedForward, self).__init__(**kwargs)
        self.fully_connected1 = Dense(d_ff)  # First fully connected layer
        self.fully_connected2 = Dense(d_model)  # Second fully connected layer
        self.activation = ReLU()  # ReLU activation layer
 
    def call(self, x):
        # The input is passed into the two fully-connected layers, with a ReLU in between
        x_fc1 = self.fully_connected1(x)
 
        return self.fully_connected2(self.activation(x_fc1))

In [11]:
# Implementing the Encoder Layer
class EncoderLayer(Layer):
    def __init__(self, h, d_k, d_v, d_model, d_ff, rate, **kwargs):
        super(EncoderLayer, self).__init__(**kwargs)
        self.multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
        self.dropout1 = Dropout(rate)
        self.add_norm1 = AddNormalization()
        self.feed_forward = FeedForward(d_ff, d_model)
        self.dropout2 = Dropout(rate)
        self.add_norm2 = AddNormalization()
 
    def call(self, x, padding_mask, training):
        # Multi-head attention layer
        multihead_output = self.multihead_attention(x, x, x, padding_mask)
        # Expected output shape = (batch_size, sequence_length, d_model)
 
        # Add in a dropout layer
        multihead_output = self.dropout1(multihead_output, training=training)
 
        # Followed by an Add & Norm layer
        addnorm_output = self.add_norm1(x, multihead_output)
        # Expected output shape = (batch_size, sequence_length, d_model)
 
        # Followed by a fully connected layer
        feedforward_output = self.feed_forward(addnorm_output)
        # Expected output shape = (batch_size, sequence_length, d_model)
 
        # Add in another dropout layer
        feedforward_output = self.dropout2(feedforward_output, training=training)
 
        # Followed by another Add & Norm layer
        return self.add_norm2(addnorm_output, feedforward_output)

In [13]:
# Implementing the Encoder
class Encoder(Layer):
    def __init__(self, vocab_size, sequence_length, h, d_k, d_v, d_model, d_ff, n, rate, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.pos_encoding = PositionEmbeddingFixedWeights(sequence_length, vocab_size, d_model)
        self.dropout = Dropout(rate)
        self.encoder_layer = [EncoderLayer(h, d_k, d_v, d_model, d_ff, rate) for _ in range(n)]
 
    def call(self, input_sentence, padding_mask, training):
        # Generate the positional encoding
        pos_encoding_output = self.pos_encoding(input_sentence)
        # Expected output shape = (batch_size, sequence_length, d_model)
 
        # Add in a dropout layer
        x = self.dropout(pos_encoding_output, training=training)
 
        # Pass on the positional encoded values to each encoder layer
        for i, layer in enumerate(self.encoder_layer):
            x = layer(x, padding_mask, training)
 
        return x

## TESTS

In [19]:
h = 8  # Number of self-attention heads
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
d_ff = 2048  # Dimensionality of the inner fully connected layer
d_model = 512  # Dimensionality of the model sub-layers' outputs
n = 6  # Number of layers in the encoder stack
 
batch_size = 64  # Batch size from the training process
dropout_rate = 0.1  # Frequency of dropping the input units in the dropout layers

enc_vocab_size =  # Vocabulary size for the encoder
input_seq_length =   # Maximum length of the input sequence
 
input_seq = random.random((batch_size, input_seq_length))

encoder = Encoder(enc_vocab_size, input_seq_length, h, d_k, d_v, d_model, d_ff, n, dropout_rate)

(80, 66, 4224)