In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

BASE_DIR = '../../../'
import sys
sys.path.append(BASE_DIR)

# custom code
import utils.utils
CONFIG = utils.utils.load_config("../../config.json")
import utils.papers

Using TensorFlow backend.


In [3]:
DATASET = os.path.basename(os.getcwd()) # name of folder this file is in
RANDOM_SEED = CONFIG['random_seed']
EPOCHS = CONFIG["experiment_configs"][DATASET]["epochs"]
BATCH_SIZE = CONFIG["experiment_configs"][DATASET]["batch_size"]

print(RANDOM_SEED)

PROCESSED_DIR = os.path.join(BASE_DIR, f'processed/adult/rs={RANDOM_SEED}')
MODELS_DIR = os.path.join(BASE_DIR, f'models/adult/rs={RANDOM_SEED}')

PROCESSED_SAVEPATH = utils.utils.get_savepath(PROCESSED_DIR, "adult", ".pkl")
BASE_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult", ".h5", mt="base")

# models saved here
if not os.path.exists(BASE_MODEL_SAVEPATH):
    print(f"warning: no model has been done for rs={RANDOM_SEED}")

55


In [4]:
train_df = pd.read_csv(os.path.join(PROCESSED_DIR, "train.csv"))
hyper_train_df = pd.read_csv(os.path.join(PROCESSED_DIR, "hyper_train.csv"))
val_df = pd.read_csv(os.path.join(PROCESSED_DIR, "val.csv"))
hyper_val_df = pd.read_csv(os.path.join(PROCESSED_DIR, "hyper_val.csv"))
test_df = pd.read_csv(os.path.join(PROCESSED_DIR, "test.csv"))

val_full_df = pd.concat([val_df, hyper_val_df])

In [5]:
x_train = train_df.drop('label', axis=1).values
y_train = train_df['label'].values

x_hyper_train = hyper_train_df.drop('label', axis=1).values
y_hyper_train = hyper_train_df['label'].values

x_val_full = val_full_df.drop('label', axis=1).values
y_val_full = val_full_df['label'].values

x_val = val_df.drop('label', axis=1).values
y_val = val_df['label'].values

x_hyper_val = hyper_val_df.drop('label', axis=1).values
y_hyper_val = hyper_val_df['label'].values

x_test = test_df.drop('label', axis=1).values
y_test = test_df['label'].values

In [6]:
y_train = tf.keras.utils.to_categorical(y_train)
y_hyper_train = tf.keras.utils.to_categorical(y_hyper_train)
y_val_full = tf.keras.utils.to_categorical(y_val_full)
y_val = tf.keras.utils.to_categorical(y_val)
y_hyper_val = tf.keras.utils.to_categorical(y_hyper_val)
y_test = tf.keras.utils.to_categorical(y_test)

In [7]:
model = tf.keras.models.Sequential([
    tf.keras.Input(shape=x_train.shape[1]),
    tf.keras.layers.Dense(2, activation=tf.nn.softmax),
])
model.load_weights(BASE_MODEL_SAVEPATH)

In [8]:
# val full acc
preds_val_full = utils.utils.compute_preds(
    model,
    x_val_full,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_val_full, axis=1) == np.argwhere(y_val_full)[:,1]).mean()

0.6319018404907976

In [9]:
# test acc
preds_test = utils.utils.compute_preds(
    model,
    x_test,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_test, axis=1) == np.argwhere(y_test)[:,1]).mean()

0.6192455639657526

# Baseline 1: Fine-Tune

This trains the model a little more on the validation set, in the hopes that it better generalized to test.

In [10]:
# reload save weights, in case being run out-of-order
model.load_weights(BASE_MODEL_SAVEPATH)

In [11]:
# default is 1e-3; we cut by a factor of 10 for fine-tuning
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-4)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

In [12]:
# Need to save the best model by validation loss
# mt stands for model_type
FT_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult", ".h5", mt="ft")

save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=FT_MODEL_SAVEPATH,
    monitor="val_loss",
    mode='min',
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
)

callbacks = [save_best]

In [13]:
model.fit(
    x_val,
    y_val,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    validation_data = (x_hyper_val, y_hyper_val),
    callbacks=callbacks,
)

Epoch 1/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7265 - accuracy: 0.6875
Epoch 00001: val_loss improved from inf to 0.73159, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 2/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7229 - accuracy: 0.6875
Epoch 00002: val_loss improved from 0.73159 to 0.73108, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 3/100
1/6 [====>.........................] - ETA: 0s - loss: 1.0027 - accuracy: 0.5000
Epoch 00003: val_loss improved from 0.73108 to 0.73046, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 4/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7700 - accuracy: 0.5938
Epoch 00004: val_loss improved from 0.73046 to 0.72993, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 5/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6900 - accuracy: 0.6250
Epoch 00005: val_loss improved from 0.72993 to 0.72928, saving model to ../../.

Epoch 26/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6702 - accuracy: 0.6875
Epoch 00026: val_loss improved from 0.71850 to 0.71795, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 27/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7178 - accuracy: 0.6875
Epoch 00027: val_loss improved from 0.71795 to 0.71754, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 28/100
1/6 [====>.........................] - ETA: 0s - loss: 0.4929 - accuracy: 0.7812
Epoch 00028: val_loss improved from 0.71754 to 0.71695, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 29/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6455 - accuracy: 0.6562
Epoch 00029: val_loss improved from 0.71695 to 0.71651, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 30/100
1/6 [====>.........................] - ETA: 0s - loss: 0.4913 - accuracy: 0.7188
Epoch 00030: val_loss improved from 0.71651 to 0.71587, saving model t

Epoch 51/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6794 - accuracy: 0.6562
Epoch 00051: val_loss improved from 0.70534 to 0.70478, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 52/100
1/6 [====>.........................] - ETA: 0s - loss: 1.0076 - accuracy: 0.5938
Epoch 00052: val_loss improved from 0.70478 to 0.70439, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 53/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7824 - accuracy: 0.7188
Epoch 00053: val_loss improved from 0.70439 to 0.70401, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 54/100
1/6 [====>.........................] - ETA: 0s - loss: 0.8201 - accuracy: 0.5625
Epoch 00054: val_loss improved from 0.70401 to 0.70350, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 55/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7011 - accuracy: 0.6250
Epoch 00055: val_loss improved from 0.70350 to 0.70297, saving model t

1/6 [====>.........................] - ETA: 0s - loss: 0.7047 - accuracy: 0.7188
Epoch 00075: val_loss improved from 0.69441 to 0.69395, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 76/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6772 - accuracy: 0.7500
Epoch 00076: val_loss improved from 0.69395 to 0.69350, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 77/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7402 - accuracy: 0.6875
Epoch 00077: val_loss improved from 0.69350 to 0.69301, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 78/100
1/6 [====>.........................] - ETA: 0s - loss: 0.6260 - accuracy: 0.7500
Epoch 00078: val_loss improved from 0.69301 to 0.69265, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5
Epoch 79/100
1/6 [====>.........................] - ETA: 0s - loss: 0.7782 - accuracy: 0.5938
Epoch 00079: val_loss improved from 0.69265 to 0.69216, saving model to ../../../mo

Epoch 100/100
1/6 [====>.........................] - ETA: 0s - loss: 0.5858 - accuracy: 0.7500
Epoch 00100: val_loss improved from 0.68295 to 0.68247, saving model to ../../../models/adult/rs=55/adult_mt=ft.h5


<tensorflow.python.keras.callbacks.History at 0x7f3a35100390>

In [14]:
# see the best saved
model.load_weights(FT_MODEL_SAVEPATH)

In [15]:
# val acc
preds_hyper_val = utils.utils.compute_preds(
    model,
    x_hyper_val,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_hyper_val, axis=1) == np.argwhere(y_hyper_val)[:,1]).mean()

0.6993865030674846

In [16]:
# test acc
preds_test = utils.utils.compute_preds(
    model,
    x_test,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_test, axis=1) == np.argwhere(y_test)[:,1]).mean()

0.6372999131405882

# Baseline 2: Learn to Weigh Examples
Paper: https://arxiv.org/pdf/1803.09050.pdf

This is a type of meta-learning, which doesn't quite work with the keras API. We will need to manually implement the training loop.

In [30]:
model.load_weights(BASE_MODEL_SAVEPATH)

In [33]:
optimizer = tf.keras.optimizers.SGD() # initialize with defaults

In [34]:
# Reduction.NONE means the cross entropy is computed per entry in the batch
# but is not aggregated. Traditional cross entropy will average the results.
ce = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

In [35]:
best_loss = float('inf')
LRW_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult", ".h5", mt="lrw")

total_batches = len(x_train) // BATCH_SIZE 
if total_batches * BATCH_SIZE < len(x_train):
    # this usually happens as // operator rounds down
    # we want total_batches * BATCH_SIZE >= len(x_train)
    total_batches += 1
    
# custom train loop
for epoch in range(EPOCHS):
    # implements a train loop
    print(f"Epoch {epoch}:\n----------")
    
    loss_sum = 0
    for i in tqdm(range(total_batches)):
        # grab the batch and labels
        batch = x_train[i * BATCH_SIZE : i * BATCH_SIZE + BATCH_SIZE]
        labels = y_train[i * BATCH_SIZE : i * BATCH_SIZE + BATCH_SIZE]
        
        # most of the details are abstracted away in this function
        loss = utils.papers.train_step(model, batch, labels, x_val, y_val, ce, optimizer)
        loss_sum += loss

        # print ongoing avg loss
        print(f"Loss: {loss_sum / i}", end='\r')
    
    # compute validation accuracy
    preds = utils.utils.compute_preds(
        model,
        x_hyper_val,
        batch_size=BATCH_SIZE,
    )
    val_acc = (np.argmax(preds, axis=1) == np.argwhere(y_hyper_val)[:,1]).mean()
    loss_avg = loss_sum / total_batches
    
    print(f"Hyper Val Acc: {val_acc}")
    print(f"Hyper Val Loss: {loss_avg}")
    
    # implements save best logic
    if loss_avg < best_loss:
        best_loss = loss_avg
        print(f"Saving new best weights to {LRW_MODEL_SAVEPATH}")
        model.save_weights(
            filepath=LRW_MODEL_SAVEPATH,
            save_format="h5",
        )

Epoch 0:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.17984150350093842
Hyper Val Acc: 0.6625766871165644
Hyper Val Loss: 0.17292451858520508
Saving new best weights to ../../../models/adult/rs=55/adult_mt=lrw.h5
Epoch 1:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.019122552126646042
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.018387068063020706
Saving new best weights to ../../../models/adult/rs=55/adult_mt=lrw.h5
Epoch 2:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Saving new best weights to ../../../models/adult/rs=55/adult_mt=lrw.h5
Epoch 3:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 4:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 5:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 6:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 7:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 8:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 9:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 10:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 11:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 12:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 13:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 14:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 15:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 16:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 17:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 18:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 19:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 20:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 21:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 22:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 23:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 24:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 25:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 26:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 27:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 28:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 29:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 30:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 31:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 32:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 33:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 34:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 35:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 36:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 37:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 38:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 39:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 40:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 41:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 42:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 43:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 44:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 45:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 46:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 47:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 48:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 49:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 50:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 51:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 52:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 53:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 54:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 55:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 56:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 57:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 58:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 59:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 60:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 61:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 62:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 63:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 64:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 65:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 66:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 67:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 68:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 69:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 70:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 71:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 72:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 73:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0
Epoch 74:
----------


HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

Loss: 0.0
Hyper Val Acc: 0.6687116564417178
Hyper Val Loss: 0.0


In [36]:
model.load_weights(LRW_MODEL_SAVEPATH)

In [37]:
# hyper val acc
preds_hyper_val = utils.utils.compute_preds(
    model,
    x_hyper_val,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_hyper_val, axis=1) == np.argwhere(y_hyper_val)[:,1]).mean()

0.6687116564417178

In [38]:
# test acc
preds_test = utils.utils.compute_preds(
    model,
    x_test,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_test, axis=1) == np.argwhere(y_test)[:,1]).mean()

0.634722657474146

# Baseline 3: Kernel Mean Matching

Paper: https://papers.nips.cc/paper/2006/file/a2186aa7c086b46ad4e8bf81e2a3a19b-Paper.pdf

In [39]:
# the KMM algorithm does not scale well as inputs grow
# if we were to use the full train and test set, I don't even know
# how long it would take. Instead, we bunch the train into groups (randomly)
# of 2500 and apply KMM to the group and the full test set
# we then stitch together the estimated betas for the full train set
group_size = 2500
# these are a random arangement of indices
rand_inds = np.random.RandomState(seed=RANDOM_SEED).permutation( np.arange(len(x_train)) )
# these are the betas but ordered with respect to x_train
betas_ordered = np.zeros(len(x_train))

start_i = 0
end_i = start_i + group_size
while start_i < len(x_train):
    print(f"({start_i}-{end_i})")
    
    # grab the current group
    inds = rand_inds[start_i : end_i]
    
    kmm = utils.papers.KMM()
    # fit the group with the val
    betas = kmm.fit(x_train[inds], x_val)
    # fill in betas_ordered at the indices in the current group
    betas_ordered[inds] = betas.reshape(-1) # flatten
    
    start_i = end_i
    end_i = start_i + group_size

(0-2500)
     pcost       dcost       gap    pres   dres
 0: -1.5988e+07 -1.6112e+07  2e+06  1e-01  6e-16
 1: -1.5985e+07 -1.6077e+07  4e+05  2e-02  7e-16
 2: -1.5978e+07 -1.6026e+07  2e+05  1e-02  5e-16
 3: -1.5964e+07 -1.5974e+07  9e+04  5e-03  5e-16
 4: -1.5952e+07 -1.5950e+07  7e+04  3e-03  6e-16
 5: -1.5943e+07 -1.5937e+07  6e+04  2e-03  4e-16
 6: -1.5937e+07 -1.5932e+07  5e+04  2e-03  5e-16
 7: -1.5927e+07 -1.5921e+07  4e+04  1e-03  5e-16
 8: -1.5921e+07 -1.5917e+07  4e+04  1e-03  4e-16
 9: -1.5916e+07 -1.5914e+07  3e+04  8e-04  5e-16
10: -1.5913e+07 -1.5913e+07  3e+04  7e-04  6e-16
11: -1.5909e+07 -1.5911e+07  3e+04  5e-04  5e-16
12: -1.5906e+07 -1.5909e+07  2e+04  4e-04  4e-16
13: -1.5904e+07 -1.5907e+07  2e+04  3e-04  5e-16
14: -1.5901e+07 -1.5903e+07  7e+03  1e-04  5e-16
15: -1.5900e+07 -1.5901e+07  2e+03  2e-05  7e-16
16: -1.5900e+07 -1.5900e+07  5e+02  5e-06  6e-16
17: -1.5900e+07 -1.5900e+07  2e+02  4e-16  1e-15
18: -1.5900e+07 -1.5900e+07  2e+01  5e-16  2e-15
19: -1.5900e

 8: -1.5926e+07 -1.5920e+07  4e+04  1e-03  5e-16
 9: -1.5921e+07 -1.5918e+07  4e+04  1e-03  5e-16
10: -1.5917e+07 -1.5916e+07  4e+04  1e-03  5e-16
11: -1.5912e+07 -1.5912e+07  3e+04  7e-04  5e-16
12: -1.5905e+07 -1.5905e+07  2e+04  4e-04  6e-16
13: -1.5902e+07 -1.5903e+07  1e+04  3e-04  5e-16
14: -1.5898e+07 -1.5900e+07  7e+03  1e-04  5e-16
15: -1.5897e+07 -1.5897e+07  5e+02  5e-06  8e-16
16: -1.5897e+07 -1.5897e+07  3e+01  2e-07  2e-15
17: -1.5897e+07 -1.5897e+07  4e-01  3e-09  3e-15
Optimal solution found.
(22500-25000)
     pcost       dcost       gap    pres   dres
 0: -1.5988e+07 -1.6113e+07  2e+06  1e-01  1e-15
 1: -1.5985e+07 -1.6079e+07  4e+05  2e-02  6e-16
 2: -1.5978e+07 -1.6023e+07  2e+05  9e-03  6e-16
 3: -1.5963e+07 -1.5975e+07  8e+04  4e-03  6e-16
 4: -1.5951e+07 -1.5951e+07  6e+04  3e-03  5e-16
 5: -1.5942e+07 -1.5939e+07  4e+04  2e-03  5e-16
 6: -1.5937e+07 -1.5935e+07  4e+04  1e-03  6e-16
 7: -1.5932e+07 -1.5930e+07  3e+04  1e-03  5e-16
 8: -1.5928e+07 -1.5927e+07  3e+

In [40]:
# save betas for later analysis, if any
KMM_BETAS_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult_kmm_betas", ".npy")
np.save(
    file = KMM_BETAS_SAVEPATH,
    arr = betas_ordered,
)

In [41]:
# make a fresh model
model = tf.keras.models.Sequential([
    tf.keras.Input(shape=x_train.shape[1]),
    tf.keras.layers.Dense(2, activation=tf.nn.softmax),
])

In [42]:
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

In [43]:
# Need to save the best model by validation loss
KMM_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult", ".h5", mt="kmm")
save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=KMM_MODEL_SAVEPATH,
    monitor="val_loss",
    mode='min',
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
)

callbacks = [save_best]

In [44]:
model.fit(
    x_train,
    y_train,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    validation_data = (x_hyper_val, y_hyper_val),
    callbacks=callbacks,
)

Epoch 1/75
Epoch 00001: val_loss improved from inf to 0.75929, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 2/75
Epoch 00002: val_loss improved from 0.75929 to 0.74983, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 3/75
Epoch 00003: val_loss improved from 0.74983 to 0.74017, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 4/75
Epoch 00004: val_loss improved from 0.74017 to 0.73143, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 5/75
Epoch 00005: val_loss did not improve from 0.73143
Epoch 6/75
Epoch 00006: val_loss improved from 0.73143 to 0.72399, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 7/75
Epoch 00007: val_loss improved from 0.72399 to 0.72380, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 8/75
Epoch 00008: val_loss improved from 0.72380 to 0.72063, saving model to ../../../models/adult/rs=55/adult_mt=kmm.h5
Epoch 9/75
Epoch 00009: val_loss improved from 0.7206

Epoch 29/75
Epoch 00029: val_loss did not improve from 0.71600
Epoch 30/75
Epoch 00030: val_loss did not improve from 0.71600
Epoch 31/75
Epoch 00031: val_loss did not improve from 0.71600
Epoch 32/75
Epoch 00032: val_loss did not improve from 0.71600
Epoch 33/75
Epoch 00033: val_loss did not improve from 0.71600
Epoch 34/75
Epoch 00034: val_loss did not improve from 0.71600
Epoch 35/75
Epoch 00035: val_loss did not improve from 0.71600
Epoch 36/75
Epoch 00036: val_loss did not improve from 0.71600
Epoch 37/75
Epoch 00037: val_loss did not improve from 0.71600
Epoch 38/75
Epoch 00038: val_loss did not improve from 0.71600
Epoch 39/75
Epoch 00039: val_loss did not improve from 0.71600
Epoch 40/75
Epoch 00040: val_loss did not improve from 0.71600
Epoch 41/75
Epoch 00041: val_loss did not improve from 0.71600
Epoch 42/75
Epoch 00042: val_loss did not improve from 0.71600
Epoch 43/75
Epoch 00043: val_loss did not improve from 0.71600
Epoch 44/75
Epoch 00044: val_loss did not improve from 

Epoch 59/75
Epoch 00059: val_loss did not improve from 0.71600
Epoch 60/75
Epoch 00060: val_loss did not improve from 0.71600
Epoch 61/75
Epoch 00061: val_loss did not improve from 0.71600
Epoch 62/75
Epoch 00062: val_loss did not improve from 0.71600
Epoch 63/75
Epoch 00063: val_loss did not improve from 0.71600
Epoch 64/75
Epoch 00064: val_loss did not improve from 0.71600
Epoch 65/75
Epoch 00065: val_loss did not improve from 0.71600
Epoch 66/75
Epoch 00066: val_loss did not improve from 0.71600
Epoch 67/75
Epoch 00067: val_loss did not improve from 0.71600
Epoch 68/75
Epoch 00068: val_loss did not improve from 0.71600
Epoch 69/75
Epoch 00069: val_loss did not improve from 0.71600
Epoch 70/75
Epoch 00070: val_loss did not improve from 0.71600
Epoch 71/75
Epoch 00071: val_loss did not improve from 0.71600
Epoch 72/75
Epoch 00072: val_loss did not improve from 0.71600
Epoch 73/75
Epoch 00073: val_loss did not improve from 0.71600
Epoch 74/75
Epoch 00074: val_loss did not improve from 

<tensorflow.python.keras.callbacks.History at 0x7efc9451cc50>

In [45]:
model.load_weights(KMM_MODEL_SAVEPATH)

In [46]:
# hyper val acc
preds_hyper_val = utils.utils.compute_preds(
    model,
    x_hyper_val,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_hyper_val, axis=1) == np.argwhere(y_hyper_val)[:,1]).mean()

0.5828220858895705

In [47]:
# test acc
preds_test = utils.utils.compute_preds(
    model,
    x_test,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_test, axis=1) == np.argwhere(y_test)[:,1]).mean()

0.5839548730805391

# Baseline 4: Just Train on Validation Set

In [48]:
# make a fresh model
model = tf.keras.models.Sequential([
    tf.keras.Input(shape=x_val.shape[1]),
    tf.keras.layers.Dense(2, activation=tf.nn.softmax),
])

In [49]:
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

In [50]:
# Need to save the best model by validation loss
JV_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, "adult", ".h5", mt="jv")
save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=JV_MODEL_SAVEPATH,
    monitor="val_loss",
    mode='min',
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
)

callbacks = [save_best]

In [51]:
model.fit(
    x_val,
    y_val,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    validation_data = (x_hyper_val, y_hyper_val),
    callbacks=callbacks,
)

Epoch 1/75
Epoch 00001: val_loss improved from inf to 0.68141, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 2/75
Epoch 00002: val_loss improved from 0.68141 to 0.68006, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 3/75
Epoch 00003: val_loss improved from 0.68006 to 0.67877, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 4/75
Epoch 00004: val_loss improved from 0.67877 to 0.67752, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 5/75
Epoch 00005: val_loss improved from 0.67752 to 0.67634, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 6/75
Epoch 00006: val_loss improved from 0.67634 to 0.67520, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 7/75
Epoch 00007: val_loss improved from 0.67520 to 0.67409, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 8/75
Epoch 00008: val_loss improved from 0.67409 to 0.67302, saving model to ../../../models/adult/rs=55/adult_mt=jv

Epoch 26/75
Epoch 00026: val_loss improved from 0.65662 to 0.65572, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 27/75
Epoch 00027: val_loss improved from 0.65572 to 0.65483, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 28/75
Epoch 00028: val_loss improved from 0.65483 to 0.65395, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 29/75
Epoch 00029: val_loss improved from 0.65395 to 0.65307, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 30/75
Epoch 00030: val_loss improved from 0.65307 to 0.65219, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 31/75
Epoch 00031: val_loss improved from 0.65219 to 0.65132, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 32/75
Epoch 00032: val_loss improved from 0.65132 to 0.65046, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 33/75
Epoch 00033: val_loss improved from 0.65046 to 0.64960, saving model to ../../../models/adult/rs=55

Epoch 00050: val_loss improved from 0.63645 to 0.63567, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 51/75
Epoch 00051: val_loss improved from 0.63567 to 0.63491, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 52/75
Epoch 00052: val_loss improved from 0.63491 to 0.63415, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 53/75
Epoch 00053: val_loss improved from 0.63415 to 0.63339, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 54/75
Epoch 00054: val_loss improved from 0.63339 to 0.63264, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 55/75
Epoch 00055: val_loss improved from 0.63264 to 0.63190, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 56/75
Epoch 00056: val_loss improved from 0.63190 to 0.63116, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5
Epoch 57/75
Epoch 00057: val_loss improved from 0.63116 to 0.63043, saving model to ../../../models/adult/rs=55/adult_mt=jv

Epoch 75/75
Epoch 00075: val_loss improved from 0.61872 to 0.61808, saving model to ../../../models/adult/rs=55/adult_mt=jv.h5


<tensorflow.python.keras.callbacks.History at 0x7efc621dd910>

In [52]:
# see the best saved
model.load_weights(JV_MODEL_SAVEPATH)

In [53]:
# hyper val acc
preds_hyper_val = utils.utils.compute_preds(
    model,
    x_hyper_val,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_hyper_val, axis=1) == np.argwhere(y_hyper_val)[:,1]).mean()

0.6503067484662577

In [54]:
# test acc
preds_test = utils.utils.compute_preds(
    model,
    x_test,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds_test, axis=1) == np.argwhere(y_test)[:,1]).mean()

0.675963647759323