Skip to content

Commit

Permalink
update data scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
cookieminions committed Mar 1, 2021
1 parent b1db53f commit 257ca2f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
33 changes: 24 additions & 9 deletions data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, root_path, flag='train', size=None,
self.__read_data__()

def __read_data__(self):
scaler = StandardScaler()
self.scaler = StandardScaler()
df_raw = pd.read_csv(os.path.join(self.root_path,
self.data_path))

Expand All @@ -57,7 +57,9 @@ def __read_data__(self):
df_data = df_raw[[self.target]]

if self.scale:
data = scaler.fit_transform(df_data.values)
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
else:
data = df_data.values

Expand All @@ -70,7 +72,7 @@ def __read_data__(self):
df_stamp['hour'] = df_stamp.date.apply(lambda row:row.hour,1)
data_stamp = df_stamp.drop(['date'],1).values
elif self.timeenc==1:
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq='h')
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
data_stamp = data_stamp.transpose(1,0)

self.data_x = data[border1:border2]
Expand All @@ -93,6 +95,9 @@ def __getitem__(self, index):
def __len__(self):
return len(self.data_x) - self.seq_len- self.pred_len + 1

def inverse_transform(self, data):
return self.scaler.inverse_transform(data)

class Dataset_ETT_minute(Dataset):
def __init__(self, root_path, flag='train', size=None,
features='S', data_path='ETTm1.csv',
Expand Down Expand Up @@ -123,7 +128,7 @@ def __init__(self, root_path, flag='train', size=None,
self.__read_data__()

def __read_data__(self):
scaler = StandardScaler()
self.scaler = StandardScaler()
df_raw = pd.read_csv(os.path.join(self.root_path,
self.data_path))

Expand All @@ -139,7 +144,9 @@ def __read_data__(self):
df_data = df_raw[[self.target]]

if self.scale:
data = scaler.fit_transform(df_data.values)
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
else:
data = df_data.values

Expand All @@ -154,7 +161,7 @@ def __read_data__(self):
df_stamp['minute'] = df_stamp.minute.map(lambda x:x//15)
data_stamp = df_stamp.drop(['date'],1).values
elif self.timeenc==1:
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq='t')
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
data_stamp = data_stamp.transpose(1,0)

self.data_x = data[border1:border2]
Expand All @@ -177,6 +184,9 @@ def __getitem__(self, index):
def __len__(self):
return len(self.data_x) - self.seq_len- self.pred_len + 1

def inverse_transform(self, data):
return self.scaler.inverse_transform(data)


class Dataset_Custom(Dataset):
def __init__(self, root_path, flag='train', size=None,
Expand Down Expand Up @@ -208,7 +218,7 @@ def __init__(self, root_path, flag='train', size=None,
self.__read_data__()

def __read_data__(self):
scaler = StandardScaler()
self.scaler = StandardScaler()
df_raw = pd.read_csv(os.path.join(self.root_path,
self.data_path))
'''
Expand All @@ -232,7 +242,9 @@ def __read_data__(self):
df_data = df_raw[[self.target]]

if self.scale:
data = scaler.fit_transform(df_data.values)
train_data = df_data[border1s[0]:border2s[0]]
self.scaler.fit(train_data.values)
data = self.scaler.transform(df_data.values)
else:
data = df_data.values

Expand Down Expand Up @@ -266,4 +278,7 @@ def __getitem__(self, index):
return seq_x, seq_y, seq_x_mark, seq_y_mark

def __len__(self):
return len(self.data_x) - self.seq_len- self.pred_len + 1
return len(self.data_x) - self.seq_len- self.pred_len + 1

def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
6 changes: 3 additions & 3 deletions main_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--features', type=str, default='M', help='forecasting task, options:[M, S, MS(TBD)]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h', help='freq for time features encoding')
parser.add_argument('--freq', type=str, default='h', help='freq for time features encoding, options:[t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly]')

parser.add_argument('--seq_len', type=int, default=96, help='input sequence length of Informer encoder')
parser.add_argument('--label_len', type=int, default=48, help='start token length of Informer decoder')
Expand Down Expand Up @@ -72,9 +72,9 @@

for ii in range(args.itr):
# setting record of experiments
setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_eb{}_dt{}_{}_{}'.format(args.model, args.data, args.features,
setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_{}_{}'.format(args.model, args.data, args.features,
args.seq_len, args.label_len, args.pred_len,
args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.embed, args.distil, args.des, ii)
args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor, args.embed, args.distil, args.des, ii)

exp = Exp(args) # set experiments
print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
Expand Down

0 comments on commit 257ca2f

Please sign in to comment.