<a href="https://colab.research.google.com/github/sb2539/AI-study/blob/master/Triformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

@inproceedings{RazvanIJCAI2022,
  author    = {Razvan-Gabriel Cirstea and
                Chenjuan Guo and
                Bin Yang and
                Tung Kieu and
                Xuanyi Dong and
                Shirui Pan},
  title     = {Triformer: Triangular, Variable-Specific Attentions for Long Sequence
                Multivariate Time Series Forecasting},
  booktitle = {IJCAI},
  year      = {2022}
}

In [None]:
import math
import torch
import torch.nn as nn
from torch.nn import init

In [None]:
class Triformer(nn.Module):
    def __init__(self, device, num_nodes, input_dim, output_dim, channels, dynamic, lag,
                 horizon, patch_sizes, supports, mem_dim):
        super(Triformer, self).__init__()
        self.factorized = True
        print('Prediction {} steps ahead'.format(horizon))
        self.num_nodes = num_nodes
        self.output_dim = output_dim
        self.channels = channels
        self.dynamic = dynamic
        self.start_fc = nn.Linear(in_features= input_dim, out_features = self.channels)
        self.layers = nn.ModuleList()
        self.skip_generators = nn.ModuleList()
        self.horizon = horizon
        self.supports = supports
        self.lag = lag  # length of time series

        cuts = lag
        for patch_size in patch_sizes :
            if cuts % patch_size !=0:
                raise Exception('Lag not divisible by patch size')

            cuts = int(cuts / patch_size)   # cuts : number of patches
            self.layers.append(Layer(device = device, input_dim = channels,
                                     dynamic = dynamic, num_nodes = num_nodes, cuts=cuts, 
                                     cut_size = patch_size, factorized = self.factorized))
            self.skip_generators.append(WeightGenerator(in_dim = cuts * channels, out_dim = 256, number_of_weights = 1,
                                                        mem_dim = mem_dim, num_nodes = num_nodes, factorized = False))
            
        self.custom_linear = CustomLinear(factorized = False)
        self.projections = nn.Sequential(*[
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, horizon)])
        self.notprinted = True
    
    def forward(self, batch_x, batch_x_mark, dec_inp, batch_y_mark):
        if self.notprinted:
            self.notprinted = False
            print(batch_x.shape)
        x = self.start_fc(batch_x.unsqueeze(-1))
        batch_size = x.size(0)
        skip = 0

        for layer, skip_generator in zip(self.layers, self.skip_generators):
            x = layer(x)
            weights, biases = skip_generator()
            skip_inp = x.transpose(2,1).reshape(batch_size, 1, self.num_nodes, -1)
            skip = skip + self.custom_linear(skip_inp, weights[-1], biases[-1])

        x = torch.relu(skip).squeeze(1)
        return self.projections(x).transpose(2, 1)