-
Notifications
You must be signed in to change notification settings - Fork 0
/
generateMNISTColabModel.py
116 lines (93 loc) · 3.02 KB
/
generateMNISTColabModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import tensorflow as tf
import pickle
import sys
import colabModel
config_name = "5_non_overlap_parties"
party_digits_config = {
"5_non_overlap_parties": [[0,1], [2,3], [4,5], [6,7], [8,9]],
"5_parties_no_replication": [[0], [1, 2], [0, 1, 2, 3], [3, 4, 5], [6, 7, 8, 9]],
"5_parties_no_replication_1all": [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2],
[0, 1, 2, 3],
[3, 4, 5],
[6, 7, 8, 9],
],
"6_parties_replication_1": [
[0,],
[1, 2],
[1, 2],
[0, 1, 2, 3],
[3, 4, 5],
[6, 7, 8, 9],
],
"6_parties_replication_4": [
[0],
[1, 2],
[0, 1, 2, 3],
[3, 4, 5],
[6, 7, 8, 9],
[6, 7, 8, 9],
],
}
replicated_party_idxs = {
"5_non_overlap_parties": [],
"5_parties_no_replication": [],
"5_parties_no_replication_1all": [],
"6_parties_replication_1": [[1,2]],
"6_parties_replication_4": [[4,5]]
}
party_digits = party_digits_config[config_name]
print("config_name:", config_name)
print(party_digits)
def mnist_parties_data_split(x_train, y_train):
# split data based on the digits
# [0], [1,2], [3,4,5], [6,7,8,9]
x_trains = []
y_trains = []
for party in range(len(party_digits)):
party_data_idxs = []
for digit in party_digits[party]:
idxs = np.where(y_train == digit)[0]
party_data_idxs.extend(list(idxs))
party_x_train = x_train[party_data_idxs]
party_y_train = y_train[party_data_idxs]
x_trains.append(party_x_train)
y_trains.append(party_y_train)
return x_trains, y_trains
def test_mnist_parties_data_split():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
x_trains, y_trains = mnist_parties_data_split(x_train, y_train)
for i, ys in enumerate(y_trains):
print("Party", i)
print(ys.shape)
print(np.unique(ys))
model = colabModel.ColabModel(
dataset_name="mnist",
party_data_split=mnist_parties_data_split,
model_config={
"n_epoch": 5,
"batchsize": 64,
"n_neurons": [64, 64],
"activations": ["relu", "relu"],
},
)
# init_weights = model.model.get_weights()
# with open("mnist_init_weights.pkl", "wb") as outfile:
# pickle.dump(init_weights, outfile, protocol=pickle.HIGHEST_PROTOCOL)
with open("mnist_training/mnist_init_weights.pkl", "rb") as infile:
init_weights = pickle.load(infile)
n_party = len(party_digits)
test_accs = np.zeros(1 << n_party)
for i in range(1, 1 << n_party):
test_accs[i] = model.train(i, init_weights=init_weights)
sys.stdout.flush()
with open("mnist_training/mnist_v_{}.pkl".format(config_name), "wb") as outfile:
pickle.dump(
{"party_labels": party_digits, "name": config_name, "test_acc": test_accs, "replicated_party_idxs": replicated_party_idxs[config_name]},
outfile,
protocol=pickle.HIGHEST_PROTOCOL,
)