In [34]:
from utils.model_utils import read_data, read_client_data
import torch
import os
import torch.nn as nn
import multiprocessing
from torch.utils.data import DataLoader
import torch.nn.functional as F
from multiprocessing import Process

定义 embedding 模型

In [35]:
class EmbModel(nn.Module):

    def __init__(self, input_size, embedding_size):
        super().__init__()
        self.embedding_size = embedding_size
        
        # Encoder specification
        self.enc_linear_1 = nn.Linear(input_size, self.embedding_size)
        
        # Decoder specification
        self.dec_linear_1 = nn.Linear(self.embedding_size, input_size)

    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code
    
    def encode(self, code):
        code = self.enc_linear_1(code)
        return code
    
    def decode(self, code):
        out = F.sigmoid(self.dec_linear_1(code))
        return out

定义客户端，继承自 Process 类

In [41]:
def save_results(cid, epochs, model, emb):
    save_name = 'C' + str(cid) + '_E' + str(epochs) + '_result'

    # save embedding
    with open(save_name+'.txt', 'a+') as emb_fp:
        emb_fp.write(','.join(list(map(str, emb)))+'\n')

    # save model
    model_path = os.path.join("saved_models", "cifia100")
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    torch.save(model, os.path.join(model_path, save_name + '.pt'))
    

'''
返回 embedding
'''
def train(cid, train_data, epochs, lr):
    batch_size = 128

    loss_fn = nn.MSELoss()
    model = EmbModel(3072, 15)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    epochs = epochs
    train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    train_data_fullloader = DataLoader(train_data, batch_size=len(train_data), shuffle=True)

    for epoch in range(epochs):
        for X, y in train_data_loader:
            optimizer.zero_grad()
            X = X.reshape(-1, 3072)
            out, code = model(X)
            optimizer.zero_grad()
            train_loss = loss_fn(out, X)
            train_loss.backward()
            optimizer.step()
        
        if (epoch+1) % 100 == 0:
            model.eval()
            with torch.no_grad():
                encoded = torch.zeros([1, 15], dtype=torch.float)
                encoded = model.encode(next(iter(train_data_fullloader)))
            emb = torch.mean(encoded, axis=0).numpy().tolist()

            save_results(cid, epoch, model, emb)
    
    # return embedding
    model.eval()
    with torch.no_grad():
        encoded = torch.zeros([1, 15], dtype=torch.float)
        encoded = model.encode(next(iter(train_data_fullloader)))
    return torch.mean(encoded, axis = 0, keepdims = True).numpy().tolist()

In [37]:
a = torch.Tensor([[1, 2], [3, 4]])
print(torch.sum(a, axis=0))

tensor([4., 6.])


训练参数

In [38]:
total_clients = 5
epochs = 500
learning_rate = 0.05
dataset = "cifia100"
data = read_data(dataset)

In [39]:
def tSNEVisual(save_name, input_vector):
        import numpy as np
        import matplotlib.pyplot as plt
        from sklearn.manifold import TSNE

        labels = []
        col = 0

        for i in range(100):
            labels.append(col)
            # next label
            if (i+1) % 5 == 0:
                col += 1

        # Scaling the coordinates to [0, 1]
        def plot_embedding(data):
            x_min, x_max = np.min(data, 0), np.max(data, 0)
            data = (data - x_min) / (x_max - x_min)
            return data
        
        tsne = TSNE(n_components=2, init='pca', random_state=0, n_jobs=30, verbose=1, n_iter=10000)
        X_tsne = tsne.fit_transform(input_vector)
        aim_data = plot_embedding(X_tsne)

        plt.figure()
        plt.subplot(111)
        plt.scatter(aim_data[:, 0], aim_data[:, 1], c=labels)
        plt.savefig(save_name, dpi=600)

训练阶段

In [42]:
# 创建进程池
# pool = multiprocessing.Pool(multiprocessing.cpu_count()*2)
# result = []
# embeddings = []

for c in range(total_clients):
    print('current: ', c)
    cid, train_data, test_data = read_client_data(c, data, dataset)
    res = train(c+1, train_data, epochs, learning_rate)
    print(res, len(res))
#     result.append(pool.apply_async(train, args=(c+1, epochs, train_data, learning_rate, )))

# pool.close()
# pool.join()

current:  0
2
[[-617.7127075195312, -568.6993408203125, 473.13043212890625, -53.72748947143555, -8.818199157714844, 125.99597930908203, 267.3271484375, 218.1778564453125, 338.5292053222656, 384.20953369140625, -371.4195861816406, -94.0815658569336, 404.69732666015625, -118.89968872070312, 295.4736633300781]] 1
current:  1
2
[[-471.06683349609375, -264.04168701171875, 339.3094177246094, -392.83099365234375, 143.5758056640625, -348.8356018066406, 364.4430847167969, -279.51678466796875, -217.2294158935547, -80.01964569091797, -283.36614990234375, -658.3558349609375, 569.9845581054688, -621.1047973632812, -270.90411376953125]] 1
current:  2
2
[[108.86150360107422, 6.485332012176514, 557.6822509765625, -373.96820068359375, 336.7411193847656, -388.1986999511719, 465.2929992675781, 132.4920196533203, 48.64500045776367, 308.4232177734375, 179.9937286376953, -181.40274047851562, 95.08342742919922, 708.9345092773438, -304.086669921875]] 1
current:  3
2
[[-252.37281799316406, 522.3475341796875, 5

In [37]:
for res in result:
    embeddings.append(res.get())

tSNEVisual('cifia.pdf', embeddings)


with open('test.txt', 'a+') as emd_fp:
    lst = [1, 2, 3]
    emd_fp.write(','.join(list(map(str, lst))))