# Introduction

The purpose of this notebook is to explore the idea of dataset distillation as presented in this [Google AI Blogpost](https://ai.googleblog.com/2021/12/training-machine-learning-models-more.html). We will build a simple CNN and compare its accuracy and training times on:
1. the whole dataset (support size=33600)
2. a tiny subset (support size=10)
3. a tiny distilled subset (support size=10)
4. a small subset (support size=500)
5. a small distilled subset (support size=500)

TL;DR A summary of the results are at the end.

# Libraries

In [None]:
# Core
import numpy as np
np.random.seed(0)
import pandas as pd
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.4)
import matplotlib.pyplot as plt
%matplotlib inline
import time

# Sklearn
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score

# Tensorflow
import tensorflow as tf
tf.random.set_seed(0)
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import callbacks
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers.experimental import preprocessing
from keras.utils.vis_utils import plot_model

# Data

In [None]:
# Load data
mnist_data=pd.read_csv('../input/digit-recognizer/train.csv')

# Print shape + head
print('Training dataframe dimensions:',mnist_data.shape)
mnist_data.head()

**Preview first few images**

In [None]:
# Figure size
plt.figure(figsize=(9,9))
plt.suptitle('Training set images', fontsize=20)

# Subplot 
for i in range(9):
    img = np.asarray(mnist_data.iloc[i,1:].values.reshape((28,28))/255);
    ax=plt.subplot(3, 3, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.title.set_text(f'{mnist_data.iloc[i,0]}')
    plt.imshow(img, cmap='gray')
    
plt.show()

**Label distribution**

In [None]:
# Figure size
plt.figure(figsize=(10,5))

# Countplot
sns.countplot(x='label', data=mnist_data)
plt.title('Distribution of digits in dataset')

The distribution is balanced.

# Preprocessing

**Labels and features**

In [None]:
# Labels
y=mnist_data.label

# Scale features to be in [0,1]
X=mnist_data.drop('label', axis=1)/255

# Reshape (-1 means unspecified)
X = X.values.reshape(-1, 28, 28, 1)

# One-hot encode target
y=pd.get_dummies(y)

**Train test split**

In [None]:
# Create a validation set
X_train, X_valid, y_train, y_valid = train_test_split(X,y,train_size=0.8, test_size=0.2,
                                                      stratify=y, random_state=0)

# Model

In [None]:
# Parameters
BATCH_SIZE=250
EPOCHS=50

# Define model
def build_model():
    model = keras.Sequential([

        # Convolutional layer 1
        layers.Conv2D(filters=64, kernel_size=5, strides=1, padding='same',
                      input_shape=[28,28,1], activation='relu'),
        layers.MaxPool2D(pool_size=2, padding='same'),
        layers.Dropout(rate=0.4),

        # Convolutional layer 2
        layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='same',
                      activation='relu'),
        layers.MaxPool2D(pool_size=4, padding='same'),
        layers.Dropout(rate=0.4),
        layers.Flatten(),

        # Hidden layer 3
        layers.Dense(units=256, activation='relu'),
        layers.Dropout(rate=0.4),

        # Output layer (softmax returns a probability distribution)
        layers.Dense(units=10, activation='softmax')
    ])

    # Define optimizer, loss function and accuracy metric
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['categorical_accuracy'])
    
    return model

# Training

In [None]:
# Define model
model=build_model()

# Measure training time
start=time.time()

# Train model
history = model.fit(
    X_train, y_train,
    validation_data=(X_valid, y_valid),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    verbose=True)

stop=time.time()

**Learning curves**

In [None]:
# Convert to dataframe
history_df = pd.DataFrame(history.history)

# Subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,4))

# Plot loss metric
plt.subplot(1,2,1)
ax=history_df.reset_index().loc[:, ['loss', 'val_loss']].plot(title="Cross-entropy", ax=axes[0])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

# Plot accuracy metric
plt.subplot(1,2,2)
ax=history_df.reset_index().loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Accuracy", ax=axes[1])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

**Evaluate**

In [None]:
# Predictions
preds=np.argmax(model.predict(X_valid), axis=1)

# Confidence
conf=np.max(model.predict(X_valid), axis=1)

# Final accuracy and time
score1=accuracy_score(np.argmax(y_valid.values, axis=1), preds)
time1=np.round(stop-start,1)
print(f'Final accuracy on validation set:{np.round(100*score1,1)}%')
print(f'Training time: {time1} secs')

**Plot predictions**

In [None]:
# Plot some model predictions
plt.figure(figsize=(15,4))
plt.suptitle('Model predictions', fontsize=20, y=1.05)

# Subplot
for i in range(20):
    img = X_valid[i];
    ax=plt.subplot(2, 10, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f'Pred:{preds[i]} \n Conf:{np.round(100*conf[i],1)}', fontdict = {'fontsize':14})
    plt.imshow(img, cmap='gray')
    
plt.show()

# Dataset distillation (support size=10)

The idea is instead of optimising our network, we optimise our dataset - reducing it as much as possible, whilst keeping the key information. The implication is that it could massively reduce training times and hence reduce energy demand/gpu time. 

Blog: [https://ai.googleblog.com/2021/12/training-machine-learning-models-more.html](https://ai.googleblog.com/2021/12/training-machine-learning-models-more.html).

**Baseline**

We choose 1 sample for each digit. No distilation has occured yet.

In [None]:
# Training set with only 1 sample for each digit
mnist_10=mnist_data.drop_duplicates('label').sort_values('label')

# Plot entire training set
plt.figure(figsize=(15,6))
plt.suptitle('Entire training set', fontsize=20, y=1.02)

# Subplot
for i in range(10):
    img = np.asarray(mnist_10.iloc[i,1:].values.reshape((28,28))/255);
    ax=plt.subplot(2, 5, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.title.set_text(f'{mnist_10.iloc[i,0]}')
    plt.imshow(img, cmap='gray')
    
plt.show()

**Train model**

In [None]:
# Define model
model=build_model()

# Measure training time
start=time.time()

# Train model
history = model.fit(
    mnist_10.iloc[:,1:].values.reshape(-1, 28, 28, 1), pd.get_dummies(mnist_10.label),
    validation_data=(X_valid, y_valid),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    verbose=True)

stop = time.time()

**Learning curves**

In [None]:
# Convert to dataframe
history_df = pd.DataFrame(history.history)

# Subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,4))

# Plot loss metric
plt.subplot(1,2,1)
ax=history_df.reset_index().loc[:, ['loss', 'val_loss']].plot(title="Cross-entropy", ax=axes[0])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

# Plot accuracy metric
plt.subplot(1,2,2)
ax=history_df.reset_index().loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Accuracy", ax=axes[1])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

**Evaluate**

In [None]:
# Predictions
preds=np.argmax(model.predict(X_valid), axis=1)

# Confidence
conf=np.max(model.predict(X_valid), axis=1)

# Final accuracy and time
score2=accuracy_score(np.argmax(y_valid.values, axis=1), preds)
time2=np.round(stop-start,1)
print(f'Final accuracy on validation set:{np.round(100*score2,1)}%')
print(f'Training time: {time2} secs')

**Plot predictions**

In [None]:
# Plot some model predictions
plt.figure(figsize=(15,4))
plt.suptitle('Model predictions', fontsize=20, y=1.05)

# Subplot
for i in range(20):
    img = X_valid[i];
    ax=plt.subplot(2, 10, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f'Pred:{preds[i]} \n Conf:{np.round(100*conf[i],1)}', fontdict = {'fontsize':14})
    plt.imshow(img, cmap='gray')
    
plt.show()

**Distillation - KIP**

The code below is based off of the following github repo.

GitHub: [https://github.com/google-research/google-research/tree/master/kip](https://github.com/google-research/google-research/tree/master/kip).

**Plot distillation**

These images have been created using an algorithm called KIP (Kernel Inducing Points). Broadly speaking, it works by optimising a loss function arising from Kernel Regression.

In [None]:
# Load KIP data (1 image per class)
with tf.io.gfile.GFile('gs://kip-datasets/kip/mnist/ConvNet_ssize10_nozca_l_noaug_ckpt78.npz', 'rb') as f:
    npz = np.load(f)

# Linear projection onto [0,1]
npz_normalised=npz['images'].copy()
for i in range(10):
    npz_normalised[i] = (npz['images'][i]-npz['images'][i].min())/(npz['images'][i].max()-npz['images'][i].min())

# Plot entire training set
plt.figure(figsize=(15,6))
plt.suptitle('Entire distilled training set', fontsize=20, y=1.02)

# Subplot
for i in range(10):
    img = npz_normalised[i];
    ax=plt.subplot(2, 5, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.title.set_text(f'{i}')
    plt.imshow(img, cmap='gray')
    
plt.show()

**Plot labels**

Labels can also be optimised as well as the images to maximise the performance at no extra training cost. 

In [None]:
plt.figure(figsize=(14,5.5))

ax1=plt.subplot(1, 2, 1)
ax1.title.set_text('Untrained labels')
sns.heatmap(pd.get_dummies(np.arange(10)), cmap='magma')

ax2=plt.subplot(1, 2, 2)
ax2.title.set_text('Trained labels (unscaled)')
sns.heatmap(npz['labels'], cmap='magma')

**Train model**

In [None]:
# Define model
model=build_model()

# Measure training time
start=time.time()

# Train model
history = model.fit(
    npz_normalised, pd.get_dummies(np.arange(10)),#npz['labels'],
    validation_data=(X_valid, y_valid),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    verbose=True)

stop = time.time()

**Learning curves**

In [None]:
# Convert to dataframe
history_df = pd.DataFrame(history.history)

# Subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,4))

# Plot loss metric
plt.subplot(1,2,1)
ax=history_df.reset_index().loc[:, ['loss', 'val_loss']].plot(title="Cross-entropy", ax=axes[0])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

# Plot accuracy metric
plt.subplot(1,2,2)
ax=history_df.reset_index().loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Accuracy", ax=axes[1])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

**Evaluation**

In [None]:
# Predictions
preds=np.argmax(model.predict(X_valid), axis=1)

# Confidence
conf=np.max(model.predict(X_valid), axis=1)

# Final accuracy and time
score3=accuracy_score(np.argmax(y_valid.values, axis=1), preds)
time3=np.round(stop-start,1)
print(f'Final accuracy on validation set:{np.round(100*score3,1)}%')
print(f'Training time: {time3} secs')

**Plot predictions**

In [None]:
# Plot some model predictions
plt.figure(figsize=(15,4))
plt.suptitle('Model predictions', fontsize=20, y=1.05)

# Subplot
for i in range(20):
    img = X_valid[i];
    ax=plt.subplot(2, 10, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f'Pred:{preds[i]} \n Conf:{np.round(100*conf[i],1)}', fontdict = {'fontsize':14})
    plt.imshow(img, cmap='gray')
    
plt.show()

# Dataset distillation (support size=500)

This time each of the 10 classes gets 50 training samples. 

**Baseline**

No dataset distillation happens here.

In [None]:
# Training set with 50 samples for each digit
mnist_500=pd.DataFrame([])
for i in range(10):
    mnist_label_i=mnist_data[mnist_data.label==i]
    mnist_500=mnist_500.append(mnist_label_i.iloc[:50,:])
    
mnist_500=mnist_500.reset_index(drop=True)

**Train model**

In [None]:
# Define model
model=build_model()

# Measure training time
start=time.time()

# Train model
history = model.fit(
    mnist_500.iloc[:,1:].values.reshape(-1, 28, 28, 1), pd.get_dummies(mnist_500.label),
    validation_data=(X_valid, y_valid),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    verbose=True)

stop = time.time()

**Learning curves**

In [None]:
# Convert to dataframe
history_df = pd.DataFrame(history.history)

# Subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,4))

# Plot loss metric
plt.subplot(1,2,1)
ax=history_df.reset_index().loc[:, ['loss', 'val_loss']].plot(title="Cross-entropy", ax=axes[0])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

# Plot accuracy metric
plt.subplot(1,2,2)
ax=history_df.reset_index().loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Accuracy", ax=axes[1])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

**Evaluation**

In [None]:
# Predictions
preds=np.argmax(model.predict(X_valid), axis=1)

# Confidence
conf=np.max(model.predict(X_valid), axis=1)

# Final accuracy and time
score4=accuracy_score(np.argmax(y_valid.values, axis=1), preds)
time4=np.round(stop-start,1)
print(f'Final accuracy on validation set:{np.round(100*score4,1)}%')
print(f'Training time: {time4} secs')

**Plot predictions**

In [None]:
# Plot some model predictions
plt.figure(figsize=(15,4))
plt.suptitle('Model predictions', fontsize=20, y=1.05)

# Subplot
for i in range(20):
    img = X_valid[i];
    ax=plt.subplot(2, 10, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f'Pred:{preds[i]} \n Conf:{np.round(100*conf[i],1)}', fontdict = {'fontsize':14})
    plt.imshow(img, cmap='gray')
    
plt.show()

**Distillation - KIP**

In [None]:
# Load KIP data (50 images per class)
with tf.io.gfile.GFile('gs://kip-datasets/kip/mnist/ConvNet_ssize500_nozca_l_noaug_ckpt78.npz', 'rb') as f:
    npz2 = np.load(f)

# Linear projection onto [0,1]
npz_normalised2=npz2['images'].copy()
for i in range(len(npz_normalised)):
    npz_normalised2[i] = (npz2['images'][i]-npz2['images'][i].min())/(npz2['images'][i].max()-npz2['images'][i].min())
    
# Construct (untrained) labels
labels_500=[]
for i in range(10):
    labels_500=np.concatenate([labels_500,i*np.ones(50)])    

**Train model**

In [None]:
# Define model
model=build_model()

# Measure training time
start=time.time()

# Train model
history = model.fit(
    npz_normalised2, pd.get_dummies(labels_500),
    validation_data=(X_valid, y_valid),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    verbose=True)

stop = time.time()

**Learning curves**

In [None]:
# Convert to dataframe
history_df = pd.DataFrame(history.history)

# Subplots
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,4))

# Plot loss metric
plt.subplot(1,2,1)
ax=history_df.reset_index().loc[:, ['loss', 'val_loss']].plot(title="Cross-entropy", ax=axes[0])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

# Plot accuracy metric
plt.subplot(1,2,2)
ax=history_df.reset_index().loc[:, ['categorical_accuracy', 'val_categorical_accuracy']].plot(title="Accuracy", ax=axes[1])
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'])

**Evaluation**

In [None]:
# Predictions
preds=np.argmax(model.predict(X_valid), axis=1)

# Confidence
conf=np.max(model.predict(X_valid), axis=1)

# Final accuracy and time
score5=accuracy_score(np.argmax(y_valid.values, axis=1), preds)
time5=np.round(stop-start,1)
print(f'Final accuracy on validation set:{np.round(100*score5,1)}%')
print(f'Training time: {time5} secs')

**Plot predictions**

In [None]:
# Plot some model predictions
plt.figure(figsize=(15,4))
plt.suptitle('Model predictions', fontsize=20, y=1.05)

# Subplot
for i in range(20):
    img = X_valid[i];
    ax=plt.subplot(2, 10, i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title(f'Pred:{preds[i]} \n Conf:{np.round(100*conf[i],1)}', fontdict = {'fontsize':14})
    plt.imshow(img, cmap='gray')
    
plt.show()

# Conclusion

In [None]:
# Summary of results
ssize=[33600, 10, 10, 500, 500]
distillation=['No','No','Yes','No','Yes']
scores=np.round(100*np.array([score1, score2, score3, score4, score5]),1)
times=[time1, time2, time3, time4, time5]

# Dataframe
results=pd.DataFrame({'Support size': ssize, 'Distillation': distillation, 'Training time (s)': times, 'Accuracy (%)': scores})
results

* Distillation managed to **reduce the training time by a factor of 10** whilst retaining a **relatively high accuracy**. 
* The learning curves show distillation helps prevent against overfitting. 
* These results are based on a simple CNN architecture. Higher accuracies can be achieved with the right model - see the [blog](https://ai.googleblog.com/2021/12/training-machine-learning-models-more.html).