In [None]:
import sys
sys.path.insert(0, "../pycre/")

import matplotlib.pyplot as plt
import seaborn as sns

import time
import numpy as np
import pandas as pd

from cre import *
from parsers import get_parser
from dataset import dataset_generator

# Simulations

In [None]:
computation_time = pd.DataFrame(columns = ["B", "seed", "n", "time"])

for B in [1,10,100]:
    print(f"B: {B}")
    for seed in [1,2,3,4]:
        np.random.seed(seed)
        print(f"Seed: {seed}")
        for n in np.logspace(2, 6, 10).astype(int):
            print(f"Data Size: {n}")
            X, y, z, ite = dataset_generator(effect_size = 5,
                                            M = 2,
                                            binary_out = False,
                                            N = n)

            start_time = time.time()
            model = CRE(verbose = False)
            model.fit(X, y, z)
            end_time = time.time()

            execution_time = end_time - start_time
            print(f"Execution time: {execution_time} seconds")
            # add a row to a dataframe
            computation_time = computation_time.append({"B": B,
                                                        "seed": seed,
                                                        "n": n,
                                                        "time": execution_time}, 
                                                        ignore_index = True)

computation_time.to_csv("../results/computation_time_py.csv", index = False)

# Comparison Python vs R

In [None]:
computation_time_py = pd.read_csv("../results/computation_time_py.csv")
computation_time_R = pd.read_csv("../results/computation_time_R.csv")

In [None]:
B = 100
computation_time_py_B = computation_time_py[computation_time_py["B"] == B]
computation_time_R_B = computation_time_R[computation_time_R["B"] == B]

plt.figure(figsize=(8, 4))
sns.lineplot(x = "n", y = "time", data = computation_time_py_B, ci = 100, label = "Python")
sns.lineplot(x = "n", y = "time", data = computation_time_R_B, ci = 100, label = "R")
plt.title(f"Computation time (B = {B})")
plt.xlabel("Number of Individuals (n)")
plt.ylabel("Execution time (seconds)")
plt.xscale("log")
plt.yscale("log")
#yticks = np.arange(0, 5000, 400)
#plt.yticks(yticks, yticks);

# JOSS paper

In [None]:
computation_time_JOSS = pd.read_csv("../results/computation_time_JOSS.csv")
computation_time_JOSS

In [None]:
computation_time_JOSS_p5 = computation_time_JOSS[computation_time_JOSS["p"] == 5]
computation_time_JOSS_p10 = computation_time_JOSS[computation_time_JOSS["p"] == 10]
computation_time_JOSS_p50 = computation_time_JOSS[computation_time_JOSS["p"] == 50]

plt.figure(figsize=(8, 4))
sns.lineplot(x = "n", y = "time", data = computation_time_JOSS_p5, ci = "sd", label = "5")
sns.lineplot(x = "n", y = "time", data = computation_time_JOSS_p10, ci = "sd", label = "10")
sns.lineplot(x = "n", y = "time", data = computation_time_JOSS_p50, ci = "sd", label = "50")
#plt.title(f"Computation time")
plt.xlabel("Number of Individuals (n)")
plt.ylabel("Execution time (seconds)")
plt.xscale("log")
plt.yscale("log")
plt.legend(title = "N. Covariates");
#yticks = np.arange(0, 5000, 400)
#plt.yticks(yticks, yticks);
#plt.savefig("computation_time_JOSS.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
