In [2]:
import torch
from torch import nn

In [6]:

class Time2Vector(nn.Module):
    def __init__(self, seq_len, **kwargs):
        super(Time2Vector, self).__init__()
        self.seq_len = seq_len
        self.lin = nn.Linear(seq_len, 1)
        self.per = nn.Linear(seq_len, 1)


    def forward(self, x):
        x_lin = self.lin(x) 
        x_lin = x.unsqueeze() # (batch, seq_len, 1)

        x_per = torch.sin(self.per(x))
        x_lin = x.unsqueeze() # (batch, seq_len, 1)
        return torch.concat([time_linear, time_periodic], axis=-1) # (batch, seq_len, 2)
    
    

class SingleAttention(nn.Module):
    def __init__(self, d_k, d_v, seq_len):
        super(SingleAttention, self).__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.seq_len = seq_len
        
        self.Q = nn.Linear(seq_len, d_k)
        self.K = nn.Linear(seq_len, d_k)
        self.V = nn.Linear(seq_len, d_v)
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        
        attn_weights = torch.matmul(q, k.T)
        attn_weights = attn_weights / torch.sqrt(self.d_k)
        attn_weights = self.softmax(attn_weights)
        
        attn_out = torch.matmul(attn_weights, v)
        return attn_out  
    
    
class MultiAttention(nn.Module):
    def __init__(self, d_k, d_v, n_heads, seq_len):
        super(MultiAttention, self).__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.seq_len = seq_len
        self.n_heads = n_heads
        self.attn_heads = list()
        
        for n in range(self.n_heads):
              self.attn_heads.append(SingleAttention(self.d_k, self.d_v, seq_len))  
        self.linear = nn.Linear(input_shape, 3)
        
    def forward(self, x):
        attn = [self.attn_heads[i](x) for i in range(self.n_heads)]
        concat_attn = torch.concat(attn, axis=-1)
        multi_linear = self.linear(concat_attn)
        return multi_linear 
    

class TransformerEncoder(nn.Module):
    def __init__(self, d_k, d_v, n_heads, ff_dim, seq_len, dropout=0.1, **kwargs):
        super(TransformerEncoder, self).__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.n_heads = n_heads
        self.ff_dim = ff_dim
        self.attn_heads = list()
        self.dropout_rate = dropout

        self.attn_multi = MultiAttention(self.d_k, self.d_v, self.n_heads, seq_len)
        self.attn_dropout = nn.Dropout(self.dropout_rate)
        self.attn_normalize = nn.LayerNorm(input_shape, eps=1e-6)

        self.ff_conv1D_1 = nn.Conv1D(1, self.ff_dim, kernel_size=1, activation='relu')
        self.ff_conv1D_2 = nn.Conv1D(self.ff_dim, filters=3, kernel_size=1) # input_shape[0]=(batch, seq_len, 7), input_shape[0][-1]=7 
        self.ff_dropout = nn.Dropout(self.dropout_rate)
        self.ff_normalize = nn.LayerNorm(input_shape, eps=1e-6)    

 
  
    def call(self, x): # inputs = (in_seq, in_seq, in_seq)
        attn_layer = self.attn_multi(x)
        attn_layer = self.attn_dropout(attn_layer)
        attn_layer = self.attn_normalize(x[0] + attn_layer)

        ff_layer = self.ff_conv1D_1(attn_layer)
        ff_layer = self.ff_conv1D_2(ff_layer)
        ff_layer = self.ff_dropout(ff_layer)
        ff_layer = self.ff_normalize(x[0] + ff_layer)
        return ff_layer 

In [8]:
class Transformer(nn.Module):
    def __init__(self, d_k, d_v, n_heads, ff_dim, seq_len, dropout=0.1, **kwargs):
        super(Transformer, self).__init__()
        self.time_embedding = Time2Vector(seq_len)
        self.attn_layer1 = TransformerEncoder(d_k, d_v, n_heads, ff_dim, seq_len)
        self.attn_layer2 = TransformerEncoder(d_k, d_v, n_heads, ff_dim, seq_len)
        self.attn_layer3 = TransformerEncoder(d_k, d_v, n_heads, ff_dim, seq_len)
        self.pooling = nn.AvgPool1d()
        
        self.dropout(dropout)
        self.lin1 = nn.Linear(3, 64)
        self.lin2 = nn.Linear(64, 1)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        in_seq = time_embedding(x.shape[1])
        x = torch.concat([in_seq, x], axis=-1)
        x = attn_layer1((x, x, x))
        x = attn_layer2((x, x, x))
        x = attn_layer3((x, x, x))
        
        x = self.dropout(self.pooling(x))
        x = self.dropout(self.relu(self.lin1(x)))
        
        out = self.lin2(x)
        return out
        
        
        
def create_model():
    '''Initialize time and transformer layers'''
    time_embedding = Time2Vector(seq_len)
    attn_layer1 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)
    attn_layer2 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)
    attn_layer3 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)

    '''Construct model'''
    in_seq = Input(shape=(seq_len, 5))
    x = time_embedding(in_seq)
    x = Concatenate(axis=-1)([in_seq, x])
    x = attn_layer1((x, x, x))
    x = attn_layer2((x, x, x))
    x = attn_layer3((x, x, x))
    x = GlobalAveragePooling1D(data_format='channels_first')(x)
    x = Dropout(0.1)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.1)(x)
    out = Dense(1, activation='linear')(x)

    model = Model(inputs=in_seq, outputs=out)
    model.compile(loss='mse', optimizer='adam', metrics=['mae', 'mape'])
    return model