In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.statespace.varmax import VARMAX

In [None]:
path = os.path.dirname(__file__)

In [4]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [5]:
def num_chamber(list_x, list_y):
    list_x = list_x.values
    list_y = list_y.values
    list_c = list()
    for i in range(len(list_x)):
        x = int(list_x[i])
        y = int(list_y[i])
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [9]:
def plot_time_series(ts_1, ts_label_1, ts_2, ts_label_2, title, path):
    assert len(ts_1) == len(ts_2)
    xs = list(range(0, len(ts_1)))

    plt.plot(xs, ts_1, c='green', label=ts_label_1)
    plt.plot(xs, ts_2, c='red', label=ts_label_2)

    plt.title(title)
    plt.legend(loc='upper left')
    plt.savefig(path)
    plt.show()

In [None]:
# Data Loading
data_df = pd.read_csv(path+'/data/location_in_mm.csv')
c = [1, 74, 154]
amount_ant = [73, 80, 83]
pred = pd.DataFrame()

for colony_id in range(1, 4):

    for ant_id in range(amount_ant[colony_id-1]):

        colony = c[colony_id-1]
        data = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

        num_ant = colony + ant_id
        data.reset_index(inplace = True)
        data = data[['location_x', 'location_y']]

        # spliting the train and validation set
        train = data[:int(0.5*(len(data)))]
        valid = data[int(0.5*(len(data))):]
        
        model = VARMAX(train, order = (4, 3))
        fitted_model = model.fit()
        prediction = fitted_model.forecast(len(valid)).reset_index(drop=True) 
        prediction = pd.DataFrame(prediction)
        prediction = pd.concat([train, prediction], axis = 0)

        prediction_c = num_chamber(prediction['location_x'], prediction['location_y'])
 
        pred[['%i'%num_ant]] = pd.DataFrame({"%i"%num_ant: prediction_c})
        print(pred)

        plot_time_series(ts_1 = prediction['location_x'], ts_label_1 = 'VARMA Model', ts_2 = data['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x', path = path+'insect/%i/x1.png'%num_ant)
        plot_time_series(ts_1 = prediction['location_y'], ts_label_1 = 'VARMA Model', ts_2 = data['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y', path = path+'insect/%i/y1.png'%num_ant)

        pred.to_csv(path+'/data/pred_noloop.csv', index = False)  
    
 