In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from statsmodels.tsa.stattools import adfuller as ADF
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.varmax import VARMAX

In [2]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-whitegrid')
plt.style.use('seaborn-poster')
plt.style.use('seaborn-dark-palette')
plt.rcParams["mathtext.fontset"] = "cm"

In [3]:
def num_chamber(list_x, list_y):
    list_x = list_x.values
    list_y = list_y.values
    list_c = list()
    for i in range(len(list_x)):
        x = list_x[i]
        y = list_y[i]
        #print(x, y)
        if (y > 178):
            list_c.append(5)
        else:
            if ((y <= 0)&(x <= 0)):
                list_c.append(5)
            else:
                if (y <= 46):
                    list_c.append(1)
                else:
                    if (y <= 92):
                        list_c.append(2)
                    else:
                        if (y <= 138):
                            list_c.append(3)
                        else:
                            list_c.append(4)
    return list_c  

In [4]:
def plot_time_series(ts_1, ts_label_1, ts_2, ts_label_2, title, path):
    assert len(ts_1) == len(ts_2)
    xs = list(range(0, len(ts_1)))

    plt.rcParams['savefig.dpi'] = 300 
    plt.rcParams['figure.dpi'] = 300
    
    plt.plot(xs, ts_1, c='green', label=ts_label_1, lw = 1)
    plt.plot(xs, ts_2, c='red', label=ts_label_2, lw = 1)

    plt.title(title)
    plt.legend(loc='upper left')
    plt.savefig(path)
    plt.show()

In [5]:
def ARIMA_model_loop(data, coef, times): 
    
    train = data[:int(len(data)*coef)]
    test = data[int(len(data)*coef):]
    par = order(data, times)
    # Forecast
    start_t = len(train)
    predictions = list()
    for t in tqdm(range(len(test))):        
        current_t = t + start_t
        model = ARIMA(data[:current_t], order=(par['p'], times, par['q']))       
        model_fit = model.fit()  
        #print(model_fit.forecast())
        #print(model_fit.summary())
        predictions.append(model_fit.forecast().iloc[0])
        
    predictions = pd.DataFrame(predictions)
    predictions = pd.concat([train, predictions], axis = 0)
    predictions.reset_index(inplace = True, drop = True)
    
    #residuals = pd.DataFrame(model_fit.resid)
    #fig, ax = plt.subplots(1,2)
    #residuals.plot(title="Residuals", ax=ax[0])
    #residuals.plot(kind='kde', title='Density', ax=ax[1])
    #plt.show()
    
    return predictions 

def adf_test(ts, signif=0.05):
    times = -1
    p = 1
    while (p > signif):
        times = times + 1
        dftest = ADF(ts)
        #adf = pd.Series(dftest[0:4], index=['Test Statistic','p-value','# Lags','# Observations'])
        #for key,value in dftest[4].items():
        #    adf['Critical Value (%s)'%key] = value
        #print (adf)
        #print(dftest)
        p = dftest[1]
        ts = ts.diff().dropna()
    return times
        
def order(train, times):
    tmp = []
    for p in tqdm(range(1, 6)):
        for q in tqdm(range(5)):
            try:
                tmp.append([ARIMA(train, order=(p, times, q)).fit().bic, p, q])
            except:
                tmp.append([None, p, q])
    tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
    return tmp[tmp['bic'] == tmp['bic'].min()]

In [None]:
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
coef = 0.5
colony_id = 3

for ant_id in range(69, 83):
#ant_id = 56
    colony = 154
    data = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

    num_ant = colony + ant_id
    times_x = adf_test(data['location_x'])
    times_y = adf_test(data['location_y'])
    #times = 1
    pred_x = ARIMA_model_loop(data['location_x'], coef, times_x)
    pred_y = ARIMA_model_loop(data['location_y'], coef, times_y)
    prediction_c = num_chamber(pred_x, pred_y)
    prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%num_ant])
    prediction_c.to_csv('../dataset/insect/ant/prediction/prediction_%i.csv'%num_ant, index = False)
    plot_time_series(ts_1=pred_x, ts_label_1='ARIMA Model', ts_2=data['location_x'], ts_label_2='True data', title='ARIMA predictions vs. ground truth of x', path = '../figures/insect/%s/x.png'%num_ant)
    plot_time_series(ts_1=pred_y, ts_label_1='ARIMA Model', ts_2=data['location_y'], ts_label_2='True data', title='ARIMA predictions vs. ground truth of y', path = '../figures/insect/%s/y.png'%num_ant)

    #pred.to_csv('../dataset/radar/predictions/%s.csv'%radar, index = False)


  0%|          | 0/5 [00:00<?, ?it/s]





100%|██████████| 5/5 [00:05<00:00,  1.19s/it][A
 20%|██        | 1/5 [00:05<00:23,  5.96s/it]

  warn('Non-invertible starting MA parameters found.'






100%|██████████| 5/5 [00:08<00:00,  1.63s/it][A
 40%|████      | 2/5 [00:14<00:21,  7.25s/it]





100%|██████████| 5/5 [00:10<00:00,  2.03s/it][A
 60%|██████    | 3/5 [00:24<00:17,  8.56s/it]


  warn('Non-stationary starting autoregressive parameters'





100%|██████████| 5/5 [00:13<00:00,  2.75s/it][A
 80%|████████  | 4/5 [00:37<00:10, 10.61s/it]





100%|██████████| 5/5 [00:16<00:00,  3.38s/it][A
100%|██████████| 5/5 [00:54<00:00, 10.98s/it]






























  warn('Non-invertible starting MA parameters found.'
































































































































































In [None]:
# Step 1: Data Loading
data_df = pd.read_csv('../dataset/insect/ant/location_in_mm.csv')
colony_id = 2

for ant_id in range(0, 40):
    #ant_id = 0
    colony = 74
    data = data_df[(data_df['colony_id'] == colony_id) & (data_df['ant_id'] == ant_id)]

    num_ant = colony + ant_id
    data.reset_index(inplace = True)
    data = data[['location_x', 'location_y']]
    print(num_ant)

    # Step 2: Data Preparation
    #creating the train and validation set
    valid = data[int(0.5*(len(data))):]
    train = data[:int(0.5*(len(data)))]

    tmp = []
    for p in tqdm(range(5)):
        for q in tqdm(range(5)):
            try:
                tmp.append([VARMAX(train, order = (p,q)).fit().bic, p, q])
            except:
                tmp.append([None, p, q])
    tmp = pd.DataFrame(tmp,columns = ['bic', 'p', 'q'])
    print(tmp)
    order = tmp[tmp['bic'] == tmp['bic'].min()]
    print(order)

    try: 
        prediction_x = list()
        prediction_y = list()

        for i in range(len(train)):
            prediction_x.append(train.iloc[i, 0])
            prediction_y.append(train.iloc[i, 1])
        
        start_t = len(train)
        for t_i in tqdm(range(len(valid))):
            current_t = t_i + start_t
            model = VARMAX(data[t_i:current_t], order = (order['p'], order['q']))
            fitted_model = model.fit()
            prediction = fitted_model.forecast(len(valid)).reset_index(drop=True) 
            prediction = pd.DataFrame(prediction)
            #predictions = pd.concat([train, prediction], axis = 0)
            prediction_x.append(prediction['location_x']) 
            prediction_y.append(prediction['location_y']) 

        #plot_time_series(ts_1 = prediction_chamber, ts_label_1 = 'VARMA Model', ts_2 = valid['chamber'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth')
        plot_time_series(ts_1 = prediction['location_x'], ts_label_1 = 'VARMA Model', ts_2 = data['location_x'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location x', path = '../figures/insect/%i/x1.png'%num_ant)
        plot_time_series(ts_1 = prediction['location_y'], ts_label_1 = 'VARMA Model', ts_2 = data['location_y'], ts_label_2 = 'Close', title = 'VARMA predictions vs. ground truth of location y', path = '../figures/insect/%i/y1.png'%num_ant)

        prediction_c = num_chamber(prediction['location_x'], prediction['location_y'])
        prediction_c = pd.DataFrame(prediction_c, columns = ['%i'%num_ant])
        prediction_c.to_csv('../dataset/insect/ant/predictions/prediction_%i.csv'%num_ant, index = False)
    except Exception as reason:
        print('%i'%num_ant, reason)

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/5 [00:00<?, ?it/s][A

74






100%|██████████| 5/5 [00:14<00:00,  2.93s/it][A
 20%|██        | 1/5 [00:14<00:58, 14.63s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:08<00:00,  1.68s/it][A
 40%|████      | 2/5 [00:23<00:32, 10.97s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'

100%|██████████| 5/5 [00:19<00:00,  3.83s/it][A
 60%|██████    | 3/5 [00:42<00:29, 14.70s/it]
  0%|          | 0/5 [00:00<?, ?it/s][A
  warn('Estimation of VARMA(p,q) mod


  warn('Estimation of VARMA(p,q) models is not generically robust,'

  warn('Estimation of VARMA(p,q) models is not generically robust,'
