In [3]:
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})

parameters = {"axes.labelsize": 20, "legend.fontsize": 16, "xtick.labelsize": 20, "ytick.labelsize": 20, "lines.linewidth": 2, "lines.markersize": 10}


import lcpfn.lcpfn as lcpfn

path_all = '/mnt/c/Users/prath/PycharmProjects/rp/LCDB_localised/all_curves_preprocessed.pkl'
df_all = pd.read_pickle(path_all)

# model_name = '/mnt/c/Users/prath/PycharmProjects/rp/Data/model_lcdb_2.pt'
model_name = None
if model_name is None:
    model = lcpfn.LCPFN()
else:
    model = lcpfn.LCPFN(model_name=model_name)
print(model_name)

None


In [30]:
#vairables
CUTOFF = 10 #percentage of the curve use as input
FAIL_PER = 60 #percentage of points that should be within the CI

## Check for which curves the lcpfn model is not able to predict

In [31]:
def get_curve(i:int):
    """Retrieve the curve for a given model and dataset."""
    row = df_all.iloc[i]
    openlid = row['openmlid']
    learner = row['learner']
    anchors = row['anchors']
    means = row['means']
    std = row['std']
    return anchors, means, std, openlid, learner

def get_closest_index(lst, target):
    return min(range(len(lst)), key=lambda i: abs(lst[i] - target))

def get_common_data(cutoff, anchors, means, ):
    means = np.array(means)
    anchors = np.array(anchors)
    anchors = ((anchors - np.min(anchors)) / (np.max(anchors) - np.min(anchors))) * 100

    cutoff_index = get_closest_index(anchors, cutoff)
    cutoff_index = cutoff_index + 1 if cutoff_index == 0 else cutoff_index
    curve = np.array(means[:cutoff_index])
    anchors = anchors.astype(int)

    x = torch.from_numpy(anchors).unsqueeze(1)
    y = torch.from_numpy(curve).float().unsqueeze(1)
    return x, y, anchors, means, cutoff_index

def plot_it(i, anchors, means, predictions, cutoff_index, x, learner):
    plt.figure(figsize=(7.5, 7.5))
    plt.tight_layout()

    plt.plot(anchors, means, "*", label="target")
    plt.plot(anchors, means, label="target")
    plt.plot(x[cutoff_index:], predictions[:, 1], "r*", label="Extrapolation by PFN")
    plt.fill_between(
        x[cutoff_index:].flatten(), predictions[:, 0], predictions[:, 2], color="red", alpha=0.3, label="CI of 90%"
    )
    plt.vlines(CUTOFF, 0, 1, linewidth=0.5, color="k", label="cutoff")
    plt.ylim(0, 1)
    plt.legend(loc="lower right")
    plt.savefig("/mnt/c/Users/prath/PycharmProjects/rp/Data/extrapolation_"+str(i)+"_"+str(CUTOFF)+".png", dpi=400)
    plt.show()

In [32]:
curve_that_failed = []
for i in range(len(df_all)):
    print(str(i) + f'{len(df_all)}', end="\r")
    a, m, std, openlid, learner = get_curve(i)
    x, y, anchors, means, cutoff_index = get_common_data(CUTOFF, a, m)
    predictions = model.predict_quantiles(x_train=x[:cutoff_index], y_train=y, x_test=x[cutoff_index:], qs=[0.05, 0.5, 0.95])
    predictions = predictions.detach().numpy()
    low_ci = predictions[:, 0]
    high_ci = predictions[:, 2]

    #check if the truth points are within the CI area
    conditions = (means[cutoff_index:] >= low_ci) & (means[cutoff_index:] <= high_ci)
    within_ci = np.where(conditions)[0]

    #check if the points outside the CI area are above or below the CI
    conditions_low = (means[cutoff_index:] < low_ci)
    conditions_high = (means[cutoff_index:] > high_ci)
    outside_ci_low = np.where(conditions_low)[0]
    outside_ci_high = np.where(conditions_high)[0]

    if len(within_ci) < FAIL_PER/100 * len(means[cutoff_index:]):
        curve_that_failed.append((i, openlid, learner, len(within_ci), len(means[cutoff_index:]), len(outside_ci_low), len(outside_ci_high)))
        print(i, openlid, learner, len(within_ci), len(means[cutoff_index:]))
        if i == 23:
            plot_it(i, anchors, means, predictions, cutoff_index, x, learner)
            break

3 44 SVC_sigmoid 3 7
20 188 SVC_linear 0 8
21 188 SVC_poly 2 8
22 188 SVC_rbf 4 8


  plt.show()


In [33]:
i=23
a, m, std, openlid, learner = get_curve(i)
x, y, anchors, means, cutoff_index = get_common_data(CUTOFF, a, m)
predictions = model.predict_quantiles(x_train=x[:cutoff_index], y_train=y, x_test=x[cutoff_index:], qs=[0.05, 0.5, 0.95])
predictions = predictions.detach().numpy()
plot_it(i, anchors, means, predictions, cutoff_index, x, learner)

  plt.show()


In [63]:
print(len(curve_that_failed))
df = pd.DataFrame(curve_that_failed, columns=['index', 'openmlid', 'learner', 'within_ci', 'total', 'outside_ci_low', 'outside_ci_high'])
a = df['learner'].value_counts()
b= df['openmlid'].value_counts()
#save the failed curves count learner

df.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_{CUTOFF}.csv')
a.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_count_{CUTOFF}.csv')

1255


# Plot the curves that failed

In [None]:
plt.ylim(0, 1)
for i in curve_that_failed:
    a, m, std, openlid, learner = get_curve(i[0])
    x, y, a_2, m_2, cutoff_index = get_common_data(CUTOFF, a, m)
    plt.plot(a_2, m_2, label="target", alpha=0.4)
plt.show()

In [None]:
#Analyse the failed curves

In [76]:
df_10 = pd.read_csv('/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_10.csv')
df_20 = pd.read_csv('/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_20.csv')
df_40 = pd.read_csv('/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_40.csv')
df_80 = pd.read_csv('/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_80.csv')

a = df['learner'].value_counts()
a_10 = df_10['learner'].value_counts()
b_10 = df_10['openmlid'].value_counts()

a_20 = df_20['learner'].value_counts()
b_20 = df_20['openmlid'].value_counts()

a_40 = df_40['learner'].value_counts()
b_40 = df_40['openmlid'].value_counts()

a_80 = df_80['learner'].value_counts()
b_80 = df_80['openmlid'].value_counts()


b_10.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_d_count_10.csv')
b_20.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_d_count_20.csv')
b_40.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_d_count_40.csv')
b_80.to_csv(f'/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_d_count_80.csv')




In [182]:
# print counts
overestimate_10 = df_10[df_10['outside_ci_high'] > df_10['outside_ci_low']].index.tolist()
underestimate_10 = df_10[df_10['outside_ci_high'] < df_10['outside_ci_low']].index.tolist()
count_e_10 = df_10[df_10['outside_ci_high'] == df_10['outside_ci_low']].index.tolist()

overestimate_20 = df_20[df_20['outside_ci_high'] > df_20['outside_ci_low']].index.tolist()
underestimate_20 = df_20[df_20['outside_ci_high'] < df_20['outside_ci_low']].index.tolist()
count_e_20 = df_20[df_20['outside_ci_high'] == df_20['outside_ci_low']].index.tolist()

overestimate_40 = df_40[df_40['outside_ci_high'] > df_40['outside_ci_low']].index.tolist()
underestimate_40 = df_40[df_40['outside_ci_high'] < df_40['outside_ci_low']].index.tolist()
count_e_40 = df_40[df_40['outside_ci_high'] == df_40['outside_ci_low']].index.tolist()

overestimate_80 = df_80[df_80['outside_ci_high'] > df_80['outside_ci_low']].index.tolist()
underestimate_80 = df_80[df_80['outside_ci_high'] < df_80['outside_ci_low']].index.tolist()
count_e_80 = df_80[df_80['outside_ci_high'] == df_80['outside_ci_low']].index.tolist()

print(len(overestimate_10), len(underestimate_10), len(count_e_10))
print(len(overestimate_20), len(underestimate_20), len(count_e_20))
print(len(overestimate_40), len(underestimate_40), len(count_e_40))
print(len(overestimate_80), len(underestimate_80), len(count_e_80))
#print total
print(len(df_10), len(df_20), len(df_40), len(df_80))


266 158 1
272 296 4
406 557 23
543 703 9
425 572 986 1255


In [196]:
plt.close('all')
plt.figure(figsize=(20, 10))

for i in df_10['index']:
    #check if the curve is from lerner SVC_sigmod
    if df_all.iloc[i]['learner'] == 'SVC_sigmoid':
        a, m, std, openlid, learner = get_curve(i)
        x, y, a_2, m_2, cutoff_index = get_common_data(CUTOFF, a, m)
        plt.plot(a_2, m_2, alpha=0.4)
plt.savefig("/mnt/c/Users/prath/PycharmProjects/rp/Data/failed_curves_svc_sigmoid.png", dpi=400)
plt.show()


  plt.show()
