In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import pickle

import matplotlib.pyplot as plt
import seaborn as sns

import sys; sys.path.append('../../PGDL/sample_code_submission/')
from internal_rep.matrix_funcs import \
    get_KF_Schatten_norms, \
    compute_complexity, \
    get_df_tau, \
    evalues_from_regions, \
    get_local_rad_bound

In [4]:
def sample_dataset(x, y, n_samples, seed=None):
    if seed is not None:
        np.random.seed(seed)
    else:
        np.random.seed(int(time.time()))

    indices = np.random.choice(range(x.shape[0]), size=n_samples, replace=False)
    x = x[indices]
    y = y[indices]
    return x, y

def prepare_mnist_dataset(
    batch_size=64,
    train_sample_size=None,
    test_sample_size=None,
    seed=None,
    shuffle=True,
    shuffle_label_frac=None
):
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()#path='/mnt/ssd3/ronan/tensorflow_datasets/')
    y_train = y_train.astype('int32')
    y_test = y_test.astype('int32')

    if train_sample_size is not None:
        x_train, y_train = sample_dataset(x_train, y_train, int(train_sample_size), seed)
        if shuffle:
            shuffle_indices = np.arange(train_sample_size)
            np.random.seed(seed)
            np.random.shuffle(shuffle_indices)
            x_train = x_train[shuffle_indices]
            y_train = y_train[shuffle_indices]
        if shuffle_label_frac is not None:
            n_shuffle = int(shuffle_label_frac * train_sample_size)
            y_shuffle = y_train[:n_shuffle]
            np.random.seed(seed)
            np.random.shuffle(y_shuffle)
            y_train[:n_shuffle] = y_shuffle

    if test_sample_size is not None:
        x_test, y_test = sample_dataset(x_test, y_test, int(test_sample_size), seed)

    x_train, x_test = tf.cast(x_train, tf.float32) / 255.0, tf.cast(x_test, tf.float32) / 255.0
    print(f"x_train.shape={x_train.shape} y_train={y_train.shape} "
                 f"x_test.shape={x_test.shape} y_test={y_test.shape}")

    shuffle_buffer = 1000
    prefetch_buffer = 1000
    dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    # dataset = dataset.shuffle(shuffle_buffer)
    dataset_train = dataset_train.prefetch(prefetch_buffer).batch(batch_size)
    dataset_test = dataset_test.prefetch(prefetch_buffer).batch(batch_size)
    
    return dataset_train, dataset_test # , (x_train, y_train), (x_test, y_test)

In [5]:
def fit_model(model, n_epochs, optimizer, loss_fn, metric, metric_dict={}, save_path=None, epoch_save=5):
    @tf.function
    def _train_step(x, y):
        with tf.GradientTape() as tape:
            logits = model(x, training=True)
            loss_value = loss_fn(y, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric
        metric.update_state(y, logits)
        return loss_value

    # @tf.function
    def _test_step(x, y):
        logits = model(x, training=False)
        # Update val metrics
        metric.update_state(y, logits)
        return logits, y

    model_results = defaultdict(list)
    model_results['epochs'] = n_epochs
    
    train_eval_mat = []
    test_eval_mat = []
    
    for epoch in range(n_epochs):
        losses = []
        for step, (x_batch, y_batch) in enumerate(ds_train):
            loss_value = _train_step(x_batch, y_batch)
            losses.append(loss_value)
        model_results['train_loss'].append(np.mean(losses))
            
        # Display metrics at the end of each epoch.
        train_acc = metric.result()
        metric.reset_states()
        
        model_results['train_accuracy'].append(train_acc)
        
        # Run a validation loop at the end of each epoch.
        logit_list = []
        y_list = []
        for x_batch, y_batch in ds_test:
            logits, y = _test_step(x_batch, y_batch)
            logit_list.append(logits)
            y_list.append(y)
        y_list = tf.concat(y_list, axis=0)
        logit_list = tf.concat(logit_list, axis=0)
        model_results['test_ece'].append(tfp.stats.expected_calibration_error(10, logit_list, y_list))
        
            
        test_acc = metric.result()
        metric.reset_states()
        model_results['test_accuracy'].append(test_acc)
        
        # Get the TRAIN complexity at the end of each epoch
        for label, ds in zip(('train', 'test'), (ds_train, ds_test)):
            internal_rep = []
            for x_batch, y_batch in ds:
                for layer in model.layers[:-1]:
                    x_batch = layer(x_batch)
                internal_rep.append((x_batch.numpy() > 0).astype('bool'))
            internal_rep = np.vstack(internal_rep)
            evalues = evalues_from_regions(internal_rep)
            if epoch % epoch_save == 0:
                if label == 'train':
                    train_eval_mat.append(evalues)
                elif label == 'test':
                    test_eval_mat.append(evalues)
                else:
                    raise ValueError()
                
            h_star, h_argmin = get_local_rad_bound(evalues, from_evalues=True)
            model_results[f'h*_{label}'].append(h_star)
            model_results[f'h_argmin_{label}'].append(h_argmin)
            model_results[f'n_activated_regions_{label}'].append(sum(evalues > 0))
        
        # Get the TEST complexity at the end of each epoch

        
        if epoch % epoch_save == 0:
            print(f"Epoch {epoch}: Training acc={train_acc:.3f}, Validation acc={test_acc:.3f}")
        
    if save_path is not None:
        np.save(save_path + '_train_evalues.npy', train_eval_mat)
        np.save(save_path + '_test_evalues.npy', test_eval_mat)
        with open(save_path + '_results_dict.pkl', 'wb') as f:
            pickle.dump(model_results, f, protocol=pickle.HIGHEST_PROTOCOL)
        
    return model, model_results

## Run and save models, varying width

In [7]:
n_train_sample = 1000
n_units_list = [4, 8, 16, 32, 48, 64, 80, 100]
n_epochs = 200

for n_units in n_units_list:
    total_params = (28 * 28 + 1) * n_units + (n_units + 1) * 10
    
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),#input_shape=(28, 28)),
        tf.keras.layers.Dense(n_units, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    
    # model = tf.keras.models.clone_model(model_base)
    
    optimizer = tf.keras.optimizers.Adam(0.001)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metric = tf.keras.metrics.SparseCategoricalAccuracy()
    
    ds_train, ds_test = prepare_mnist_dataset(
        train_sample_size=n_train_sample, test_sample_size=10000,
        seed=0)
    
    model, model_results = fit_model(
        model, n_epochs, optimizer, loss_fn, metric,
        save_path=f'./width_results/relu_net_units={n_units}')

x_train.shape=(1000, 28, 28) y_train=(1000,) x_test.shape=(10000, 28, 28) y_test=(10000,)
Epoch 0: Training acc=0.150, Validation acc=0.203
Epoch 5: Training acc=0.427, Validation acc=0.416
x_train.shape=(1000, 28, 28) y_train=(1000,) x_test.shape=(10000, 28, 28) y_test=(10000,)
Epoch 0: Training acc=0.170, Validation acc=0.283
Epoch 5: Training acc=0.560, Validation acc=0.627
x_train.shape=(1000, 28, 28) y_train=(1000,) x_test.shape=(10000, 28, 28) y_test=(10000,)


KeyboardInterrupt: 