In [3]:
import torch
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from torchvision import datasets
import numpy as np
# np.random.seed(42)


def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    参数为alpha的Dirichlet分布将数据索引划分为n_clients个子集
    '''
    n_classes = train_labels.max()+1
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, N)的类别标签分布矩阵X，记录每个client占有每个类别的多少

    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]
    # 记录每个K个类别对应的样本下标

    client_idcs = [[] for _ in range(n_clients)]
    # 记录N个client分别对应样本集合的索引
    for c, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例将类别为k的样本划分为了N个子集
        # for i, idcs 为遍历第i个client对应样本集合的索引
        for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]

    return client_idcs


# torch.manual_seed(42)

if __name__ == "__main__":

    N_CLIENTS = 100
    DIRICHLET_ALPHA = 0.1

    train_data = datasets.EMNIST(
        root=".", split="digits", download=True, train=True)
    test_data = datasets.EMNIST(
        root=".", split="digits", download=True, train=False)
    n_channels = 1

    input_sz, num_cls = train_data.data[0].shape[0],  len(train_data.classes)

    train_labels = np.array(train_data.targets)

    # matplotlib.rcParams['text.usetex'] = True
    cmap = 'Spectral'
    # plt.style.use("seaborn-dark")
    color = plt.get_cmap(cmap)(np.linspace(0, 1, 10))
    n=5
    alpha=[1,0.1,0.01]
    # 展示不同client的不同label的数据分布
    fig, ax = plt.subplots(n,3, figsize=(10, 10))
    # fig.set_label('aaa')
    for r in range(3):
        DIRICHLET_ALPHA=alpha[r]
        # 我们让每个client不同label的样本数量不同，以此做到Non-IID划分
        client_idcs = dirichlet_split_noniid(
            train_labels, alpha=DIRICHLET_ALPHA, n_clients=N_CLIENTS)
        g = []
        for i in range(10):
            np.random.shuffle(client_idcs)
            g.append(client_idcs[0:10])
        for i in range(n):
            # ax[i].figure(figsize=(20, 3))
            # ax[i][r].set_ylim(1969.5, 2011.1)
            ax[i][r].hist([train_labels[idc]for idc in g[i]], stacked=True,
                bins=np.arange(min(train_labels)-0.5, max(train_labels) + 1.5, 1),
                    label=["Client {}".format(j) for j in range(N_CLIENTS)], rwidth=0.5,color=color)
            # ax[i].legend()
            ax[i][r].set_xticks(np.arange(num_cls))
            # ax[i][r].set_title('r={}, alpha={}'.format(r+1,alpha[r]))
            ax[i][r].set_xlabel('Labels')
            # ax[i][r].set_ylabel('num of samples')
            # ax[0][i].set_title(key)
            # ax[0][i].set_xlabel('Rounds')
            # ax[0][i].set_ylabel(key)
            ax[i][0].set_ylabel('r={}'.format(i+1))
            ax[0][r].set_title(r'$\alpha ={}$'.format(alpha[r]))
    lines, labels = ax[i][r].get_legend_handles_labels()
    
    fig.supxlabel('Different Alpha')
    fig.supylabel('Communication Rounds')
    fig.legend(lines, labels, loc=4, bbox_to_anchor=(
        1.1, 0.01), borderaxespad=0., mode=None)
    fig.tight_layout()
    fig.suptitle('Period drift') 
    fig.subplots_adjust(top=0.92)
    fig.savefig('period_drift.pdf', bbox_inches = 'tight')
    
    


AttributeError: module 'torch' has no attribute '_six'