导入必要的包，设置数据集和输出路径和需要分析的数据集列表。

In [3]:
import os
import pickle as pkl
import numpy as np
import csv
from tqdm import tqdm
from statsmodels.tsa.seasonal import STL
from statsmodels.tsa.tsatools import detrend
from statsmodels.tsa.stattools import adfuller, kpss, acf
import matplotlib.pyplot as plt

root_pth = '/nas/datasets/Tensor-Time-Series-Dataset/Processed_Data'
output_csv_pth = '/home/wangzihan/workspace/TTS_results/csv'
output_img_pth = '/home/wangzihan/workspace/TTS_results/imgs'
dataset_list = ['JONAS_NYC_taxi', 'JONAS_NYC_bike', 'METRO_HZ', 'METRO_SH',
                'PEMS03', 'PEMS07', 
                'ETT_hour', 'electricity', 'weather',
                'nasdaq100'
                ]


参考ProbTS的思路，将MTS/Tensor数据分解为1D时间序列的集合，将1D时间序列加窗分段，利用FFT检测每一小段的周期并进行STL分解，计算trend和seasonality strength，将两者的均值和方差做为数据集的特征。

DataParserBase实现了加窗分段，stride参数表示相邻窗之间的间隔（stride默认为1时包含了所有可能的分段，但计算时间太长了）。

In [2]:
class DataParserBase:
    def __init__(self, dataset_name:str) -> None:
        # get data from pkl file 
        pkl_name = dataset_name + '.pkl'
        pkl_path = os.path.join(root_pth, dataset_name, pkl_name)
        if not os.path.exists(pkl_path):
            raise FileExistsError(f"Can not find file: {pkl_path}")
        with open(pkl_path, 'rb') as file:
            self.data_pkl = pkl.load(file)
            self.data = self.data_pkl['data']
        self.transform_data()
    
    # treat data as a collection of 1D time series tobtain statistics of channels
    def transform_data(self):
        temp = self.data.reshape(self.data.shape[0], -1, 1)
        self.data = np.squeeze(temp, axis=-1)
        self.data_shape =self.data.shape
        self.timesteps = self.data_shape[0]
        self.channel_nums = self.data_shape[1]
    
    def get_data_shape(self):
        return self.data_shape

STParser类实现了时间序列周期的检测以及trend和seasonality的计算，其中get_period()方法来自tasks/data_analysis，取FFT谱的最大值为season周期。
分解序列时，先检测序列season周期period，再以period为参数进行STL分解（seasonal smoother长度取最接近period的奇数），得到三个分量后计算seasonality和trend strength值。

In [4]:
class STParser(DataParserBase):
    def __init__(self, dataset_name:str, window_size, stride=0) -> None:
        super().__init__(dataset_name, window_size, stride)
    
    def get_period_acf(self, series, verbose=False):
        # This method generally follows the steps in the paper:
        # Characteristic-Based Clustering for Time Series Data (2006)
        # TODO: can't cope well with shorter sequences
        # 1. detrend data
        data_detrend = detrend(series,order=1)
        ser_len = data_detrend.shape[0]
        maxlag = max(int(ser_len/3))
        # 2. calculate autocorrelation function (maxlag up to 1/3 series length)
        acorr = acf(data_detrend, nlags=maxlag, fft=True)
        # 3. find peaks and trough in acf and determine period according to prominence
        from scipy.signal import find_peaks
        corr_peaks = find_peaks(acorr, height=0.1, distance=12)[0]
        no_season = False
        period = 1 # period default to 1 (no seasonality)
        for i in range(1, corr_peaks.shape[0]):
            seg = acorr[corr_peaks[i-1] : corr_peaks[i]]
            trough = np.min(seg)
            if (abs(trough-corr_peaks[i]) > 0.1):
                period = corr_peaks[i]
                break
            no_season = (i==(corr_peaks.shape[0]-1))
        return period if not no_season else 1
    
    def get_period(self, ts, k=1):
        ts = detrend(ts, order=1)
        # 计算FFT
        if np.nonzero(ts)[0].shape[0] == 0:
            return np.ones(k), np.ones(k)
        fft = np.abs(np.fft.fft(ts, axis=0))
        frequencies = np.fft.fftfreq(len(ts))

        # 找到最大的k个频率
        indices = np.argsort(np.abs(fft[1:-1]))[-k:]
        periods = (1 / (frequencies[indices]+1e-10) if frequencies[indices] != 0 else np.array([1]))
        strength = fft[indices]/np.sum(fft)
        return np.abs(periods.astype(int))

    def decompose(self, seg):
        trd = []
        ses = []
        for i in range(0, self.channel_nums):
            ser = seg[:, i]
            if np.all(ser - np.mean(ser) == 0):
                trd.append(0)
                ses.append(0)
            else:
                period = self.get_period(ser)[0]
                L = 2 * (period // 2) + 1
                if period != 1:
                    stl = STL(ser, period = int(period), seasonal = max(L, 7)).fit()
                    seasonal, trend, resid = stl.seasonal, stl.trend, stl.resid
                    # seasonality and trend strength
                    val_trd = 1- (np.var(resid)/np.var(trend+resid))
                    trd.append((val_trd if val_trd > 0 else 0))
                    val_ses = 1- (np.var(resid)/np.var(seasonal+resid))
                    ses.append((val_ses if val_ses > 0 else 0))
                else:
                    resid = detrend(ser, order=1)
                    trend = ser - resid
                    # seasonality and trend strength
                    val_trd = 1- (np.var(resid)/np.var(trend+resid))
                    trd.append((val_trd if val_trd > 0 else 0))
                    ses.append(0)
        return trd, ses
    
    def st_parse(self):
        all_trd = []
        all_ses = []
        print(f"Start parsing dataset: {self.dataset_name}")
        print(f"iter nums: {len(range(0, self.timesteps, self.stride))}")
        for elems in tqdm(self.loader):
            trd, ses = self.decompose(elems)
            all_trd = all_trd + trd
            all_ses = all_ses + ses
        all_trd = np.array(all_trd)
        all_ses = np.array(all_ses)
        sm = np.mean(all_ses)
        sv = np.var(all_ses)
        tm = np.mean(all_trd)
        tv = np.var(all_trd)
        return sm, sv, tm, tv

遍历数据集列表并输出。

In [4]:
class Aggregator:
    def __init__(self, dataset_list: list) -> None:
        self.dataset_list = dataset_list
    
    def st_to_csv(self, output_pth, window_size, stride=0, verbose=True, exp_id=0):
        with open(os.path.join(output_pth, 'dataset_stats_' + str(exp_id) + '.csv'), 'w') as file:
                writer = csv.writer(file)
                writer.writerow(['dataset_name', 'seasonality_mean', 'seasonality_var', 'trend_mean', 'trend_var'])
                for elems in self.dataset_list:
                    parser = STParser(elems, window_size, stride)
                    sm, sv, tm, tv = parser.st_parse()
                    if verbose:
                        print(f"seasonal mean: {sm}, seasonal var: {sv}")
                        print(f"trend mean: {tm}, trend var: {tv}") 
                    writer.writerow([f'{elems}', sm, sv, tm, tv])
                writer.writerow([f'window_size: {window_size}'])
                writer.writerow([f'stride: {stride}'])


if __name__ == '__main__':
    aggregator = Aggregator(dataset_list)
    aggregator.st_to_csv(output_csv_pth, window_size=336, stride=24, exp_id=4, verbose=True)
    print("Finished parsing all datasets")

Start parsing dataset: JONAS_NYC_taxi
iter nums: 200


200it [07:52,  2.36s/it]


seasonal mean: 0.36480094834370064, seasonal var: 0.028795999364248745
trend mean: 0.2678069947183953, trend var: 0.034050372499239236
Start parsing dataset: JONAS_NYC_bike
iter nums: 200


200it [06:22,  1.91s/it]


seasonal mean: 0.2664028903791252, seasonal var: 0.03387400936294714
trend mean: 0.14804500721195502, trend var: 0.019789749294984373
Start parsing dataset: METRO_HZ
iter nums: 77


77it [01:47,  1.39s/it]


seasonal mean: 0.30468968511042055, seasonal var: 0.03218503090800619
trend mean: 0.2424300875566895, trend var: 0.026542667808434027
Start parsing dataset: METRO_SH
iter nums: 280


280it [24:30,  5.25s/it]


seasonal mean: 0.2925391435840112, seasonal var: 0.0173356560822418
trend mean: 0.26250950378993154, trend var: 0.01665480604176324
Start parsing dataset: PEMS03
iter nums: 1092


1092it [24:29,  1.35s/it]


seasonal mean: 0.09108738143510978, seasonal var: 0.062394588874263086
trend mean: 0.2021890686861127, trend var: 0.08089491909702981
Start parsing dataset: PEMS07
iter nums: 1176


1176it [1:54:45,  5.86s/it]


seasonal mean: 0.17734584686491675, seasonal var: 0.11033743190138805
trend mean: 0.278869273251859, trend var: 0.12393006236097137
Start parsing dataset: ETT_hour
iter nums: 726


726it [00:41, 17.42it/s]


seasonal mean: 0.4015875245608988, seasonal var: 0.0635661592955661
trend mean: 0.391413566403501, trend var: 0.051206288890890864
Start parsing dataset: electricity
iter nums: 1096


1096it [40:10,  2.20s/it]


seasonal mean: 0.5281561285203609, seasonal var: 0.06787297025035276
trend mean: 0.250130570098088, trend var: 0.03469123818187717
Start parsing dataset: weather
iter nums: 2196


2196it [15:31,  2.36it/s]


seasonal mean: 0.5810504317295482, seasonal var: 0.15616022223444742
trend mean: 0.5899956520709586, trend var: 0.11314750541944754
Start parsing dataset: nasdaq100
iter nums: 3112


3112it [4:59:20,  5.77s/it]


seasonal mean: 0.32000442523891476, seasonal var: 0.147398157908286
trend mean: 0.5305769569932318, trend var: 0.11868567869620894
Finished parsing all datasets
