In [3]:
import tensorflow as tf
import numpy as np
import pickle

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

num_workers = 20
n_train = len(y_train)
n_test = len(y_test)

idxs = np.random.permutation(n_train)
batch_idxs = np.array_split(idxs, num_workers)   

In [3]:
from tensorflow.keras import datasets, layers, models

def create_mlp():
  model = models.Sequential()
  model.add(tf.keras.Input(shape=(28,28)))
  model.add(layers.Flatten())
  model.add(layers.Dense(100, activation='relu'))
  model.add(layers.Dense(10, activation='softmax'))

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(0.01),
      metrics=['accuracy']
  ) 

  return model 

In [4]:
model = create_mlp()
model.summary()

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 100)               78500     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1010      
Total params: 79,510
Trainable params: 79,510
Non-trainable params: 0
_________________________________________________________________


In [None]:
for j in range(num_workers):
  model = create_mlp()
  train_features, train_labels = x_train[batch_idxs[j]], y_train[batch_idxs[j]]
  model.fit(train_features, train_labels, batch_size=32, epochs=9, validation_data=(x_test, y_test))

  Ws = model.get_weights()
  with open("mni_workernn_{}".format(j), 'wb+') as f:
    pickle.dump(Ws, f)

In [6]:
for j in range(num_workers):
    model = create_mlp()
    with open("mni_workernn_{}".format(j), 'rb') as f:
        ws = pickle.load(f)
    model.set_weights(ws)
    print(model.evaluate(x_test, y_test))

[0.4862384514331818, 0.8803]
[0.5021830556154251, 0.8758]
[0.4963450841426849, 0.8715]
[0.4986572093963623, 0.8705]
[0.49069374022483825, 0.8736]
[0.4976419484376907, 0.8769]
[0.48803191804885865, 0.8817]
[0.5078792052268982, 0.8686]
[0.47933087840080263, 0.8778]
[0.4905010535001755, 0.8759]
[0.48608291573524476, 0.8768]
[0.501139972114563, 0.8701]
[0.48320589809417724, 0.874]
[0.4853775868654251, 0.8739]
[0.49945286605358125, 0.869]
[0.4944356895685196, 0.8775]
[0.49078040227890013, 0.875]
[0.48326182775497434, 0.8774]
[0.4935495052576065, 0.8725]
[0.487951301074028, 0.8764]


# for cifa10

In [4]:
cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0



In [5]:
from tensorflow.keras import datasets, layers, models

def create_mlp():
    model = models.Sequential()
    # model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    # model.add(layers.MaxPooling2D((2, 2)))
    # model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    # model.add(layers.MaxPooling2D((2, 2)))
    # model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(1536, activation='relu'))
    model.add(layers.Dense(384, activation='relu'))
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10))

    model.compile(optimizer=tf.keras.optimizers.SGD(0.001, momentum=0.9, nesterov=True),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    return model

In [24]:
history = model.fit(x_train, y_train, epochs=10, batch_size=64,
                    validation_data=(x_test, y_test))

Train on 50000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [25]:
model.evaluate(x_test, y_test, verbose=1)



[1.3744550441741943, 0.5072]

In [6]:
num_workers = 40
n_train = len(y_train)
n_test = len(y_test)

idxs = np.random.permutation(n_train)
batch_idxs = np.array_split(idxs, num_workers)   

for j in range(num_workers):
  model = create_mlp()
  train_features, train_labels = x_train[batch_idxs[j]], y_train[batch_idxs[j]]
  model.fit(train_features, train_labels, batch_size=64, epochs=10, validation_data=(x_test, y_test), verbose=0)

  Ws = model.get_weights()
  with open("cifar10_{}.pb".format(j), 'wb+') as f:
    pickle.dump(Ws, f)    

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [10]:
maps = {i: dataidx for i, dataidx in enumerate(batch_idxs)}

def saved_cls_counts(net_dataidx_map):
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
        
    with open('cifar10_counts', 'wb+') as f:
        pickle.dump(net_cls_counts, f)

saved_cls_counts(maps)

{0: array([ 4219, 42286,  6642, ..., 48874, 48800, 48710]), 1: array([15990, 14719,  7778, ..., 10015, 34386, 45670]), 2: array([32013, 19734, 25504, ..., 41383,  7567, 10785]), 3: array([23218, 46749, 36137, ..., 12351, 41417, 33904]), 4: array([46143, 13257, 12382, ...,  2293, 18944, 45370]), 5: array([30980, 10077, 41684, ..., 22458, 36590, 41647]), 6: array([   75, 11978, 39749, ..., 19506, 14970, 41265]), 7: array([42988, 38910,  1386, ..., 26642, 12346,  3922]), 8: array([24392, 32694, 13181, ..., 14172,    91, 22183]), 9: array([33608, 49368, 46530, ..., 22613, 20269, 48593]), 10: array([49001, 25241, 41362, ...,  5497, 37134,   412]), 11: array([22660, 23078, 17909, ...,  8835, 34090, 34423]), 12: array([11248, 48604, 17840, ..., 32441, 23911, 42153]), 13: array([22284, 29144,  7871, ..., 28035, 39055, 45679]), 14: array([11273, 25669, 22939, ..., 12365, 33458, 26230]), 15: array([15090, 19619, 39679, ...,  9264, 47016, 23121]), 16: array([32600,  5187, 15626, ..., 32022, 39481