In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataclasses import dataclass
import lovely_tensors as lt
lt.monkey_patch()

In [78]:
class Model(nn.Module):

    # FITS: Frequency Interpolation Time Series Forecasting

    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.individual = configs.individual
        self.channels = configs.enc_in

        self.dominance_freq=configs.cut_freq # 720/24
        self.length_ratio = (self.seq_len + self.pred_len)/self.seq_len

        if self.individual:
            self.freq_upsampler = nn.ModuleList()
            for i in range(self.channels):
                self.freq_upsampler.append(nn.Linear(self.dominance_freq, int(self.dominance_freq*self.length_ratio)).to(torch.cfloat))

        else:
            self.freq_upsampler = nn.Linear(self.dominance_freq, int(self.dominance_freq*self.length_ratio)).to(torch.cfloat) # complex layer for frequency upsampling

    def forward(self, x):
        print("length_ratio", self.length_ratio)

        # RIN 
        print("x", x)
        x_mean = torch.mean(x, dim=1, keepdim=True)
        print("x_mean", x_mean)
        x = x - x_mean
        x_var = torch.var(x, dim=1, keepdim=True) + 1e-5
        print("x_var", x_var)
        x = x / torch.sqrt(x_var)
        print("x", x)

        low_specx = torch.fft.rfft(x, dim=1)
        print("low_specx", low_specx.shape)
        print("low_specx_var", torch.var(low_specx, dim=1, keepdim=True))
        print("low_specx_mean", torch.mean(low_specx, dim=1, keepdim=True))

        low_specx[:,self.dominance_freq:] = 0 # LPF
        print("low_specx", low_specx.shape)
        low_specx = low_specx[:,0:self.dominance_freq,:] # LPF
        print("low_specx", low_specx.shape)

        # print(low_specx.permute(0,2,1))

        if self.individual:
        
            low_specxy_ = torch.zeros(
                [low_specx.size(0),
                 int(self.dominance_freq * self.length_ratio),
                 low_specx.size(2)],
                 dtype=low_specx.dtype).to(low_specx.device)
            
            for i in range(self.channels):
                low_specxy_[:,:,i] = self.freq_upsampler[i](
                    low_specx[:,:,i]
                    .permute(0,1)
                    ).permute(0,1)
        else:
            low_specxy_ = self.freq_upsampler(
                low_specx
                .permute(0,2,1)
                ).permute(0,2,1)
                
            print("low_specxy_", low_specxy_.shape)
            
        low_specxy = torch.zeros([
            low_specxy_.size(0),
            int((self.seq_len + self.pred_len) / 2 + 1 ),
            low_specxy_.size(2)],
            dtype=low_specxy_.dtype
            ).to(low_specxy_.device)
        
        print("low_specxy", low_specxy.shape)
        
        low_specxy[:, 0:low_specxy_.size(1), :] = low_specxy_ # zero padding

        print(low_specxy[:, low_specxy_.size(1):, :].shape)

        print("low_specxy", low_specxy.shape)

        low_xy = torch.fft.irfft(low_specxy, dim=1)

        print("low_xy", low_xy.shape)

        low_xy = low_xy * self.length_ratio # compensate the length change

        print("low_xy", low_xy.shape)
        
        xy = (low_xy) * torch.sqrt(x_var) + x_mean

        return xy, low_xy * torch.sqrt(x_var)


@dataclass
class Config:
    seq_len: int = 120
    pred_len: int = 60
    individual: bool = False
    enc_in: int = 1
    cut_freq: int = 20

model=Model(Config())
for param in model.parameters():
    param.data.fill_(0)

x = torch.rand(1, 120, 1)
y = model(x)
y[0].shape

length_ratio 1.5
x tensor[1, 120, 1] x∈[0.009, 0.994] μ=0.485 σ=0.287
x_mean tensor[1, 1, 1] [[[0.485]]]
x_var tensor[1, 1, 1] [[[0.083]]]
x tensor[1, 120, 1] x∈[-1.657, 1.771] μ=-1.192e-07 σ=1.000
low_specx torch.Size([1, 61, 1])
low_specx_var tensor[1, 1, 1] [[[119.863]]]
low_specx_mean tensor([[[-0.0314-0.9799j]]])
low_specx torch.Size([1, 61, 1])
low_specx torch.Size([1, 20, 1])
low_specxy_ torch.Size([1, 30, 1])
low_specxy torch.Size([1, 91, 1])
torch.Size([1, 61, 1])
low_specxy torch.Size([1, 91, 1])
low_xy torch.Size([1, 180, 1])
low_xy torch.Size([1, 180, 1])


torch.Size([1, 180, 1])

In [79]:
print(model)

Model(
  (freq_upsampler): Linear(in_features=20, out_features=30, bias=True)
)
