-
Notifications
You must be signed in to change notification settings - Fork 0
/
dba.py
60 lines (49 loc) · 2.03 KB
/
dba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
"""
Created on Sat May 28 08:04:58 2022
@author: yin
"""
import time
import datetime
import sklearn
from data_utils import get_data
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from tslearn.clustering import TimeSeriesKMeans
from sklearn.cluster import KMeans
from tslearn.datasets import CachedDatasets
from tslearn.preprocessing import TimeSeriesScalerMeanVariance, TimeSeriesResampler
from utils import cluster_acc
if __name__ == '__main__':
train_data = loadmat('earthb.mat')
all_data = train_data['images']
all_target = train_data['labels']
shape = 3
select_maps = {0: None, 1: [0, 2000], 2: [300, 1300], 3: [500, 1000]}
shape_maps = {0: (-1, 1, 6000), 1: (-1, 1, 2000), 2: [-1, 1, 1000], 3: (-1, 1, 500)}
# load data
data = get_data(all_data, all_target, dataset='eq', seed=1, shape=shape_maps[shape], select=select_maps[shape], size=0.95)
x_train, x_valid, x_test, y_train, y_valid, y_test, splits, splits_test = data
# set contains training and validation
X = np.concatenate([x_train, x_valid])
y = np.concatenate([y_train, y_valid])
start = time.time()
X_feat = X.reshape(X.shape[0], -1)
print(X_feat.shape)
print("DBA k-means")
dba_km = TimeSeriesKMeans(n_clusters=3,
n_init=2,
metric="dtw",
n_jobs=-1,
verbose=True,
max_iter_barycenter=10,
random_state=1)
y_pred = dba_km.fit_predict(X_feat)
print('train accuracy')
train_valid_acc, y_pred = cluster_acc(y, y_pred)
y_pred = dba_km.fit_predict(x_test.reshape(X.shape[0], -1))
print('test accuracy')
test_acc = cluster_acc(y_test, y_pred)
end = time.time()
print("total time (DBA k-means training + evaluate cluster accuracy) takes %d seconds, %s" % (end - start, str(datetime.timedelta(seconds=end - start))))