In [12]:
from utils.viz import plot_time_series, plot_balance, plot_balance_vs_price
import matplotlib.pyplot as plt
import numpy as np

In [13]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [14]:
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX

In [15]:
# Step 1: Data Loading
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
colony_id = 1
ant_id = 66
data_df = data_df[(data_df['colony_id']==colony_id) & (data_df['ant_id']==ant_id)]
ant_num = colony_id + ant_id
data_df.reset_index(inplace = True)
data = data_df[['location_x', 'location_y']]

In [16]:
# Step 2: Data Preparation
#creating the train and validation set
train = data[:int(0.5*(len(data)))]
valid = data[int(0.5*(len(data))):]
print(train)

      location_x  location_y
0          59.34       29.65
1          59.34       29.65
2          59.34       29.65
3          59.34       29.65
4          59.34       29.65
...          ...         ...
1435       60.01       29.13
1436       60.01       29.13
1437       60.01       29.13
1438       60.01       29.13
1439       60.01       29.13

[1440 rows x 2 columns]


In [17]:
tmp = []
for p in tqdm(range(5)):
    for q in tqdm(range(5)):
        try:
            tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
        except:
            tmp.append([None, p, q])
tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
print(tmp)
order = tmp[tmp['bic'] == tmp['bic'].min()]
print(order)

  0%|          | 0/5 [00:00<?, ?it/s]




100%|██████████| 5/5 [00:22<00:00,  4.48s/it][A
 20%|██        | 1/5 [00:22<01:29, 22.40s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [01:10<00:00, 14.11s/it][A
 40%|████      | 2/5 [01:32<02:32, 50.72s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [02:07<00:00, 25.50s/it][A
 60%|██████    | 3/5 [03:40<02:51, 85.77s/it]
  0%|          | 0/5 [00:00<?, ?it/s][


100%|██████████| 5/5 [00:37<00:00,  7.51s/it][A
100%|██████████| 5/5 [04:41<00:00, 56.28s/it]

             bic  p  q
0            NaN  0  0
1   -4157.465754  0  1
2   -4328.984341  0  2
3   -4909.872417  0  3
4   -4990.670166  0  4
5  -13196.842015  1  0
6  -13348.798697  1  1
7  -13330.875874  1  2
8  -13550.157106  1  3
9  -13556.150665  1  4
10 -13321.791244  2  0
11 -13302.868999  2  1
12 -13418.576213  2  2
13 -13534.932642  2  3
14 -13519.060054  2  4
15 -13582.178134  3  0
16 -13568.874339  3  1
17 -13541.903650  3  2
18 -13596.987530  3  3
19 -13633.003202  3  4
20 -13683.357502  4  0
21 -13656.033069  4  1
22 -13627.741680  4  2
23 -13609.479122  4  3
24 -13627.799618  4  4
             bic  p  q
20 -13683.357502  4  0





In [None]:
prediction_x = list()
prediction_y = list()

for i in range(len(train)):
    prediction_x.append(train.iloc[i, 0])
    prediction_y.append(train.iloc[i, 1])

start_t = len(train)
for t_i in tqdm(range(len(valid))):
    current_t = t_i + start_t
    model = VARMAX(data[t_i:current_t], order = (order['p'], order['q']))
    fitted_model = model.fit()
    prediction = fitted_model.forecast().reset_index(drop=True) 
    prediction_x.append(prediction['location_x']) 
    prediction_y.append(prediction['location_y']) 




  warn('Non-stationary starting autoregressive parameters'




























 33%|███▎      | 477/1440 [40:44<2:32:51,  9.52s/it]

In [None]:
#plot_time_series(ts_1 = prediction_chamber, ts_label_1 = 'VARMA Model', ts_2 = valid['chamber'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth')
plot_time_series(ts_1 = prediction_x, ts_label_1 = 'VARMA Model', ts_2 = data['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x')
plot_time_series(ts_1 = prediction_y, ts_label_1 = 'VARMA Model', ts_2 = data['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y')


In [None]:
def num_chamber(list_x, list_y):
    list_c = list()
    for i in range(len(list_x)):
        x = int(list_x[i])
        y = int(list_y[i])
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [None]:
prediction_c = num_chamber(prediction_x, prediction_y)
prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%ant_num])
prediction_c.to_csv('../dataset/insect/ant/prediction/prediction_%i.csv'%ant_num, index = False)