In [12]:
from utils.viz import plot_time_series, plot_balance, plot_balance_vs_price
import matplotlib.pyplot as plt
import numpy as np

In [13]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [14]:
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX

In [15]:
# Step 1: Data Loading
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
colony_id = 2
ant_id = 3
colony = len(np.unique(data_df[data_df['colony_id'] == colony_id-1][['ant_id']])) + 1
data_df = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

ant_num = colony + ant_id
data_df.reset_index(inplace = True)
data = data_df[['location_x', 'location_y']]

In [5]:
# Step 2: Data Preparation
#creating the train and validation set
train = data[:int(0.5*(len(data)))]
valid = data[int(0.5*(len(data))):]
print(train)

      location_x  location_y
0          11.66      151.22
1          11.66      151.22
2          11.66      151.22
3          11.66      151.22
4          11.66      151.22
...          ...         ...
1435        1.61      141.82
1436        1.61      141.82
1437        1.61      141.82
1438        1.61      141.82
1439        1.61      141.82

[1440 rows x 2 columns]


In [6]:
tmp = []
for p in tqdm(range(5)):
    for q in tqdm(range(5)):
        try:
            tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
        except:
            tmp.append([None, p, q])
tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
print(tmp)
order = tmp[tmp['bic'] == tmp['bic'].min()]
print(order)

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A



100%|██████████| 5/5 [00:17<00:00,  3.50s/it][A
 20%|██        | 1/5 [00:17<01:09, 17.50s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:20<00:00,  4.09s/it][A
 40%|████      | 2/5 [00:37<00:57, 19.23s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:29<00:00,  5.95s/it][A
 60%|██████    | 3/5 [01:07<00:48, 24.05s/it]


100%|██████████| 5/5 [00:38<00:00,  7.79s/it][A
 80%|████████  | 4/5 [01:46<00:29, 29.93s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:51<00:00, 10.25s/it][A
100%|██████████| 5/5 [02:37<00:00, 31.59s/it]

             bic  p  q
0            NaN  0  0
1   13787.492961  0  1
2   13434.761233  0  2
3   12962.612854  0  3
4   12710.460043  0  4
5    5682.065804  1  0
6    5672.789545  1  1
7    5651.262373  1  2
8    5626.363617  1  3
9    5652.842797  1  4
10   5672.366066  2  0
11   5676.803615  2  1
12   5671.462012  2  2
13   5650.297533  2  3
14   5663.330354  2  4
15   5667.904472  3  0
16   5668.696436  3  1
17   5688.471917  3  2
18   5676.971169  3  3
19   5694.569158  3  4
20   5623.917410  4  0
21   5647.057444  4  1
22   5672.666719  4  2
23   5616.073371  4  3
24   5678.669090  4  4
            bic  p  q
23  5616.073371  4  3





In [None]:
prediction_x = list()
prediction_y = list()

for i in range(len(train)):
    prediction_x.append(train.iloc[i, 0])
    prediction_y.append(train.iloc[i, 1])

start_t = len(train)
for t_i in tqdm(range(len(valid))):
    current_t = t_i + start_t
    model = VARMAX(data[t_i:current_t], order = (order['p'], order['q']))
    fitted_model = model.fit()
    prediction = fitted_model.forecast().reset_index(drop=True) 
    prediction_x.append(prediction['location_x']) 
    prediction_y.append(prediction['location_y']) 
    


  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not g

  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'
  warn('Estimation of VARMA(p,q) models is not generically robust,'


In [None]:
#plot_time_series(ts_1 = prediction_chamber, ts_label_1 = 'VARMA Model', ts_2 = valid['chamber'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth')
plot_time_series(ts_1 = prediction_x, ts_label_1 = 'VARMA Model', ts_2 = data_df['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x')
plot_time_series(ts_1 = prediction_y, ts_label_1 = 'VARMA Model', ts_2 = data_df['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y')


In [None]:
def num_chamber(list_x, list_y):
    list_c = list()
    for i in range(len(list_x)):
        x = int(list_x[i])
        y = int(list_y[i])
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [None]:
prediction_c = num_chamber(prediction_x, prediction_y)
prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%num_ant])
prediction_c.to_csv('../dataset/insect/ant/prediction_%i.csv'%num_ant, index = False)