In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm_notebook as tqdm

from dyn_fed.data.mnist import MNist
import dyn_fed as df

%load_ext autoreload
%autoreload
%config Completer.use_jedi=False

In [2]:
tf.executing_eagerly()

True

In [3]:
X_train, y_train, X_test, y_test = df.data.mnist.load_data()

In [4]:
X_train.shape

(60000, 28, 28)

In [22]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Non customizable training loop way

In [23]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    # tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['sparse_categorical_accuracy'])

In [24]:
%%time
model.fit(train_dataset, epochs=10)

Train for 469 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
CPU times: user 11.1 s, sys: 1.81 s, total: 12.9 s
Wall time: 6.71 s


<tensorflow.python.keras.callbacks.History at 0x120a03be0>

In [25]:
model.evaluate(test_dataset)



[1.7410907488835008, 0.8446]

Customizable training loop

In [26]:
# Define logistic regression model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    # tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

# Define optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Define loss function
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

epoch_loss_avg = tf.keras.metrics.Mean()
epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

In [27]:
@tf.function
def train_loop(x, y):

    # Calculate gradients
    with tf.GradientTape() as t:
        # training=training is needed only if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(x, training=True)
        loss = loss_func(y, predictions)

    grads = t.gradient(loss, model.trainable_variables)

    # Optimize the model
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    # Track progress
    epoch_loss_avg(loss)

    # Compare predicted label to actual
    epoch_accuracy.update_state(y, predictions)
    
#     return loss, predictions

In [28]:
%%time
train_loss_results = []
train_accuracy_results = []
epochs = 50
n_batches = len(list(train_dataset))

for epoch in tqdm(np.arange(epochs)):
    
    for x, y in tqdm(train_dataset, total=n_batches):
        train_loop(x, y)

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())
    
    print(
        "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(
            epoch,
            epoch_loss_avg.result(),
            epoch_accuracy.result()
        )
    )
    
    # Clear the current state of the metrics
    epoch_loss_avg.reset_states()
    epoch_accuracy.reset_states()
    # valid_loss.reset_states(), valid_acc.reset_states()
    
    

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 000: Loss: 2.170, Accuracy: 46.763%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 001: Loss: 2.010, Accuracy: 73.305%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 002: Loss: 1.924, Accuracy: 78.222%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 003: Loss: 1.872, Accuracy: 80.215%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 004: Loss: 1.837, Accuracy: 81.267%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 005: Loss: 1.811, Accuracy: 81.920%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 006: Loss: 1.792, Accuracy: 82.443%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 007: Loss: 1.777, Accuracy: 82.923%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 008: Loss: 1.764, Accuracy: 83.208%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 009: Loss: 1.754, Accuracy: 83.467%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 010: Loss: 1.745, Accuracy: 83.707%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 011: Loss: 1.737, Accuracy: 83.972%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 012: Loss: 1.730, Accuracy: 84.188%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 013: Loss: 1.724, Accuracy: 84.402%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 014: Loss: 1.718, Accuracy: 84.575%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 015: Loss: 1.714, Accuracy: 84.697%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 016: Loss: 1.709, Accuracy: 84.863%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 017: Loss: 1.705, Accuracy: 84.998%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 018: Loss: 1.701, Accuracy: 85.133%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 019: Loss: 1.698, Accuracy: 85.263%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 020: Loss: 1.694, Accuracy: 85.377%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 021: Loss: 1.691, Accuracy: 85.523%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 022: Loss: 1.688, Accuracy: 85.625%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 023: Loss: 1.686, Accuracy: 85.713%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 024: Loss: 1.683, Accuracy: 85.790%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 025: Loss: 1.681, Accuracy: 85.870%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 026: Loss: 1.678, Accuracy: 85.985%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 027: Loss: 1.676, Accuracy: 86.052%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 028: Loss: 1.674, Accuracy: 86.135%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 029: Loss: 1.672, Accuracy: 86.198%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 030: Loss: 1.670, Accuracy: 86.280%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 031: Loss: 1.669, Accuracy: 86.348%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 032: Loss: 1.667, Accuracy: 86.402%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 033: Loss: 1.665, Accuracy: 86.473%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 034: Loss: 1.664, Accuracy: 86.512%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 035: Loss: 1.662, Accuracy: 86.553%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 036: Loss: 1.661, Accuracy: 86.592%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 037: Loss: 1.660, Accuracy: 86.613%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 038: Loss: 1.658, Accuracy: 86.648%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 039: Loss: 1.657, Accuracy: 86.695%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 040: Loss: 1.656, Accuracy: 86.748%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 041: Loss: 1.654, Accuracy: 86.807%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 042: Loss: 1.653, Accuracy: 86.842%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 043: Loss: 1.652, Accuracy: 86.895%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 044: Loss: 1.651, Accuracy: 86.922%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 045: Loss: 1.650, Accuracy: 86.950%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 046: Loss: 1.649, Accuracy: 86.973%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 047: Loss: 1.648, Accuracy: 87.012%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 048: Loss: 1.647, Accuracy: 87.047%


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 049: Loss: 1.646, Accuracy: 87.078%

CPU times: user 35.8 s, sys: 6.43 s, total: 42.2 s
Wall time: 30.3 s


## Other

In [265]:
import functools

In [269]:
def repeat(num_times):
    def decorator_repeat(func):
        @functools.wraps(func)
        def wrapper_repeat(*args, **kwargs):
            for _ in range(num_times):
                print("Something")
                value = func(*args, **kwargs)
            return value
        return wrapper_repeat
    return decorator_repeat

In [270]:
@repeat(num_times=4)
def greet(name):
    print(f"Hello {name}")

In [271]:
greet("Sashlin")

Something
Hello Sashlin
Something
Hello Sashlin
Something
Hello Sashlin
Something
Hello Sashlin


In [329]:
class config(dict):
    MARKER = object()

    def __init__(self, value=None):
        if value is None:
            pass
        elif isinstance(value, dict):
            for key in value:
                self.__setitem__(key, value[key])
        else:
            raise TypeError('expected dict')

    def __setitem__(self, key, value):
        if isinstance(value, dict) and not isinstance(value, config):
            value = config(value)
        super(config, self).__setitem__(key, value)

    def __getitem__(self, key):
        found = self.get(key, config.MARKER)
        if found is config.MARKER:
            found = config()
            super(config, self).__setitem__(key, found)
        return found

    __setattr__, __getattr__ = __setitem__, __getitem__