In [61]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX

In [62]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [63]:
def num_chamber(list_x, list_y):
    list_c = list()
    for i in range(len(list_x)):
        x = int(list_x[i])
        y = int(list_y[i])
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [64]:
def plot_time_series(ts_1, ts_label_1, ts_2, ts_label_2, title, path):
    assert len(ts_1) == len(ts_2)
    xs = list(range(0, len(ts_1)))

    plt.plot(xs, ts_1, c='green', label=ts_label_1)
    plt.plot(xs, ts_2, c='red', label=ts_label_2)

    plt.title(title)
    plt.legend(loc='upper left')
    plt.savefig(path)
    plt.show()

In [None]:
# Step 1: Data Loading
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
colony_id = 3
for ant_id in range(0, 83):
    #ant_id = 68
    colony = 154
    data_df = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

    num_ant = colony + ant_id
    data_df.reset_index(inplace = True)
    data = data_df[['location_x', 'location_y']]
    print(num_ant)

    # Step 2: Data Preparation
    #creating the train and validation set
    train = data[:int(0.5*(len(data)))]
    valid = data[int(0.5*(len(data))):]

    tmp = []
    for p in tqdm(range(5)):
        for q in tqdm(range(5)):
            try:
                tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
            except:
                tmp.append([None, p, q])
    tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
    print(tmp)
    order = tmp[tmp['bic'] == tmp['bic'].min()]
    print(order)

    prediction_x = list()
    prediction_y = list()

    for i in range(len(train)):
        prediction_x.append(train.iloc[i, 0])
        prediction_y.append(train.iloc[i, 1])

    start_t = len(train)
    for t_i in tqdm(range(len(valid))):
        current_t = t_i + start_t
        model = VARMAX(data[t_i:current_t], order = (order['p'], order['q']))
        fitted_model = model.fit()
        prediction = fitted_model.forecast().reset_index(drop=True) 
        prediction_x.append(prediction['location_x']) 
        prediction_y.append(prediction['location_y']) 

    #plot_time_series(ts_1 = prediction_chamber, ts_label_1 = 'VARMA Model', ts_2 = valid['chamber'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth')
    plot_time_series(ts_1 = prediction_x, ts_label_1 = 'VARMA Model', ts_2 = data_df['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x', path = '../figures/insect/%i/x.png'%num_ant)
    plot_time_series(ts_1 = prediction_y, ts_label_1 = 'VARMA Model', ts_2 = data_df['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y', path = '../figures/insect/%i/y.png'%num_ant)

    prediction_c = num_chamber(prediction_x, prediction_y)
    prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%num_ant])
    prediction_c.to_csv('../dataset/insect/ant/prediction/prediction_%i.csv'%num_ant, index = False)


  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A

142






100%|██████████| 5/5 [00:15<00:00,  3.16s/it][A
 20%|██        | 1/5 [00:15<01:03, 15.82s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:18<00:00,  3.65s/it][A
 40%|████      | 2/5 [00:34<00:51, 17.26s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:31<00:00,  6.34s/it][A
 60%|██████    | 3/5 [01:05<00:47, 23.87s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) mod


100%|██████████| 5/5 [00:35<00:00,  7.01s/it][A
100%|██████████| 5/5 [02:28<00:00, 29.66s/it]
  0%|          | 0/1440 [00:00<?, ?it/s]

             bic  p  q
0            NaN  0  0
1   12257.377958  0  1
2   11667.506741  0  2
3   11182.239745  0  3
4   10708.676921  0  4
5    6235.815167  1  0
6    6249.820019  1  1
7    6274.273854  1  2
8    6284.064737  1  3
9    6291.595977  1  4
10   6249.207490  2  0
11   6278.291845  2  1
12   6301.642092  2  2
13   6309.624081  2  3
14   6317.602167  2  4
15   6273.735917  3  0
16   6302.774187  3  1
17   6329.403720  3  2
18   6327.901380  3  3
19   6338.927332  3  4
20   6279.659679  4  0
21   6308.672051  4  1
22   6332.812525  4  2
23   6354.526718  4  3
24   6363.771067  4  4
           bic  p  q
5  6235.815167  1  0


 16%|█▌        | 233/1440 [02:50<09:55,  2.03it/s]