In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
# import only frequently used modules
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

### Adult Dataset

##### **[Adult-1]** Dataset Preparation

In [4]:
adult = pd.read_csv("./data/adult.csv")
adult = adult.drop(columns=["fnlwgt"])
adult = adult.replace({"?": np.nan})
adult = adult.dropna()
adult = adult.replace({"<=50K": 0, ">50K": 1})
adult = adult.reset_index(drop=True)

In [6]:
from sklearn.model_selection import train_test_split

X = adult.drop(columns=["income"])
y = adult["income"]

X_train_, X_test_, y_train_, y_test_ = train_test_split(
    X, y, test_size=0.2, random_state=42
)

In [7]:
# transform categorical features to one-hot encoding, and scale numerical features
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler

categorical_features = X.select_dtypes(include=["object", "bool"]).columns
categorical_transformer = OneHotEncoder(handle_unknown="ignore")

numerical_features = X.select_dtypes(include=["int64", "float64"]).columns
numerical_transformer = StandardScaler()

preprocessor = ColumnTransformer(
    transformers=[
        ("cat", categorical_transformer, categorical_features),
        ("num", numerical_transformer, numerical_features),
    ]
)

X_train: np.ndarray = preprocessor.fit_transform(X_train_).A  # type: ignore
X_test: np.ndarray = preprocessor.transform(X_test_).A  # type: ignore

y_train: np.ndarray = y_train_.to_numpy()  # type: ignore
y_test: np.ndarray = y_test_.to_numpy()  # type: ignore

✅ now `X_train`, `X_test`, `y_train`, `y_test` are prepared

##### **[Adult-2]** Experiment

In [8]:
from utils import get_params_combination

adult_params = get_params_combination({
    "alpha": [0.0],
    "gamma": list(np.linspace(0.001, 0.4, 20)),
    "r": [5.0],
    "nu": [0.04],
    "lambda_max": [10.0],
})

pd.DataFrame(adult_params)  # wrapping with pd.DataFrame just for tabular display

Unnamed: 0,alpha,gamma,r,nu,lambda_max
0,0.0,0.001,5.0,0.04,10.0
1,0.0,0.022,5.0,0.04,10.0
2,0.0,0.043,5.0,0.04,10.0
3,0.0,0.064,5.0,0.04,10.0
4,0.0,0.085,5.0,0.04,10.0
5,0.0,0.106,5.0,0.04,10.0
6,0.0,0.127,5.0,0.04,10.0
7,0.0,0.148,5.0,0.04,10.0
8,0.0,0.169,5.0,0.04,10.0
9,0.0,0.19,5.0,0.04,10.0


In [9]:
from experiments import MainExperiment
from tasks import BinaryLogisticClassificationTask

adult_exp = MainExperiment(BinaryLogisticClassificationTask)

adult_exp.task.train(X_train, y_train)
adult_exp.task.test(X_test, y_test)

adult_results = adult_exp.run(adult_params)

5 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]]
5 / 20 [2] Generating lagrangian caches: 200it [00:00, 3882.77it/s]


5 / 20 [3] Solving:   0%|          | 0/60416 [00:00<?, ?it/s]

9 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
11 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
9 / 20 [2] Generating lagrangian caches: 200it [00:00, 4023.24it/s]


9 / 20 [3] Solving:   0%|          | 0/78828 [00:00<?, ?it/s]

10 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
10 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]

5 / 20 [3] Solving:  10%|▉         | 5990/60416 [00:00<00:00, 59890.78it/s]

11 / 20 [2] Generating lagrangian caches: 200it [00:00, 5410.64it/s]


11 / 20 [3] Solving:   0%|          | 0/88951 [00:00<?, ?it/s]

3 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
1 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
20 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
10 / 20 [2] Generating lagrangian caches: 200it [00:00, 2710.90it/s]


9 / 20 [3] Solving:   8%|▊         | 6029/78828 [00:00<00:01, 60263.62it/s]

18 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
18 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]

11 / 20 [3] Solving:   6%|▋         | 5623/88951 [00:00<00:01, 56222.01it/s]

1 / 20 [2] Generating lagrangian caches: 200it [00:00, 2109.52it/s]


1 / 20 [3] Solving:   0%|          | 0/44449 [00:00<?, ?it/s]

20 / 20 [2] Generating lagrangian caches: 200it [00:00, 2670.71it/s]


20 / 20 [3] Solving:   0%|          | 0/142072 [00:00<?, ?it/s]

3 / 20 [2] Generating lagrangian caches: 200it [00:00, 1746.76it/s]


3 / 20 [3] Solving:   0%|          | 0/52127 [00:00<?, ?it/s]

16 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
18 / 20 [2] Generating lagrangian caches: 200it [00:00, 3824.32it/s]


18 / 20 [3] Solving:   0%|          | 0/129197 [00:00<?, ?it/s]

16 / 20 [2] Generating lagrangian caches: 200it [00:00, 5400.12it/s]


18 / 20 [3] Solving:   5%|▍         | 6056/129197 [00:00<00:02, 60554.86it/s]

8 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]]

8 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]]

9 / 20 [3] Solving:  23%|██▎       | 18233/78828 [00:00<00:00, 60860.96it/s]]

12 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
13 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]

7 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
7 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]

1 / 20 [3] Solving:  27%|██▋       | 12086/44449 [00:00<00:00, 60452.23it/s]]

4 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
2 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]

20 / 20 [3] Solving:   9%|▊         | 12134/142072 [00:00<00:02, 60691.97it/s]

2 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
14 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
14 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]

3 / 20 [3] Solving:  23%|██▎       | 12173/52127 [00:00<00:00, 60880.71it/s]

6 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.28it/s]
15 / 20 [1] Generating oracle caches: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]
15 / 20 [2] Generating lagrangian caches: 0it [00:00, ?it/s]

18 / 20 [3] Solving:   9%|▉         | 12114/129197 [00:00<00:01, 60563.82it/s]

8 / 20 [2] Generating lagrangian caches: 200it [00:00, 2027.86it/s]


8 / 20 [3] Solving:   0%|          | 0/73996 [00:00<?, ?it/s]

19 / 20 [2] Generating lagrangian caches: 200it [00:00, 1840.42it/s]


9 / 20 [3] Solving:  31%|███       | 24320/78828 [00:00<00:00, 60805.67it/s]]]

7 / 20 [2] Generating lagrangian caches: 101it [00:00, 1009.26it/s]]

1 / 20 [3] Solving:  41%|████      | 18132/44449 [00:00<00:00, 60387.24it/s]]

4 / 20 [2] Generating lagrangian caches: 109it [00:00, 1074.80it/s]

20 / 20 [3] Solving:  13%|█▎        | 18205/142072 [00:00<00:02, 60697.46it/s]

14 / 20 [2] Generating lagrangian caches: 96it [00:00, 954.53it/s]]

3 / 20 [3] Solving:  35%|███▌      | 18303/52127 [00:00<00:00, 61070.38it/s]

13 / 20 [2] Generating lagrangian caches: 200it [00:00, 1265.46it/s]


18 / 20 [3] Solving:  14%|█▍        | 18224/129197 [00:00<00:01, 60803.83it/s]

12 / 20 [2] Generating lagrangian caches: 200it [00:00, 1152.95it/s]


9 / 20 [3] Solving:  39%|███▊      | 30421/78828 [00:00<00:00, 60876.81it/s]]]

7 / 20 [2] Generating lagrangian caches: 200it [00:00, 1056.04it/s]


7 / 20 [3] Solving:   0%|          | 0/69316 [00:00<?, ?it/s]

17 / 20 [2] Generating lagrangian caches: 200it [00:00, 976.86it/s] 


17 / 20 [3] Solving:   0%|          | 0/122989 [00:00<?, ?it/s]

4 / 20 [2] Generating lagrangian caches: 200it [00:00, 1110.78it/s]


5 / 20 [3] Solving:  60%|██████    | 36337/60416 [00:00<00:00, 60732.19it/s]

15 / 20 [2] Generating lagrangian caches: 200it [00:00, 1264.60it/s]


15 / 20 [3] Solving:   0%|          | 0/111032 [00:00<?, ?it/s]

14 / 20 [2] Generating lagrangian caches: 200it [00:00, 1128.43it/s]
6 / 20 [2] Generating lagrangian caches: 200it [00:00, 1174.53it/s]



1 / 20 [3] Solving: 100%|██████████| 44449/44449 [00:00<00:00, 60383.66it/s]s]
5 / 20 [3] Solving: 100%|██████████| 60416/60416 [00:00<00:00, 60558.58it/s]s]
3 / 20 [3] Solving: 100%|██████████| 52127/52127 [00:00<00:00, 61039.45it/s]s]
9 / 20 [3] Solving: 100%|██████████| 78828/78828 [00:01<00:00, 60881.33it/s]]]
2 / 20 [3] Solving: 100%|██████████| 48211/48211 [00:00<00:00, 61137.86it/s]s]
4 / 20 [3] Solving: 100%|██████████| 56195/56195 [00:00<00:00, 60441.45it/s]s]
10 / 20 [3] Solving: 100%|██████████| 83813/83813 [00:01<00:00, 60935.85it/s]
11 / 20 [3] Solving: 100%|██████████| 88951/88951 [00:01<00:00, 60675.91it/s]]
6 / 20 [3] Solving: 100%|██████████| 64790/64790 [00:01<00:00, 61089.60it/s]]]
8 / 20 [3] Solving: 100%|██████████| 73996/73996 [00:01<00:00, 60683.15it/s]s]
13 / 20 [3] Solving:  73%|███████▎  | 73149/99686 [00:01<00:00, 60747.03it/s]]
12 / 20 [3] Solving: 100%|██████████| 94242/94242 [00:01<00:00, 60612.75it/s]]]
13 / 20 [3] Solving: 100%|██████████| 99686/99686 [0

In [10]:
target_alpha = 0.0
target_r = 5.0

plot_x = []
plot_I_alpha = []
plot_err = []

for param in adult_results:
    param_dict = {k: v for k, v in list(param)}

    if param_dict["alpha"] != target_alpha or param_dict["r"] != target_r:
        continue

    plot_x.append(param_dict["gamma"])
    plot_I_alpha.append(adult_results[param][2])
    plot_err.append(adult_results[param][3])

In [None]:
import matplotlib.pyplot as plt

plt.plot(plot_x, plot_I_alpha)
plt.xlabel("gamma")
plt.ylabel("I(alpha)")

In [None]:
plt.plot(plot_x, plot_err)
plt.xlabel("gamma")
plt.ylabel("err")