In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX
import matplotlib.pyplot as plt

In [3]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [4]:
data = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
data_chamber = pd.read_csv('../dataset/insect/ant/time_series_chamber.csv')
#print(data)

In [5]:
data_x = data[['location_x', 'ant_id', 'colony_id']]
data_y = data[['location_y', 'ant_id', 'colony_id']]
#print(data_x, data_y)

In [6]:
data_x = pd.DataFrame(np.unique(data.time), columns = ['time'])
data_y = pd.DataFrame(np.unique(data.time), columns = ['time'])

k = 1
for i in tqdm(range(len(np.unique(data.colony_id)))):
    for j in tqdm(range(len(np.unique(data[data['colony_id'] == i+1]['ant_id'])))):
        tmp = data[(data['colony_id'] == i+1) & (data['ant_id']== j)]['location_x']
        tmp1 = data[(data['colony_id'] == i+1) & (data['ant_id']== j)]['location_y']
        tmp.reset_index(drop = True, inplace = True)
        tmp1.reset_index(drop = True, inplace = True)
        data_x = pd.concat([data_x, tmp], axis = 1)
        data_x.rename({'location_x':'%i'%k}, axis = 'columns', inplace = True)
        data_y = pd.concat([data_y, tmp1], axis = 1)
        data_y.rename({'location_y':'%i'%k}, axis = 'columns', inplace = True)
        #print(time_list)
        k += 1

data_x.drop(columns = ['time'], inplace = True)
data_y.drop(columns = ['time'], inplace = True)

  0%|          | 0/3 [00:00<?, ?it/s]
  0%|          | 0/73 [00:00<?, ?it/s][A
 14%|█▎        | 10/73 [00:00<00:00, 93.28it/s][A
 29%|██▉       | 21/73 [00:00<00:00, 100.68it/s][A
 44%|████▍     | 32/73 [00:00<00:00, 98.43it/s] [A
 58%|█████▊    | 42/73 [00:00<00:00, 90.33it/s][A
 71%|███████   | 52/73 [00:00<00:00, 83.91it/s][A
 84%|████████▎ | 61/73 [00:00<00:00, 82.27it/s][A
100%|██████████| 73/73 [00:00<00:00, 83.53it/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.08it/s]
  0%|          | 0/80 [00:00<?, ?it/s][A
  8%|▊         | 6/80 [00:00<00:01, 58.70it/s][A
 15%|█▌        | 12/80 [00:00<00:01, 55.63it/s][A
 22%|██▎       | 18/80 [00:00<00:01, 53.41it/s][A
 30%|███       | 24/80 [00:00<00:01, 53.01it/s][A
 38%|███▊      | 30/80 [00:00<00:00, 53.56it/s][A
 45%|████▌     | 36/80 [00:00<00:00, 53.22it/s][A
 52%|█████▎    | 42/80 [00:00<00:00, 53.48it/s][A
 60%|██████    | 48/80 [00:00<00:00, 52.80it/s][A
 68%|██████▊   | 54/80 [00:01<00:00, 49.89it/s][A
 75%|███████▌  

In [7]:
datax_corr = data_x.corr()
datay_corr = data_y.corr()
cpl_x = list()
cpl_y = list()

num_ant = 13

for j in range(len(datax_corr)):
    if (datax_corr.iloc[num_ant, j] > 0.6):
        cpl_x.append('%i'%j)
    if (datay_corr.iloc[num_ant, j] > 0.6):
        cpl_y.append('%i'%j)

ant_related = list()
for j in cpl_x:
    if (j in cpl_y):
        ant_related.append(j)
         
print(ant_related)

['13', '28']


In [8]:
data_df = pd.DataFrame(data_chamber['%i'%num_ant])
data_df.rename({'%i'%num_ant:'chamber'}, axis = 'columns', inplace = True)
for j in ant_related:
    data_df = pd.concat([data_df, data_x[j]], axis = 1)
    data_df.rename({j:'location_x_%s'%j}, axis = 'columns', inplace = True)
    data_df = pd.concat([data_df, data_y[j]], axis = 1)
    data_df.rename({j:'location_y_%s'%j}, axis = 'columns', inplace = True)
print(data_df)

      chamber  location_x_13  location_y_13  location_x_28  location_y_28
0           4          38.89         162.52          16.97          28.78
1           4          38.89         162.52          15.11          27.10
2           4          38.89         162.52          22.93          30.13
3           4          38.89         162.52          27.48          29.77
4           4          38.89         162.52          27.31          25.67
...       ...            ...            ...            ...            ...
2875        5           0.00           0.00          35.23         151.88
2876        5           0.00           0.00          35.23         151.88
2877        5           0.00           0.00          35.23         151.88
2878        5           0.00           0.00          33.77         153.15
2879        5           0.00           0.00          34.45         153.61

[2880 rows x 5 columns]


In [9]:
train = data_df[:int(0.5*(len(data_df)))]
valid = data_df[int(0.5*(len(data_df))):]

In [11]:
tmp = []
for p in tqdm(range(5)):
    for q in tqdm(range(5)):
        try:
            tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
        except:
            tmp.append([None, p, q])
tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
print(tmp)
tmp[tmp['bic'] == tmp['bic'].min()]

  0%|          | 0/5 [00:00<?, ?it/s]




100%|██████████| 5/5 [03:39<00:00, 43.91s/it][A
 20%|██        | 1/5 [03:39<14:38, 219.58s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [04:25<00:00, 53.10s/it][A
 40%|████      | 2/5 [08:05<12:19, 246.60s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [08:52<00:00, 106.52s/it][A
 60%|██████    | 3/5 [16:57<12:34, 377.19s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  war


  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [11:25<00:00, 137.10s/it][A
 80%|████████  | 4/5 [28:23<08:18, 498.91s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [44:29<00:00, 533.84s/it][A
100%|██████████| 5/5 [1:12:52<00:00, 874.48s/it] 


             bic  p  q
0            NaN  0  0
1   50546.388364  0  1
2   58453.674686  0  2
3   54846.086287  0  3
4   55207.972408  0  4
5   29413.179349  1  0
6   28837.504166  1  1
7   28601.630616  1  2
8   28754.428700  1  3
9   28802.109197  1  4
10  28415.277955  2  0
11  28557.368454  2  1
12  28505.156860  2  2
13  28569.033622  2  3
14  28549.566974  2  4
15  28233.808396  3  0
16  28382.877919  3  1
17  28533.919607  3  2
18  28602.208580  3  3
19  28635.118986  3  4
20  28221.163784  4  0
21  28399.862440  4  1
22  28547.156117  4  2
23  28698.133584  4  3
24  28768.649490  4  4


Unnamed: 0,bic,p,q
20,28221.163784,4,0


In [13]:
print(tmp)

             bic  p  q
0            NaN  0  0
1   50546.388364  0  1
2   58453.674686  0  2
3   54846.086287  0  3
4   55207.972408  0  4
5   29413.179349  1  0
6   28837.504166  1  1
7   28601.630616  1  2
8   28754.428700  1  3
9   28802.109197  1  4
10  28415.277955  2  0
11  28557.368454  2  1
12  28505.156860  2  2
13  28569.033622  2  3
14  28549.566974  2  4
15  28233.808396  3  0
16  28382.877919  3  1
17  28533.919607  3  2
18  28602.208580  3  3
19  28635.118986  3  4
20  28221.163784  4  0
21  28399.862440  4  1
22  28547.156117  4  2
23  28698.133584  4  3
24  28768.649490  4  4


In [None]:
prediction_x = list()
prediction_y = list()

for i in range(len(train)):
    prediction_x.append(train.iloc[i, 1])
    prediction_y.append(train.iloc[i, 2])

start_t = len(train)
for t_i in tqdm(range(len(valid))):
    current_t = t_i + start_t
    model = VARMAX(data_df[t_i:current_t], order = (3,0))
    fitted_model = model.fit()
    prediction = fitted_model.forecast().reset_index(drop=True) 
    #print(prediction)
    prediction_x.append(prediction['location_x_%i'%num_ant]) 
    prediction_y.append(prediction['location_y_%i'%num_ant]) 