## Transformer 구조 구현

In [1]:
# !pip install sentencepiece

#### 1. 데이터 확인

In [2]:
# data를 저장할 directory 확인
data_dir = "./data"

#### 2. Imports

In [3]:
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import sentencepiece as spm
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

#### 3. 폴더의 목록을 확인
data_dir 목록을 확인 합니다.

In [4]:
for f in os.listdir(data_dir):
  print(f)

.DS_Store
kor_w2v_cbow
kor_w2v_cbow.model
kor_w2v_skipgram
kor_w2v_skipgram.model
kowiki.model
kowiki.txt
kowiki.vocab
naver.model
naver.vocab
naver_review.txt
ratings_train.csv
ratings_train.txt


#### 4. Vocab 및 입력
Sentencepiece를 활용해 미리 만든 voca를 로드함  
: wiki corpus로 만들어 놓음

로딩된 vocab을 이용해 input을 만듭니다.

In [5]:
# vocab 만들기
# 모델 로딩
import csv
vocab_file = f"{data_dir}/kowiki.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)

# 입력 텍스트
lines = [
    "겨울은 추워요.",
    "감기 조심하세요."
]

# input
inputs = []
for line in lines:
    pieces = vocab.EncodeAsPieces(line) # 토큰으로 바꿈
    ids = vocab.EncodeAsIds(line) # index로 바꿈
    inputs.append(torch.tensor(ids))
    print(pieces)
    print(ids)
    
# 입력 길이를 맞춰주기 위해 Padding 수행 : 최대 길이에 맞춰서 패딩이 이루어짐
inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0)

# shape
print(inputs.size()) 

# 값
print(inputs)

['▁겨울', '은', '▁추', '워', '요', '.']
[3234, 3744, 205, 4081, 3902, 3730]
['▁감', '기', '▁조', '심', '하', '세', '요', '.']
[199, 3746, 54, 3974, 3736, 3826, 3902, 3730]
torch.Size([2, 8])
tensor([[3234, 3744,  205, 4081, 3902, 3730,    0,    0],
        [ 199, 3746,   54, 3974, 3736, 3826, 3902, 3730]])


#### 5. Embedding

#### - Input Embedding

In [6]:
# input 값에 대한 embedding
n_vocab = len(vocab)
print(n_vocab)
d_hidn = 128 # hiddensize : 512
nn_emb = nn.Embedding(n_vocab, d_hidn)
input_embs = nn_emb(inputs)
print(input_embs.size())

8007
torch.Size([2, 8, 128])


##### - Position Embedding

1. 문장의 position 별 angle 값을 구함  
2. 구해진 angle 중 짝수 index의 값에 대한 sin 값을 구합니다.  
3. 구해진 angle 중 홀수 index의 값에 대한 cos 값을 구합니다.

In [7]:
""" sinusoid position embedding """
def get_sinusoid_encoding_table(n_seq, d_hidn):
    def cal_angle(position, i_hidn):
        return position/ np.power(10000, 2*(i_hidn //2) / d_hidn)
    
    def get_posi_angle_vec(position):
        return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]
    # 각 possition에 대해서 dim 마다 angle 값을 구함
    sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])  
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 
    return sinusoid_table



In [8]:
n_seq = 64
pos_encoding = get_sinusoid_encoding_table(n_seq, d_hidn)
print(pos_encoding.shape)
# embedding 그림 출력
# plt.pcolormesh(pos_encoding, cmap='RdBu')
# plt.xlabel('Depth')
# plt.xlim((0, d_hidn))
# plt.ylabel('Position')
# plt.colorbar()
# plt.show()

(64, 128)


In [9]:
# position embedding 구성
pos_encoding = torch.FloatTensor(pos_encoding)

# embedding layer 생성
nn_pos  = nn.Embedding.from_pretrained(pos_encoding, freeze=True)

# inputs: [2:8]
positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1

# position masking
pos_mask = inputs.eq(0)
positions.masked_fill_(pos_mask, 0) # padding 값을 MASKING
pos_embs = nn_pos(positions)

print(inputs)
print(positions)
print(pos_embs.size())

tensor([[3234, 3744,  205, 4081, 3902, 3730,    0,    0],
        [ 199, 3746,   54, 3974, 3736, 3826, 3902, 3730]])
tensor([[1, 2, 3, 4, 5, 6, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8]])
torch.Size([2, 8, 128])


In [10]:
# 초기 input 값 구성
input_sums = input_embs +pos_embs
print(input_sums.size())
print(input_sums)
print(input_embs)
print(pos_embs)

torch.Size([2, 8, 128])
tensor([[[ 0.2217,  1.0359, -1.6467,  ...,  0.0962, -1.2751,  1.2077],
         [ 0.0784,  0.5207,  0.5911,  ...,  0.5457, -0.5813,  1.4269],
         [-0.8689, -0.3755, -0.7425,  ...,  1.8775, -0.0199,  0.8167],
         ...,
         [-0.6697,  0.2188, -1.3910,  ...,  0.4820, -2.9566,  1.5345],
         [ 1.8596,  1.3996, -0.5283,  ...,  0.5701, -2.2868,  0.9269],
         [ 1.8596,  1.3996, -0.5283,  ...,  0.5701, -2.2868,  0.9269]],

        [[-0.2576,  0.0702, -1.2664,  ..., -0.5051, -1.0110,  1.9549],
         [ 0.7859, -0.6066,  1.4636,  ..., -0.3887,  1.6415,  0.4860],
         [ 2.1200, -4.0953,  1.3539,  ...,  2.8743,  1.1692,  0.3495],
         ...,
         [-0.5206,  0.0577, -0.5392,  ...,  1.9668, -0.9671,  0.8915],
         [ 1.1237,  0.9027, -1.4957,  ...,  0.8373, -0.4433,  3.0406],
         [ 0.5991, -0.8869,  0.0952,  ...,  0.4820, -2.9564,  1.5345]]],
       grad_fn=<AddBackward0>)
tensor([[[-0.6198,  0.4956, -2.4084,  ..., -0.9038, -1.2752, 

#### 6. Scale Dot Product Attention

##### Input

In [11]:
# input 입력 값을 만드는 과정
Q = input_sums
K = input_sums
V = input_sums

attn_mask = inputs.eq(0).unsqueeze(1).expand(Q.size(0) , Q.size(1), K.size(1))
print(attn_mask.size())
print(attn_mask[0])

torch.Size([2, 8, 8])
tensor([[False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True]])


##### Q * K-transpose

In [12]:
scores = torch.matmul(Q, K.transpose(-1, -2))
print(scores.size())
print(scores[0])

torch.Size([2, 8, 8])
tensor([[181.1045,  54.7520,  51.1379,  55.8373,  69.7627,  63.9615,  78.3058,
          78.3058],
        [ 54.7520, 195.3990,  35.4192,  40.9941,  42.6141,  21.3285,  50.6848,
          50.6848],
        [ 51.1379,  35.4192, 178.9026,  61.1841,  78.2546,  47.3389,  61.2755,
          61.2755],
        [ 55.8373,  40.9941,  61.1841, 156.4162,  67.0148,  44.1819,  74.0393,
          74.0393],
        [ 69.7627,  42.6141,  78.2546,  67.0148, 233.2007,  94.2432,  75.7979,
          75.7979],
        [ 63.9615,  21.3285,  47.3389,  44.1819,  94.2432, 165.0052,  65.8726,
          65.8726],
        [ 78.3058,  50.6848,  61.2755,  74.0393,  75.7979,  65.8726, 189.4782,
         189.4782],
        [ 78.3058,  50.6848,  61.2755,  74.0393,  75.7979,  65.8726, 189.4782,
         189.4782]], grad_fn=<SelectBackward0>)


##### Scale

In [13]:
# scaled dot
d_head = 64
scores = scores.mul_(1/d_head**0.5)
print(scores.size())
print(scores[0])


torch.Size([2, 8, 8])
tensor([[22.6381,  6.8440,  6.3922,  6.9797,  8.7203,  7.9952,  9.7882,  9.7882],
        [ 6.8440, 24.4249,  4.4274,  5.1243,  5.3268,  2.6661,  6.3356,  6.3356],
        [ 6.3922,  4.4274, 22.3628,  7.6480,  9.7818,  5.9174,  7.6594,  7.6594],
        [ 6.9797,  5.1243,  7.6480, 19.5520,  8.3769,  5.5227,  9.2549,  9.2549],
        [ 8.7203,  5.3268,  9.7818,  8.3769, 29.1501, 11.7804,  9.4747,  9.4747],
        [ 7.9952,  2.6661,  5.9174,  5.5227, 11.7804, 20.6257,  8.2341,  8.2341],
        [ 9.7882,  6.3356,  7.6594,  9.2549,  9.4747,  8.2341, 23.6848, 23.6848],
        [ 9.7882,  6.3356,  7.6594,  9.2549,  9.4747,  8.2341, 23.6848, 23.6848]],
       grad_fn=<SelectBackward0>)


##### Mask (Opt.)

In [14]:
# masking
scores.masked_fill_(attn_mask, -1e9)
print(scores.size())
print(scores[0])

torch.Size([2, 8, 8])
tensor([[ 2.2638e+01,  6.8440e+00,  6.3922e+00,  6.9797e+00,  8.7203e+00,
          7.9952e+00, -1.0000e+09, -1.0000e+09],
        [ 6.8440e+00,  2.4425e+01,  4.4274e+00,  5.1243e+00,  5.3268e+00,
          2.6661e+00, -1.0000e+09, -1.0000e+09],
        [ 6.3922e+00,  4.4274e+00,  2.2363e+01,  7.6480e+00,  9.7818e+00,
          5.9174e+00, -1.0000e+09, -1.0000e+09],
        [ 6.9797e+00,  5.1243e+00,  7.6480e+00,  1.9552e+01,  8.3769e+00,
          5.5227e+00, -1.0000e+09, -1.0000e+09],
        [ 8.7203e+00,  5.3268e+00,  9.7818e+00,  8.3769e+00,  2.9150e+01,
          1.1780e+01, -1.0000e+09, -1.0000e+09],
        [ 7.9952e+00,  2.6661e+00,  5.9174e+00,  5.5227e+00,  1.1780e+01,
          2.0626e+01, -1.0000e+09, -1.0000e+09],
        [ 9.7882e+00,  6.3356e+00,  7.6594e+00,  9.2549e+00,  9.4747e+00,
          8.2341e+00, -1.0000e+09, -1.0000e+09],
        [ 9.7882e+00,  6.3356e+00,  7.6594e+00,  9.2549e+00,  9.4747e+00,
          8.2341e+00, -1.0000e+09, -1.0000e

##### Softmax

In [15]:
# softmax 적용
attn_prob = nn.Softmax(dim=-1)(scores)
print(attn_prob.size())
print(attn_prob[0])

torch.Size([2, 8, 8])
tensor([[1.0000e+00, 1.3827e-07, 8.8010e-08, 1.5836e-07, 9.0283e-07, 4.3720e-07,
         0.0000e+00, 0.0000e+00],
        [2.3159e-08, 1.0000e+00, 2.0664e-09, 4.1481e-09, 5.0792e-09, 3.5503e-10,
         0.0000e+00, 0.0000e+00],
        [1.1589e-07, 1.6246e-08, 1.0000e+00, 4.0685e-07, 3.4367e-06, 7.2081e-08,
         0.0000e+00, 0.0000e+00],
        [3.4664e-06, 5.4212e-07, 6.7631e-06, 9.9997e-01, 1.4018e-05, 8.0751e-07,
         0.0000e+00, 0.0000e+00],
        [1.3411e-09, 4.5047e-11, 3.8768e-09, 9.5126e-10, 1.0000e+00, 2.8605e-08,
         0.0000e+00, 0.0000e+00],
        [3.2703e-06, 1.5856e-08, 4.0946e-07, 2.7594e-07, 1.4404e-04, 9.9985e-01,
         0.0000e+00, 0.0000e+00],
        [3.7319e-01, 1.1816e-02, 4.4404e-02, 2.1894e-01, 2.7277e-01, 7.8882e-02,
         0.0000e+00, 0.0000e+00],
        [3.7319e-01, 1.1816e-02, 4.4404e-02, 2.1894e-01, 2.7277e-01, 7.8882e-02,
         0.0000e+00, 0.0000e+00]], grad_fn=<SelectBackward0>)


##### atten_prov * V

In [16]:
context = torch.matmul(attn_prob, V)
print(context.size())

torch.Size([2, 8, 128])


##### Implementation Class

In [17]:
""" scale dot product attention """
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_head):
        super().__init__()
        self.scale = 1/(d_head**0.5)   
        
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        scores.masked_fill_(attn_mask, -1e9)
        attn_prob = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn_prob, V)
        return context, attn_prob
    
    

#### 7. Multi-Head Attention

##### Input

In [18]:
Q = input_sums # [batch, n_seq, dim]
K = input_sums
V = input_sums
# masking 만들기 
attn_mask = inputs.eq(0).unsqueeze(1).expand(Q.size(0), Q.size(1), K.size(1))

batch_size = Q.size(0)
n_head = 2 # 기존 논문 6


##### Multi Head Q, K, V

In [19]:
# 멀티 헤드 수에 맞게 linear 구성
W_Q = nn.Linear(d_hidn, n_head * d_head) 
W_K = nn.Linear(d_hidn, n_head * d_head)
W_V = nn.Linear(d_hidn, n_head * d_head) 

# (bs, n_seq, n_head * d_head)
q_s = W_Q(Q)
print(q_s.size())

# (bs, n_seq, n_head, d_head)
q_s = q_s.view(batch_size, -1, n_head, d_head) # 배열을 4차원으로 변경
print(q_s.size())

# (bs, n_head, n_seq, d_head)
q_s = q_s.transpose(1, 2)
print(q_s.size())

torch.Size([2, 8, 128])
torch.Size([2, 8, 2, 64])
torch.Size([2, 2, 8, 64])


In [20]:
# 멀티 헤드 수에 맞게 변경 -> Q, K, V 모두
q_s = W_Q(Q).view(batch_size, -1, n_head, d_head).transpose(1, 2)
k_s = W_K(K).view(batch_size, -1, n_head, d_head).transpose(1, 2)
v_s = W_V(V).view(batch_size, -1, n_head, d_head).transpose(1, 2)
print(q_s.size(), k_s.size(), v_s.size())

torch.Size([2, 2, 8, 64]) torch.Size([2, 2, 8, 64]) torch.Size([2, 2, 8, 64])


##### Multi Head Attention Mask

In [21]:
# Mask도 변경
print(attn_mask.size())
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_head, 1, 1)
print(attn_mask.size())

torch.Size([2, 8, 8])
torch.Size([2, 2, 8, 8])


##### Attention

In [22]:
scaled_dot_attn = ScaledDotProductAttention(d_head)

In [23]:
context, attn_prob =scaled_dot_attn.forward(q_s,k_s,v_s,attn_mask)

In [24]:
print(context.size())
print(attn_prob.size() )

torch.Size([2, 2, 8, 64])
torch.Size([2, 2, 8, 8])


##### Concat

In [25]:
# (bs[배치사이즈], n_head[헤드 수], n_seq[최대길이], d_head[헤드별 디멘션] )
# (bs[배치사이즈], n_seq[최대길이], n_head * d_head[헤드별 디멘션] )
# context.transpose(1,2) # 시퀀스와 헤드의 자리를 바꿔줌
context = context.transpose(1,2).contiguous().view(batch_size, -1, n_head*d_head)
print(context.size())

torch.Size([2, 8, 128])


##### Linear

In [26]:
linear = nn.Linear(n_head * d_head, d_hidn) 
# (bs, n_seq, n_hidn)
output = linear(context)
print(output.size())

torch.Size([2, 8, 128])


##### Implementation Class

In [34]:
""" multi head attention """
class MultiHeadAttention(nn.Module):
    def __init__(self, d_hidn, n_head, d_head):
        super().__init__()
        # self 인자
        self.d_hidn = d_hidn
        self.d_head = d_head
        self.n_head = n_head

        # linear, sacled_dot_attn, linear
        self.W_Q = nn.Linear(d_hidn, n_head * d_head)  
        self.W_K = nn.Linear(d_hidn, n_head * d_head)
        self.W_V = nn.Linear(d_hidn, n_head * d_head) 
        self.scaled_dot_attn = ScaledDotProductAttention(d_head)
        self.linear = nn.Linear(n_head * d_head, d_hidn)
    
    def forward(self, Q, K, V, attn_mask): 
        batch_size = Q.size(0)
        # q_s: (bs, n_head, n_q_seq, d_head)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_head, self.d_head).transpose(1, 2)
        # k_s: (bs, n_head, n_k_seq, d_head)
        k_s = self.W_K(K).view(batch_size, -1, self.n_head, self.d_head).transpose(1, 2)
        # v_s: (bs, n_head, n_v_seq, d_head)
        v_s = self.W_V(V).view(batch_size, -1, self.n_head, self.d_head).transpose(1, 2)

        # mask
        # (bs, n_head, n_q_seq, n_k_seq)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)
              
        # scaled dot
        # (bs, n_head, n_q_seq, d_head), (bs, n_head, n_q_seq, n_k_seq)
        context, attn_prob =self.scaled_dot_attn.forward(q_s,k_s,v_s,attn_mask)
        
        # concat
        # (bs, n_q_seq, h_head * d_head)
        context = context.transpose(1,2).contiguous().view(batch_size, -1, self.n_head * self.d_head)
        
        # linear
        # (bs, n_q_seq, d_hidn)
        output = self.linear(context)
        
        # (bs, n_q_seq, d_hidn), (bs, n_head, n_q_seq, n_k_seq)
        return output, attn_prob 

In [36]:
mul_head_attn = MultiHeadAttention(d_hidn, n_head, d_head)
output, attn_prob = mul_head_attn.forward(Q,K,V,attn_mask)
print(output.size())
print(attn_prob.size())

torch.Size([2, 8, 128])
torch.Size([2, 2, 8, 8])


#### 8. Masked Multi Head Attention

In [37]:
""" attention decoder mask """
def get_attn_decoder_mask(seq):
    subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1))
    subsequent_mask = subsequent_mask.triu(diagonal=1) # upper triangular
    return subsequent_mask


Q = input_sums
K = input_sums
V = input_sums

# attn_pad_mask : 기존 input에 대한 pad mask
attn_pad_mask = inputs.eq(0).unsqueeze(1).expand(Q.size(0), Q.size(1), K.size(1)) # 차원수 늘리고, 기본적인 마스킹
print(attn_pad_mask[0]) # 같은 문장에 대해 8개를 만듦


# attn_dec_mask : 현재 단어에서 이전 단어만 보겠다는 attention mask

attn_dec_mask = get_attn_decoder_mask(inputs)
print(attn_dec_mask[0])
# attn_mask : attn_pad_mask + attn_dec_mask
attn_mask = torch.gt((attn_pad_mask + attn_dec_mask),0)
print(attn_mask)


batch_size = Q.size(0)
n_head = 2


# 마스킹 된값을 1 로 표현

#  입력 시퀀스의 패딩 부분을 마스킹하고,
# 뒤에 단어는 보지 않겠다(True)

tensor([[False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True,  True]])
tensor([[0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[[False,  True,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True,  True,  True

In [38]:
attention = MultiHeadAttention(d_hidn, n_head, d_head) # 초기화
output, attn_prob = attention(Q, K, V, attn_mask )
print(output.size(), attn_prob.size())

torch.Size([2, 8, 128]) torch.Size([2, 2, 8, 8])


#### 9. Feed Forward

##### f1 (Linear)

In [39]:
# Linear : Conv1d로 활용
conv1 = nn.Conv1d(in_channels=d_hidn, out_channels=d_hidn*4, kernel_size=1)
print(conv1)
ff_1 = conv1(output.transpose(1,2))
print(ff_1.size())
# (bs, d_hidn * 4, n_seq)

Conv1d(128, 512, kernel_size=(1,), stride=(1,))
torch.Size([2, 512, 8])


##### Activation (relu or gelu)

![](https://raw.githubusercontent.com/paul-hyun/paul-hyun.github.io/master/assets/2019-12-19/activation.png)

In [40]:
# active = F.gelu
active = F.gelu
ff_2 = active(ff_1)

##### f3 (Linear)

In [42]:
# Linear : Conv1d로 활용
# 원형으로 다시 복원하는 과정
conv2 = nn.Conv1d(in_channels=d_hidn*4, out_channels=d_hidn, kernel_size=1)
ff3 = conv2(ff_2).transpose(1,2)
print(ff3.size())

torch.Size([2, 8, 128])


##### Implementation Class

In [43]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_hidn):
        super().__init__()
        self.d_hidn = d_hidn
        self.conv1 = nn.Conv1d(in_channels=self.d_hidn, out_channels=self.d_hidn*4, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.d_hidn*4, out_channels=self.d_hidn, kernel_size=1)
        self.active = F.gelu

    def forward(self, inputs):
        # f1 output: (bs, d_ff, n_seq)
        ff_1 = self.conv1(inputs.transpose(1,2))
        # f2 output: (bs, n_seq, d_hidn)
        active = F.gelu
        ff_2 = self.active(ff_1)
        # (bs, n_seq, d_hidn)
        ff3 = self.conv2(ff_2).transpose(1,2)
        return ff3

In [None]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_hidn):
        super().__init__()
        self.d_hidn = d_hidn
        self.conv1 = nn.Conv1d(in_channels=self.d_hidn, out_channels=self.d_hidn*4, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.d_hidn*4, out_channels=self.d_hidn, kernel_size=1)
        self.active = F.gelu

    def forward(self, inputs):
        # f1 output: (bs, d_ff, n_seq)
        output = self.active(self.conv1(inputs.transpose(1,2)))
        # f2 output: (bs, n_seq, d_hidn)
        active = F.gelu
        output= self.conv2(output).transpose(1,2)
        # (bs, n_seq, d_hidn)
    
        return output