In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX
import matplotlib.pyplot as plt

In [2]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [3]:
data = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
data_chamber = pd.read_csv('../dataset/insect/ant/time_series_chamber.csv')
#print(data)

In [4]:
data_x = data[['location_x', 'ant_id', 'colony_id']]
data_y = data[['location_y', 'ant_id', 'colony_id']]
#print(data_x, data_y)

In [5]:
data_x = pd.DataFrame(np.unique(data.time), columns = ['time'])
data_y = pd.DataFrame(np.unique(data.time), columns = ['time'])

k = 1
for i in tqdm(range(len(np.unique(data.colony_id)))):
    for j in tqdm(range(len(np.unique(data[data['colony_id'] == i+1]['ant_id'])))):
        tmp = data[(data['colony_id'] == i+1) & (data['ant_id']== j)]['location_x']
        tmp1 = data[(data['colony_id'] == i+1) & (data['ant_id']== j)]['location_y']
        tmp.reset_index(drop = True, inplace = True)
        tmp1.reset_index(drop = True, inplace = True)
        data_x = pd.concat([data_x, tmp], axis = 1)
        data_x.rename({'location_x':'%i'%k}, axis = 'columns', inplace = True)
        data_y = pd.concat([data_y, tmp1], axis = 1)
        data_y.rename({'location_y':'%i'%k}, axis = 'columns', inplace = True)
        #print(time_list)
        k += 1

data_x.drop(columns = ['time'], inplace = True)
data_y.drop(columns = ['time'], inplace = True)

  0%|          | 0/3 [00:00<?, ?it/s]
  0%|          | 0/73 [00:00<?, ?it/s][A
 15%|█▌        | 11/73 [00:00<00:00, 104.03it/s][A
 32%|███▏      | 23/73 [00:00<00:00, 112.84it/s][A
 48%|████▊     | 35/73 [00:00<00:00, 106.49it/s][A
 63%|██████▎   | 46/73 [00:00<00:00, 96.98it/s] [A
 77%|███████▋  | 56/73 [00:00<00:00, 86.54it/s][A
100%|██████████| 73/73 [00:00<00:00, 84.98it/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.11it/s]
  0%|          | 0/80 [00:00<?, ?it/s][A
  9%|▉         | 7/80 [00:00<00:01, 69.82it/s][A
 18%|█▊        | 14/80 [00:00<00:01, 63.04it/s][A
 26%|██▋       | 21/80 [00:00<00:00, 61.54it/s][A
 35%|███▌      | 28/80 [00:00<00:00, 57.98it/s][A
 42%|████▎     | 34/80 [00:00<00:00, 56.79it/s][A
 50%|█████     | 40/80 [00:00<00:00, 55.60it/s][A
 57%|█████▊    | 46/80 [00:00<00:00, 54.01it/s][A
 65%|██████▌   | 52/80 [00:00<00:00, 52.50it/s][A
 72%|███████▎  | 58/80 [00:01<00:00, 52.46it/s][A
 80%|████████  | 64/80 [00:01<00:00, 52.27it/s][A
 88%|████████

In [6]:
datax_corr = data_x.corr()
datay_corr = data_y.corr()
cpl_x = list()
cpl_y = list()

num_ant = 13

for j in range(len(datax_corr)):
    if (datax_corr.iloc[num_ant, j] > 0.5):
        cpl_x.append('%i'%j)
    if (datay_corr.iloc[num_ant, j] > 0.5):
        cpl_y.append('%i'%j)

ant_related = list()
for j in cpl_x:
    if (j in cpl_y):
        ant_related.append(j)
         
print(ant_related)

['13', '28', '225']


In [7]:
data_df = pd.DataFrame(data_chamber['%i'%num_ant])
data_df.rename({'%i'%num_ant:'chamber'}, axis = 'columns', inplace = True)
for j in ant_related:
    data_df = pd.concat([data_df, data_x[j]], axis = 1)
    data_df.rename({j:'location_x_%s'%j}, axis = 'columns', inplace = True)
    data_df = pd.concat([data_df, data_y[j]], axis = 1)
    data_df.rename({j:'location_y_%s'%j}, axis = 'columns', inplace = True)
print(data_df)

      chamber  location_x_13  location_y_13  location_x_28  location_y_28  \
0           4          38.89         162.52          16.97          28.78   
1           4          38.89         162.52          15.11          27.10   
2           4          38.89         162.52          22.93          30.13   
3           4          38.89         162.52          27.48          29.77   
4           4          38.89         162.52          27.31          25.67   
...       ...            ...            ...            ...            ...   
2875        5           0.00           0.00          35.23         151.88   
2876        5           0.00           0.00          35.23         151.88   
2877        5           0.00           0.00          35.23         151.88   
2878        5           0.00           0.00          33.77         153.15   
2879        5           0.00           0.00          34.45         153.61   

      location_x_225  location_y_225  
0              51.78           11.35

In [8]:
train = data_df[:int(0.5*(len(data_df)))]
valid = data_df[int(0.5*(len(data_df))):]

In [57]:
tmp = []
for p in tqdm(range(5)):
    for q in tqdm(range(5)):
        try:
            tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
        except:
            tmp.append([None, p, q])
tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
print(tmp)
tmp[tmp['bic'] == tmp['bic'].min()]

  0%|          | 0/5 [00:00<?, ?it/s]




100%|██████████| 5/5 [13:19<00:00, 159.95s/it][A
 20%|██        | 1/5 [13:19<53:19, 799.77s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [09:21<00:00, 112.34s/it][A
 40%|████      | 2/5 [22:41<32:59, 659.74s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [18:22<00:00, 220.58s/it][A
 60%|██████    | 3/5 [41:04<28:44, 862.11s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  w


  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [29:24<00:00, 352.91s/it][A
 80%|████████  | 4/5 [1:10:28<20:18, 1218.38s/it]

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [40:19<00:00, 483.91s/it][A
100%|██████████| 5/5 [1:50:48<00:00, 1329.70s/it]


             bic  p  q
0            NaN  0  0
1   67754.475850  0  1
2   77206.057555  0  2
3   73022.086630  0  3
4   79306.548262  0  4
5   36323.447395  1  0
6   35945.808452  1  1
7   35828.146075  1  2
8   36122.437493  1  3
9   36313.958707  1  4
10  35499.394520  2  0
11  35813.531964  2  1
12  35902.291604  2  2
13  36108.254198  2  3
14  36239.478776  2  4
15  35454.191329  3  0
16  35781.247359  3  1
17  36103.805069  3  2
18  36324.744831  3  3
19  36507.936685  3  4
20  35579.786935  4  0
21  35932.917200  4  1
22  36264.200021  4  2
23  36593.618029  4  3
24  36815.981334  4  4


Unnamed: 0,bic,p,q
15,35454.191329,3,0


In [None]:
prediction_x = list()
prediction_y = list()

for i in range(len(train)):
    prediction_x.append(train.iloc[i, 1])
    prediction_y.append(train.iloc[i, 2])

start_t = len(train)
for t_i in tqdm(range(len(valid))):
    current_t = t_i + start_t
    model = VARMAX(data_df[t_i:current_t], order = (3,0))
    fitted_model = model.fit()
    prediction = fitted_model.forecast().reset_index(drop=True) 
    #print(prediction)
    prediction_x.append(prediction['location_x_%i'%num_ant]) 
    prediction_y.append(prediction['location_y_%i'%num_ant]) 