<a href="https://colab.research.google.com/github/sb2539/AI-study/blob/master/Triformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

@inproceedings{RazvanIJCAI2022,
  author    = {Razvan-Gabriel Cirstea and
                Chenjuan Guo and
                Bin Yang and
                Tung Kieu and
                Xuanyi Dong and
                Shirui Pan},
  title     = {Triformer: Triangular, Variable-Specific Attentions for Long Sequence
                Multivariate Time Series Forecasting},
  booktitle = {IJCAI},
  year      = {2022}
}

In [None]:
import math
import torch
import torch.nn as nn
from torch.nn import init

In [None]:
class Triformer(nn.Module):
    def __init__(self, device, num_nodes, input_dim, output_dim, channels, dynamic, lag,
                 horizon, patch_sizes, supports, mem_dim):
        super(Triformer, self).__init__()
        self.factorized = True
        print('Prediction {} steps ahead'.format(horizon))
        self.num_nodes = num_nodes # variable 수로 추정
        self.output_dim = output_dim
        self.channels = channels
        self.dynamic = dynamic
        self.start_fc = nn.Linear(in_features= input_dim, out_features = self.channels)
        self.layers = nn.ModuleList()
        self.skip_generators = nn.ModuleList()
        self.horizon = horizon # 예측과 관련? 논문에서는 temporal horizon이라는 언급이 있는데 layer size 말하는가? 아니면 예측할 다음 time stamp 말하는건가... 
        self.supports = supports
        self.lag = lag  # length of time series로 추정 (12) # lag는 지연이라는 뜻임 전체 시계열 길이 의미하는 듯

        cuts = lag  # 12
        for patch_size in patch_sizes :
            if cuts % patch_size !=0:
                raise Exception('Lag not divisible by patch size')

            cuts = int(cuts / patch_size)   # cuts : number of patches (12/ 3) patch_size = 3
            self.layers.append(Layer(device = device, input_dim = channels,    
                                     dynamic = dynamic, num_nodes = num_nodes, cuts=cuts, 
                                     cut_size = patch_size, factorized = self.factorized))   # layer stacking 형태로 쌓기 위한 코드로 보임
            self.skip_generators.append(WeightGenerator(in_dim = cuts * channels, out_dim = 256, number_of_weights = 1,
                                                        mem_dim = mem_dim, num_nodes = num_nodes, factorized = False))   # 각 layer의 aggregation output 과 마지막 layer의 output layer를 predictor로 보내는 skip connection 역할
            
        self.custom_linear = CustomLinear(factorized = False)
        self.projections = nn.Sequential(*[   # predictor 역할
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, horizon)])
        self.notprinted = True
    
    def forward(self, batch_x, batch_x_mark, dec_inp, batch_y_mark):
        if self.notprinted:
            self.notprinted = False
            print(batch_x.shape)
        x = self.start_fc(batch_x.unsqueeze(-1))   # x = (b, t, n, c) b : batch_size, t : timelength, n : variables, c : channels?
        batch_size = x.size(0)
        skip = 0

        for layer, skip_generator in zip(self.layers, self.skip_generators):
            x = layer(x)   # layer output(aggregate state)
            weights, biases = skip_generator()
            skip_inp = x.transpose(2,1).reshape(batch_size, 1, self.num_nodes, -1)   # skip connection input
            skip = skip + self.custom_linear(skip_inp, weights[-1], biases[-1])    # 이전 layer skip connection input과 합쳐짐

        x = torch.relu(skip).squeeze(1)
        return self.projections(x).transpose(2, 1)  # predictor 역할

In [None]:
class Layer(nn.Module): # layer내 동작 관련 코드로 추정
    def __init__(self, device, input_dim, num_nodes, cuts, cut_size, dynamic, factorized):
        super(Layer, self).__init__()
        self.device = device
        self.input_dim = input_dim
        self.num_nodes = num_nodes
        self.dynamic = dynamic
        self.cuts = cuts
        self.cut_size = cut_size
        self.temporal_embeddings = nn.Parameter(torch.rand(cuts, 1, 1, self.num_nodes, 5).to(device),
                                                requires_grad=True).to(device)    # pseudo timestamp로 추정, torch.rand에서 1두개 있는데 오타로 추정됨 

        self.embeddings_generator = nn.ModuleList([nn.Sequential(*[      # 
            nn.Linear(5, input_dim)]) for _ in range(cuts)])
        
        self.out_net1 = nn.Sequential(*[      # out_net1, out_net2는 recurrent connection에서의 gated mechanism으로 보임
            nn.Linear(input_dim, input_dim**2),
            nn.Tanh(),
            nn.Linear(input_dim **2, input_dim),
            nn.Tanh(),
        ])

        self.out_net2 = nn.Sequential(*[
            nn.Linear(input_dim, input_dim**2),
            nn.Tanh(),
            nn.Linear(input_dim **2, input_dim),
            nn.Sigmoid(),
        ])

        self.temporal_att = TemporalAttention(input_dim, factorized=factorize)
        self.weights_generator_distinct = WeightGenerator(input_dim, input_dim, mem_dim = 5, num_nodes = num_nodes,  # light-weight method의 specific weight
                                                          factorized = factorized, number_of_weights = 2)
        self.weights_generator_shared = WeightGenerator(input_dim, input_dim, men_dim = None, num_nodes = num_nodes,
                                                        factorized = False, number_of_weights = 2)   # light-weight mehtod의 공유되는 왼쪽 오른쪽 가중치?
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x shape : B T N C (batch_size, time_length, variables, channels)
        batch_size = x.size(0)
        data_concat = None
        out = 0

        weights_shared, biases_shared = self.weight_generator_shared()     # 구분되는 가중치 , 편차항 생성
        weights_distinct, biases_distinct = self.weight_generator_distinct()  # 공유되는 가중치, 편차항 생성

        for i in range(self.cuts): # 패치 수만큼 반복
            # shape is (B, cut_size, N, C)
            t = x[:, i*self.cut_size : (i+1)*self.cut_size, :, :] # t는 여기서 나누어진 패치를 의미하는것으로 보임

            if i != 0:
                out = self.out_net1(out) * self.out_net2(out)   # 이전 pseudo timestamp의 출력을 recurrent connect와 gate mechanism 사용해 다음 pseudo timestamp로 전달
        
            emb = self.embeddings_generator[i](self.temporal_embeddings[i]).repeat(batch_size, 1, 1, 1) + out  # 다음시점의 pseudo timestamp 임베딩 해서 생성하고, 이전 pseudo timestame에서 넘어온 output 더함 
            t = torch.cat([emb, t], dim = 1)  # patch와 concat?? 왜?
            out = self.temporal_att(t[:, :1, :, ,:], t, t, weights_distinct, biases_distinct, weights_shared, biases_shared) # 최종 다음 시점 pseudo timestamp 만들어내는 과정

            if data_concat == None:
            data_concat = out
            else :
            data_concat = torch.cat([data_concat, out], dim = 1)
        return self.dropout(data_concat)
        