In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import os
import numpy as np
import tensorflow as tf
import keras.backend as K
import matplotlib.pyplot as plt
from sklearn import model_selection
from tqdm.notebook import tqdm
import umap
import time

BASE_DIR = '../../../'
import sys
sys.path.append(BASE_DIR)

# custom code
import utils.utils
CONFIG = utils.utils.load_config("../../config.json")
import utils.papers
import utils.custom_tf

Using TensorFlow backend.


In [3]:
DATASET = os.path.basename(os.getcwd()) # name of folder this file is in
RANDOM_SEED = CONFIG['random_seed']
EPOCHS = CONFIG["experiment_configs"][DATASET]["epochs"]
BATCH_SIZE = CONFIG["experiment_configs"][DATASET]["batch_size"]
IMAGE_X_SIZE = CONFIG["experiment_configs"][DATASET]["image_x_size"]
IMAGE_Y_SIZE = CONFIG["experiment_configs"][DATASET]["image_y_size"]
IMAGE_SIZE = (IMAGE_Y_SIZE, IMAGE_X_SIZE)
HYPER_VAL_SPLIT = CONFIG['experiment_configs'][DATASET]['hyper_val_split']

print(DATASET, RANDOM_SEED)

# folders for processed, models
DATA_F = os.path.join(BASE_DIR, f"data/{DATASET}/")
PROCESSED_DIR = os.path.join(BASE_DIR, f'processed/{DATASET}/rs={RANDOM_SEED}')
MODELS_DIR = os.path.join(BASE_DIR, f'models/{DATASET}/rs={RANDOM_SEED}')

BASE_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, DATASET, ".h5", mt="base") # mt = model_type

# base model saved here
if not os.path.exists(BASE_MODEL_SAVEPATH):
    print(f"warning: no model has been run for rs={RANDOM_SEED}")
    

adience 55


In [4]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "train"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    seed = RANDOM_SEED,
)

hyper_train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "hyper_train"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    seed = RANDOM_SEED,
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "val"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    seed = RANDOM_SEED,
)

hyper_val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "hyper_val"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    seed = RANDOM_SEED,
)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "test"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    seed = RANDOM_SEED,
)

Found 9795 files belonging to 2 classes.
Found 2449 files belonging to 2 classes.
Found 199 files belonging to 2 classes.
Found 200 files belonging to 2 classes.
Found 3585 files belonging to 2 classes.


In [5]:
'''
This will standardize the pixel data
'''
def preprocess(imgs, labels):
    # turn from <0..255> to <0..1>
    imgs = imgs / 255.0
    means = np.array( [0.5, 0.5, 0.5] )
    stds = np.array( [0.5, 0.5, 0.5] )
    imgs = (imgs - means) / stds
    return imgs, labels

In [6]:
train_ds = train_ds.map(preprocess)
hyper_train_ds = hyper_train_ds.map(preprocess)
val_ds = val_ds.map(preprocess)
hyper_val_ds = hyper_val_ds.map(preprocess)
test_ds = test_ds.map(preprocess)

# create a full validation set for baselines that need it
val_full_ds = val_ds.concatenate(hyper_val_ds)

In [7]:
model = utils.utils.make_resnet(
    depth=2,
    random_state=RANDOM_SEED,
    input_shape=(*IMAGE_SIZE, 3),
    nc=2,
)

model.load_weights(BASE_MODEL_SAVEPATH)

In [8]:
# hyper train acc
preds, labels = utils.utils.compute_preds(
    model,
    hyper_train_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 77/77 [00:18<00:00,  4.10it/s]


0.7897100857492855

In [9]:
# val acc
preds, labels = utils.utils.compute_preds(
    model,
    val_full_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 14/14 [00:03<00:00,  4.36it/s]


0.7493734335839599

In [10]:
# test acc
preds, labels = utils.utils.compute_preds(
    model,
    test_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 113/113 [00:23<00:00,  4.74it/s]


0.7403068340306834

# Baseline 1: Fine Tune
This is a very widely used technique in deep learning. The idea is simple: do a little bit more training on the validation set.


In [10]:
model.load_weights(BASE_MODEL_SAVEPATH)

In [11]:
optimizer = tf.keras.optimizers.SGD(lr=5e-6, momentum=0.9)

In [12]:
model.compile(optimizer, loss='categorical_crossentropy', metrics = ['accuracy'])

In [11]:
FT_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, DATASET, ".h5", mt="ft")

save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=FT_MODEL_SAVEPATH,
    monitor="val_loss",
    verbose=1,
    save_weights_only=True,
    save_best_only=True,
)


In [14]:
callbacks = [save_best]

In [15]:
model.fit(
        x=val_ds,
        epochs=EPOCHS,
        validation_data=hyper_val_ds,
        callbacks=callbacks,
)

Epoch 1/25
Epoch 00001: val_loss improved from inf to 0.55712, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 2/25
Epoch 00002: val_loss improved from 0.55712 to 0.55485, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 3/25
Epoch 00003: val_loss improved from 0.55485 to 0.54945, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 4/25
Epoch 00004: val_loss improved from 0.54945 to 0.54769, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 5/25
Epoch 00005: val_loss did not improve from 0.54769
Epoch 6/25
Epoch 00006: val_loss improved from 0.54769 to 0.54472, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 7/25
Epoch 00007: val_loss did not improve from 0.54472
Epoch 8/25
Epoch 00008: val_loss improved from 0.54472 to 0.54377, saving model to ../../../models/adience/rs=15/adience_mt=ft.h5
Epoch 9/25
Epoch 00009: val_loss improved from 0.54377 to 0.54228, saving model to ../../../models/a

<tensorflow.python.keras.callbacks.History at 0x7f24903bb350>

In [12]:
model.load_weights(FT_MODEL_SAVEPATH)

In [13]:
# hyper val acc
preds, labels = utils.utils.compute_preds(
    model,
    hyper_val_ds,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 13/13 [00:02<00:00,  4.63it/s]


0.7969924812030075

In [14]:
# test acc
preds, labels = utils.utils.compute_preds(
    model,
    test_ds,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 100/100 [00:22<00:00,  4.53it/s]


0.7806714778788829

# Baseline 2: Learn to Weigh Examples
Paper: https://arxiv.org/pdf/1803.09050.pdf

This is a type of meta-learning, which doesn't quite work with the keras API. We will need to manually implement the training loop.

In [8]:
model.load_weights(BASE_MODEL_SAVEPATH)

In [9]:
optimizer = tf.keras.optimizers.SGD(lr=5e-5, momentum=0.9)

In [10]:
# Reduction.NONE means the cross entropy is computed per entry in the batch
# but is not aggregated. Traditional cross entropy will average the results.
ce = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

In [11]:
def yield_batches_indef(ds):
    # keep yielding batches from `val_ds` indefinitely; loop around when finishing a dataset
    while True:
        for x_val, y_val in ds:
            yield x_val, y_val

In [12]:
val_batch_gen = yield_batches_indef(val_ds)

In [13]:
best_loss = float('inf')
LRW_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, DATASET, ".h5", mt="lrw")

for epoch in range(EPOCHS):
    # implements a train loop
    print(f"Epoch {epoch}:\n----------")
    
    total_batches = len(train_ds)
    batch_c = 0
    loss_sum = 0
    for batch, labels in tqdm(train_ds):
        x_val, y_val = next(val_batch_gen)
        # most of the details are abstracted away in this function
        loss = utils.papers.train_step(model, batch, labels, x_val, y_val, ce, optimizer)
        loss_sum += loss
        batch_c += 1
        # print ongoing avg loss
        print(f"Loss: {loss_sum / batch_c}", end='\r')
        
        # we need to tally this ourselves because the iterator simply restarts another epoch
        if batch_c >= total_batches:
            break
    
    # compute validation accuracy
    preds, labels = utils.utils.compute_preds(
        model,
        hyper_val_ds,
        batch_size=BATCH_SIZE,
    )
    val_acc = (np.argmax(preds, axis=1) == labels).mean()
    loss_avg = loss_sum / total_batches
    
    print(f"Hyper Val Acc: {val_acc}")
    print(f"Hyper Val Loss: {loss_avg}", end='\n\n')
        
    # implements save best logic
    if loss_avg < best_loss:
        best_loss = loss_avg
        print(f"Saving new best weights to {LRW_MODEL_SAVEPATH}")
        model.save_weights(
            filepath=LRW_MODEL_SAVEPATH,
            save_format="h5",
        )
        last_best_epoch = epoch
        
    # because this takes a long time (roughly 3x the normal train time)
    # we use early stopping 
    early_stop = 10
    if last_best_epoch + early_stop <= epoch:
        print(f"no improvement for {early_stop} epochs, ending training")
        break

Epoch 0:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.21759152412414555

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.21688276529312134


100%|██████████| 13/13 [00:03<00:00,  3.80it/s]

Hyper Val Acc: 0.7568922305764411
Hyper Val Loss: 0.21688276529312134

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 1:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.20383436977863312

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.20317041873931885


100%|██████████| 13/13 [00:02<00:00,  4.58it/s]


Hyper Val Acc: 0.7619047619047619
Hyper Val Loss: 0.20317041873931885

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 2:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.17989750206470497

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.17931151390075684


100%|██████████| 13/13 [00:02<00:00,  4.57it/s]


Hyper Val Acc: 0.7493734335839599
Hyper Val Loss: 0.17931151390075684

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 3:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.18651047348976135

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.18665733933448792


100%|██████████| 13/13 [00:02<00:00,  4.61it/s]

Hyper Val Acc: 0.7694235588972431
Hyper Val Loss: 0.18665733933448792

Epoch 4:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.19130821526050568

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.19103917479515076


100%|██████████| 13/13 [00:02<00:00,  4.65it/s]

Hyper Val Acc: 0.7769423558897243
Hyper Val Loss: 0.19103917479515076

Epoch 5:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.21711790561676025

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.21659691631793976


100%|██████████| 13/13 [00:02<00:00,  4.70it/s]

Hyper Val Acc: 0.7593984962406015
Hyper Val Loss: 0.21659691631793976

Epoch 6:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.19805613160133362

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.19741100072860718


100%|██████████| 13/13 [00:02<00:00,  4.53it/s]

Hyper Val Acc: 0.7669172932330827
Hyper Val Loss: 0.19741100072860718

Epoch 7:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.172115981578826945

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.17155534029006958


100%|██████████| 13/13 [00:02<00:00,  4.65it/s]


Hyper Val Acc: 0.7669172932330827
Hyper Val Loss: 0.17155534029006958

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 8:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.18839119374752045

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.1877775341272354


100%|██████████| 13/13 [00:02<00:00,  4.69it/s]

Hyper Val Acc: 0.7619047619047619
Hyper Val Loss: 0.1877775341272354

Epoch 9:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.17725113034248352

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.17681726813316345


100%|██████████| 13/13 [00:02<00:00,  4.40it/s]

Hyper Val Acc: 0.7493734335839599
Hyper Val Loss: 0.17681726813316345

Epoch 10:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.17868730425834656

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.17840944230556488


100%|██████████| 13/13 [00:02<00:00,  4.71it/s]

Hyper Val Acc: 0.7619047619047619
Hyper Val Loss: 0.17840944230556488

Epoch 11:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.19441524147987366

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.19378195703029633


100%|██████████| 13/13 [00:02<00:00,  4.74it/s]

Hyper Val Acc: 0.7593984962406015
Hyper Val Loss: 0.19378195703029633

Epoch 12:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.18723870813846588

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.1866288036108017


100%|██████████| 13/13 [00:02<00:00,  4.76it/s]

Hyper Val Acc: 0.7744360902255639
Hyper Val Loss: 0.1866288036108017

Epoch 13:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.17335519194602966

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.1727905124425888


100%|██████████| 13/13 [00:02<00:00,  4.55it/s]

Hyper Val Acc: 0.7543859649122807
Hyper Val Loss: 0.1727905124425888

Epoch 14:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.16737031936645508

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.16716289520263672


100%|██████████| 13/13 [00:02<00:00,  4.61it/s]


Hyper Val Acc: 0.7644110275689223
Hyper Val Loss: 0.16716289520263672

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 15:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.16387143731117249

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.16333766281604767


100%|██████████| 13/13 [00:02<00:00,  4.64it/s]


Hyper Val Acc: 0.7518796992481203
Hyper Val Loss: 0.16333766281604767

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 16:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.18798233568668365

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.18737001717090607


100%|██████████| 13/13 [00:02<00:00,  4.64it/s]

Hyper Val Acc: 0.7669172932330827
Hyper Val Loss: 0.18737001717090607

Epoch 17:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.15981453657150269

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.15929396450519562


100%|██████████| 13/13 [00:02<00:00,  4.73it/s]

Hyper Val Acc: 0.7619047619047619
Hyper Val Loss: 0.15929396450519562

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 18:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.16343134641647344

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.16366778314113617


100%|██████████| 13/13 [00:02<00:00,  4.62it/s]

Hyper Val Acc: 0.7644110275689223
Hyper Val Loss: 0.16366778314113617

Epoch 19:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.19355389475822456

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.19302715361118317


100%|██████████| 13/13 [00:02<00:00,  4.38it/s]

Hyper Val Acc: 0.7518796992481203
Hyper Val Loss: 0.19302715361118317

Epoch 20:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.16845244169235235

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.16790373623371124


100%|██████████| 13/13 [00:02<00:00,  4.70it/s]

Hyper Val Acc: 0.7543859649122807
Hyper Val Loss: 0.16790373623371124

Epoch 21:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.15247532725334167

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.15346217155456543


100%|██████████| 13/13 [00:02<00:00,  4.71it/s]


Hyper Val Acc: 0.7744360902255639
Hyper Val Loss: 0.15346217155456543

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 22:
----------


HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.15086673200130463

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.15037532150745392


100%|██████████| 13/13 [00:02<00:00,  4.71it/s]

Hyper Val Acc: 0.7769423558897243
Hyper Val Loss: 0.15037532150745392

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5
Epoch 23:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.16070169210433967

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.16017822921276093


100%|██████████| 13/13 [00:02<00:00,  4.60it/s]

Hyper Val Acc: 0.7769423558897243
Hyper Val Loss: 0.16017822921276093

Epoch 24:
----------





HBox(children=(FloatProgress(value=0.0, max=307.0), HTML(value='')))

Loss: 0.14117944240570068

  0%|          | 0/13 [00:00<?, ?it/s]

Loss: 0.14152872562408447


100%|██████████| 13/13 [00:02<00:00,  4.65it/s]


Hyper Val Acc: 0.7744360902255639
Hyper Val Loss: 0.14152872562408447

Saving new best weights to ../../../models/adience/rs=55/adience_mt=lrw.h5


In [14]:
model.load_weights(LRW_MODEL_SAVEPATH)

In [15]:
# val acc
preds, labels = utils.utils.compute_preds(
    model,
    hyper_val_ds,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 13/13 [00:02<00:00,  4.70it/s]


0.7744360902255639

In [16]:
# test acc
preds, labels = utils.utils.compute_preds(
    model,
    test_ds,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 100/100 [00:21<00:00,  4.59it/s]


0.7565108252274867

# Baseline 3: KMM
Paper: https://papers.nips.cc/paper/2006/file/a2186aa7c086b46ad4e8bf81e2a3a19b-Paper.pdf

In [17]:
TRAIN_IMGNET_PREDS = utils.utils.get_savepath(PROCESSED_DIR, "adience_imgnet_preds_train", ".npy")
VAL_IMGNET_PREDS = utils.utils.get_savepath(PROCESSED_DIR, "adience_imgnet_preds_val", ".npy")

x_train = np.load(TRAIN_IMGNET_PREDS)
x_val = np.load(VAL_IMGNET_PREDS)

In [18]:
x_train.shape, x_val.shape

((9795, 2048), (398, 2048))

In [19]:
n_neighbors = 10
dim = 10

umap_emb = umap.UMAP(
    n_neighbors=n_neighbors,
    min_dist=0.5, 
    n_components=dim,
    metric='euclidean',
    random_state=RANDOM_SEED,
)

start = time.time()

umap_emb.fit(x_train)

end = time.time()

print(f"took { np.round(end - start, decimals=2) } seconds")

took 47.68 seconds


In [20]:
x_train_emb = umap_emb.transform(x_train)
x_val_emb = umap_emb.transform(x_val)

In [21]:
x_train_emb.shape, x_val_emb.shape

((9795, 10), (398, 10))

In [22]:
# no longer needed, delete to save memory
del x_train, x_val, umap_emb

In [23]:
# NOTE: this will take a couple minutes
# the KMM algorithm does not scale well as inputs grow
# if we were to use the full train and test set, I don't even know
# how long it would take. Instead, we bunch the train into groups (randomly)
# and apply KMM to the group and the full test set
# we then stitch together the estimated betas for the full train set
group_size = 2500
rand_inds = np.random.RandomState(seed=RANDOM_SEED).permutation( np.arange(len(x_train_emb)) )
betas_ordered = np.zeros(len(x_train_emb))

start_i = 0
end_i = start_i + group_size
while start_i < len(x_train_emb):
    print(f"({start_i}-{end_i})")
    
    inds = rand_inds[start_i : end_i]
    
    kmm = utils.papers.KMM()
    betas = kmm.fit(x_train_emb[inds], x_val_emb)
    betas_ordered[inds] = betas.reshape(-1) # flatten
    
    start_i = end_i
    end_i = start_i + group_size

(0-2500)
     pcost       dcost       gap    pres   dres
 0: -1.2453e+09 -1.2454e+09  1e+06  9e-02  5e-16
 1: -1.2453e+09 -1.2454e+09  6e+05  3e-02  5e-16
 2: -1.2453e+09 -1.2452e+09  4e+05  2e-02  5e-16
 3: -1.2453e+09 -1.2450e+09  4e+05  1e-02  5e-16
 4: -1.2452e+09 -1.2449e+09  4e+05  1e-02  5e-16
 5: -1.2452e+09 -1.2447e+09  4e+05  1e-02  5e-16
 6: -1.2452e+09 -1.2446e+09  4e+05  1e-02  5e-16
 7: -1.2451e+09 -1.2445e+09  5e+05  1e-02  5e-16
 8: -1.2451e+09 -1.2444e+09  5e+05  1e-02  5e-16
 9: -1.2450e+09 -1.2443e+09  5e+05  1e-02  5e-16
10: -1.2450e+09 -1.2442e+09  6e+05  1e-02  4e-16
11: -1.2450e+09 -1.2442e+09  6e+05  9e-03  5e-16
12: -1.2449e+09 -1.2440e+09  6e+05  9e-03  5e-16
13: -1.2449e+09 -1.2439e+09  6e+05  8e-03  5e-16
14: -1.2448e+09 -1.2437e+09  7e+05  8e-03  5e-16
15: -1.2447e+09 -1.2436e+09  7e+05  7e-03  5e-16
16: -1.2446e+09 -1.2434e+09  7e+05  7e-03  5e-16
17: -1.2445e+09 -1.2433e+09  8e+05  6e-03  4e-16
18: -1.2444e+09 -1.2432e+09  8e+05  6e-03  4e-16
19: -1.2443e

15: -1.0488e+09 -1.0479e+09  6e+05  7e-03  4e-16
16: -1.0488e+09 -1.0478e+09  6e+05  6e-03  4e-16
17: -1.0487e+09 -1.0477e+09  7e+05  6e-03  4e-16
18: -1.0486e+09 -1.0476e+09  7e+05  5e-03  4e-16
19: -1.0485e+09 -1.0475e+09  7e+05  5e-03  4e-16
20: -1.0484e+09 -1.0475e+09  7e+05  5e-03  4e-16
21: -1.0483e+09 -1.0474e+09  7e+05  4e-03  4e-16
22: -1.0482e+09 -1.0473e+09  7e+05  4e-03  4e-16
23: -1.0481e+09 -1.0473e+09  7e+05  4e-03  4e-16
24: -1.0480e+09 -1.0472e+09  7e+05  3e-03  4e-16
25: -1.0479e+09 -1.0472e+09  7e+05  3e-03  4e-16
26: -1.0478e+09 -1.0471e+09  7e+05  3e-03  4e-16
27: -1.0477e+09 -1.0471e+09  7e+05  2e-03  4e-16
28: -1.0476e+09 -1.0471e+09  7e+05  2e-03  4e-16
29: -1.0475e+09 -1.0470e+09  6e+05  2e-03  4e-16
30: -1.0474e+09 -1.0470e+09  6e+05  2e-03  4e-16
31: -1.0474e+09 -1.0470e+09  6e+05  1e-03  4e-16
32: -1.0473e+09 -1.0470e+09  5e+05  1e-03  4e-16
33: -1.0472e+09 -1.0470e+09  5e+05  1e-03  4e-16
34: -1.0472e+09 -1.0469e+09  5e+05  9e-04  4e-16
35: -1.0471e+09 -1.0

In [24]:
train_df = utils.utils.load_sorted_df(PROCESSED_DIR, "train")

In [25]:
train_df['beta'] = betas_ordered
train_df.head()

Unnamed: 0,user_id,original_image,face_id,age,gender,beta
0,7464014@N04,10218534135_6c73e2982d_o.jpg,961,"(25, 32)",f,0.999999
100,10897942@N03,8403758902_a1d5ba65e7_o.jpg,636,"(25, 32)",f,0.999999
1000,113445054@N07,11764107793_5ec337a088_o.jpg,1325,"(25, 32)",f,1.0
1002,7398884@N04,8725912445_166a5ba9d1_o.jpg,1649,"(15, 20)",f,1.0
1003,11008464@N06,11345824903_e6355034f8_o.jpg,970,"(0, 2)",f,1.0


In [26]:
def train_paths_to_df_rows(train_paths):
    rns = []
    for path in train_paths:
        int_str = path.split('/')[-1].split('.')[0]
        try:
            rn = int( int_str )
        except:
            print(f"failed to convert {int_str}, path: {path}")
            
        rns.append(rn)
    return rns

In [27]:
# see docstring of this function for details
kmm_train_ds, kmm_train_paths = utils.custom_tf.image_dataset_from_directory(
    directory=os.path.join(PROCESSED_DIR, "train"),
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    label_mode='categorical',
    follow_links=True,
    shuffle=True,
    seed=RANDOM_SEED,
)

kmm_train_ds = kmm_train_ds.map(preprocess)

Found 9795 files belonging to 2 classes.


In [28]:
row_nums = train_paths_to_df_rows(kmm_train_paths)

# grab the sample weights corresponding to shuffled data
sample_weights = train_df.loc[row_nums]['beta'].values.reshape(-1, 1)
sample_weights = tf.convert_to_tensor(sample_weights, dtype=tf.float32)

In [29]:
# make fresh model
model = utils.utils.make_resnet(
    depth=2,
    random_state=RANDOM_SEED,
    input_shape=(*IMAGE_SIZE, 3),
    nc=2,
)

In [30]:
ce = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

In [31]:
def loss_with_sample_weight_and_tfds(loss_f, sample_weight, batch_size):
    '''
    Compute loss of `loss_f` using sample weights and a tf.data.Dataset input.
    Keras API does not allow sample weights with tf.data.Dataset.
    '''
    start_i = 0
    end_i = start_i + batch_size
    
    def loss_inner(y_true, y_pred):
        # each time this is called, we are at the next batch so we increment the indices
        nonlocal start_i, end_i
        
        batch_weight = sample_weight[start_i : end_i]
        loss = tf.math.reduce_mean( loss_f(y_true, y_pred) * batch_weight )
        start_i = end_i
        end_i = start_i + batch_size
        
        return loss
        
    return loss_inner

In [32]:
loss = loss_with_sample_weight_and_tfds(
    loss_f=ce,
    sample_weight=sample_weights,
    batch_size=BATCH_SIZE,
)

In [33]:
optimizer = tf.keras.optimizers.SGD(lr=1e-4, momentum=0.9)

In [34]:
model.compile(optimizer, loss=loss, metrics = ['accuracy'])

In [35]:
def scheduler(epoch):
    if epoch > 10:
        return 5e-5
    else:
        return 1e-4

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

KMM_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, DATASET, ".h5", mt="kmm")
save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=KMM_MODEL_SAVEPATH,
    monitor="val_loss",
    verbose=1,
    save_weights_only=True,
    save_best_only=True,
)

callbacks = [lr_scheduler, save_best]

In [36]:
history = model.fit(
    x=kmm_train_ds,
    epochs=EPOCHS,
    validation_data=hyper_val_ds,
    verbose=1,
    callbacks=callbacks,
)

Epoch 1/25
Epoch 00001: val_loss improved from inf to 1.11998, saving model to ../../../models/adience/rs=55/adience_mt=kmm.h5
Epoch 2/25
Epoch 00002: val_loss improved from 1.11998 to 0.76467, saving model to ../../../models/adience/rs=55/adience_mt=kmm.h5
Epoch 3/25
Epoch 00003: val_loss did not improve from 0.76467
Epoch 4/25
Epoch 00004: val_loss improved from 0.76467 to 0.72075, saving model to ../../../models/adience/rs=55/adience_mt=kmm.h5
Epoch 5/25
Epoch 00005: val_loss did not improve from 0.72075
Epoch 6/25
Epoch 00006: val_loss did not improve from 0.72075
Epoch 7/25
Epoch 00007: val_loss did not improve from 0.72075
Epoch 8/25
Epoch 00008: val_loss did not improve from 0.72075
Epoch 9/25
Epoch 00009: val_loss did not improve from 0.72075
Epoch 10/25
Epoch 00010: val_loss did not improve from 0.72075
Epoch 11/25
Epoch 00011: val_loss improved from 0.72075 to 0.68409, saving model to ../../../models/adience/rs=55/adience_mt=kmm.h5
Epoch 12/25
Epoch 00012: val_loss did not im

In [37]:
# load best model
model.load_weights(KMM_MODEL_SAVEPATH)

In [38]:
# val acc
preds, labels = utils.utils.compute_preds(
    model,
    hyper_val_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 13/13 [00:03<00:00,  3.94it/s]


0.7568922305764411

In [39]:
# test acc
preds, labels = utils.utils.compute_preds(
    model,
    test_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 100/100 [00:21<00:00,  4.60it/s]


0.7455287103859429

# Baseline 4: Just Train on Validation Set

In [40]:
# make a fresh instance
model = utils.utils.make_resnet(
    depth=2,
    random_state=RANDOM_SEED,
    input_shape=(*IMAGE_SIZE, 3),
    nc=2,
)

In [41]:
model.summary()

Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 256, 256, 16) 448         input_3[0][0]                    
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 256, 256, 16) 64          conv2d_30[0][0]                  
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 256, 256, 16) 0           batch_normalization_26[0][0]     
_______________________________________________________________________________________

In [42]:
optimizer = tf.keras.optimizers.SGD(lr=1e-4, momentum=0.9)

In [43]:
model.compile(optimizer, loss='categorical_crossentropy', metrics = ['accuracy'])

In [44]:
def scheduler(epoch):
    if epoch > 10:
        return 5e-5
    else:
        return 1e-4

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

JV_MODEL_SAVEPATH = utils.utils.get_savepath(MODELS_DIR, DATASET, ".h5", mt="jv")
save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=JV_MODEL_SAVEPATH,
    monitor="val_loss",
    verbose=1,
    save_weights_only=True,
    save_best_only=True,
)

callbacks = [lr_scheduler, save_best]

In [45]:
model.fit(
    x=val_ds,
    epochs=EPOCHS,
    validation_data=hyper_val_ds,
    callbacks=callbacks,
)

Epoch 1/25
Epoch 00001: val_loss improved from inf to 2.22663, saving model to ../../../models/adience/rs=55/adience_mt=jv.h5
Epoch 2/25
Epoch 00002: val_loss improved from 2.22663 to 1.13196, saving model to ../../../models/adience/rs=55/adience_mt=jv.h5
Epoch 3/25
Epoch 00003: val_loss did not improve from 1.13196
Epoch 4/25
Epoch 00004: val_loss did not improve from 1.13196
Epoch 5/25
Epoch 00005: val_loss did not improve from 1.13196
Epoch 6/25
Epoch 00006: val_loss improved from 1.13196 to 1.11807, saving model to ../../../models/adience/rs=55/adience_mt=jv.h5
Epoch 7/25
Epoch 00007: val_loss did not improve from 1.11807
Epoch 8/25
Epoch 00008: val_loss did not improve from 1.11807
Epoch 9/25
Epoch 00009: val_loss improved from 1.11807 to 1.03408, saving model to ../../../models/adience/rs=55/adience_mt=jv.h5
Epoch 10/25
Epoch 00010: val_loss did not improve from 1.03408
Epoch 11/25
Epoch 00011: val_loss did not improve from 1.03408
Epoch 12/25
Epoch 00012: val_loss improved from 

<tensorflow.python.keras.callbacks.History at 0x7f123ed42ed0>

In [46]:
model.load_weights(JV_MODEL_SAVEPATH)

In [47]:
# val acc
preds, labels = utils.utils.compute_preds(
    model,
    hyper_val_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 13/13 [00:03<00:00,  3.93it/s]


0.6691729323308271

In [48]:
# test acc
preds, labels = utils.utils.compute_preds(
    model,
    test_ds,
    batch_size=BATCH_SIZE,
)
(np.argmax(preds, axis=1) == labels).mean()

100%|██████████| 100/100 [00:21<00:00,  4.60it/s]


0.7003451521807342