In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX

In [2]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [3]:
def num_chamber(list_x, list_y):
    list_c = list()
    for i in range(len(list_x)):
        x = int(list_x[i])
        y = int(list_y[i])
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [4]:
def plot_time_series(ts_1, ts_label_1, ts_2, ts_label_2, title, path):
    assert len(ts_1) == len(ts_2)
    xs = list(range(0, len(ts_1)))

    plt.plot(xs, ts_1, c='green', label=ts_label_1)
    plt.plot(xs, ts_2, c='red', label=ts_label_2)

    plt.title(title)
    plt.legend(loc='upper left')
    plt.savefig(path)
    plt.show()

In [None]:
# Step 1: Data Loading
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
colony_id = 1
#for ant_id in range(32, 40):
    
ant_id = 70
#colony = 154
colony = 1
data = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

num_ant = colony + ant_id
data.reset_index(inplace = True)
data = data[['location_x', 'location_y']]
print(num_ant)

# Step 2: Data Preparation
#creating the train and validation set
train = data[:int(0.5*(len(data)))]
valid = data[int(0.5*(len(data))):]

try:
    tmp = []
    for p in tqdm(range(5)):
        for q in tqdm(range(5)):
            try:
                tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
            except:
                tmp.append([None, p, q])
    tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
    print(tmp)
    order = tmp[tmp['bic'] == tmp['bic'].min()]
    print(order)

    prediction_x = list()
    prediction_y = list()

    for i in range(len(train)):
        prediction_x.append(train.iloc[i, 0])
        prediction_y.append(train.iloc[i, 1])

    start_t = len(train)
    for t_i in tqdm(range(len(valid))):
        current_t = t_i + start_t
        model = VARMAX(data[t_i:current_t], order = (order['p'], order['q']))
        fitted_model = model.fit()
        prediction = fitted_model.forecast().reset_index(drop=True) 
        prediction_x.append(prediction['location_x']) 
        prediction_y.append(prediction['location_y']) 

    #plot_time_series(ts_1 = prediction_chamber, ts_label_1 = 'VARMA Model', ts_2 = valid['chamber'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth')
    plot_time_series(ts_1 = prediction_x, ts_label_1 = 'VARMA Model', ts_2 = data['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x', path = '../figures/insect/%i/x.png'%num_ant)
    plot_time_series(ts_1 = prediction_y, ts_label_1 = 'VARMA Model', ts_2 = data['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y', path = '../figures/insect/%i/y.png'%num_ant)

    prediction_c = num_chamber(prediction_x, prediction_y)
    prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%num_ant])
    prediction_c.to_csv('../dataset/insect/ant/prediction/prediction_%i.csv'%num_ant, index = False)
except Exception as reason:
    print('%i'%num_ant, reason)

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A

224





 80%|████████  | 4/5 [00:08<00:02,  2.57s/it][A