In [3]:
import argparse
import torch
import torch.nn as NN
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import SMART_FEDERAL as SF
from distillation import *
from train import train
from test import test
import time
from model import CNN_Model
import os
import csv
from utils import XKs, Measure, get_result
from dataloader import MyDataset


def main(world_size, epochs, rank, batch_size=200, backend='nccl', data_path='/dataset',
         lr=1e-5, momentum=0.01, no_cuda=False, seed=35, aggregation_method='naive',
         load_model = False, load_path = '/data'
        ):
    '''Main Function'''

#     parser = argparse.ArgumentParser(description='PyTorch: Deep Mutual Learning')
#     parser.add_argument('--worker_size', type=int, default=5, metavar='N',
#                         help='the number of wokers/nodes (default: 5)')
#     parser.add_argument('--batch_size', type=int, default=128, metavar='N',
#                         help='input batch size for training (default: 128)')
#     parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N',
#                         help='input batch size for testing (default: 1000)')
#     parser.add_argument('--epochs', type=int, default=50000, metavar='N',
#                         help='number of epochs to train (default: 50000)')
#     parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
#                         help='learning rate (default: 0.01)')
#     parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
#                         help='SGD momentum (default: 0.5)')
#     parser.add_argument('--no-cuda', action='store_true', default=False,
#                         help='disables CUDA training')
#     parser.add_argument('--seed', type=int, default=1, metavar='S',
#                         help='random seed (default: 1)')
#     parser.add_argument('--aggregation_method_pool', type=str, default='naive',
#                         help='choices of aggregation method')
#     parser.add_argument('--save-model', action='store_true', default=True,
#                         help='For Saving the current Model')
#     parser.add_argument('--rank', type=int,
#                         help='the id of node. 0 is master and others are servants')
#     parser.add_argument('--backend', type=str, default='nccl',
#                         help='communication protocol')
#     parser.add_argument('--data_path', type=str, default='/dataset',
#                         help='data path')
#     args = parser.parse_args()

    use_cuda = not no_cuda and torch.cuda.is_available() #使用的使用多个gpu进行训练，需要改进一下
    torch.manual_seed(seed)
    timeline = time.strftime('%m%d%Y_%H:%M', time.localtime())
    device = torch.device("cuda" if use_cuda else "cpu")
    aggregation_method_pool = ["naive", "bsz_average", "weight_average", "distillation"]
    ratio = 0.8551957853612336

    weight = torch.FloatTensor([ratio, 1 - ratio]).to(device)
    Loss = NN.BCELoss(weight=weight)

    if rank == 0:
        name = 'master'
        result_dir = 'result_' + timeline + '/' + name + '/'
        model_dir = 'models_' + timeline + '/' + name + '/'
        csvname = '{}_log'.format(name) + timeline + '.csv'
        modelname = 'model_{:d}.pth'
                             
        if aggregation_method == 'distillation':
            DistillationData = MyDataset(root=data_path, train=True, data_root='distillation.csv')
            distillation_dataloader = DataLoader(dataset=DistillationData, batch_size=batch_size,
                                                 shuffle=True, drop_last=True)
        
        model_set = []
        for worker_id in range(world_size):
            model_set.append(CNN_Model().to(device))
             
        if load_model:
            if aggregation_method == 'distillation':
                raise ValueError('Unexpected model')
            for worker_id in range(world_size):
                model_set[worker_id].load_state_dict(torch.load(load_path))
            
        opt_set = []
        for worker_id in range(world_size):
            opt_set.append(optim.SGD(model_set[worker_id].parameters(), lr=lr, momentum=momentum))

        model = SF.Master(model=model_set[0], backend=backend, rank=rank, world_size=world_size, learning_rate=lr,
                          device=device, aggregation_method=aggregation_method)
        for epoch in range(1, epochs+1):
            model.train()
            model.step(model_buffer=model_set)
            model.update(model_set[1:])
            if aggregation_method == 'distillation':         
                distillation(NN_set=model_set[1:], opt_set=opt_set[1:], dataset=distillation_dataloader,
                             world_size=world_size, epoch=epoch, device=device)                
#                 best_idx = choose_best(NN_set=model_set[1:], name=name, dataset=dataloader,world_size=world_size,
#                                        epoch=epoch, Loss=Loss, time=timeline)
#                 best_state_dict = model_set[best_idx+1].state_dict()
#                 model_set[0].load_state_dict(best_state_dict)

#这里要回传所有的模型

    else:
        name = 'worker'+str(rank)
        result_dir = 'result_' + timeline + '/' + name + '/'
        model_dir = 'models_' + timeline + '/' + name + '/'
        csvname = '{}_log'.format(name) + timeline + '.csv'
        modelname = 'model_{:d}.pth'       
        
        DataSet_train = MyDataset(root=data_path, train=True, data_root='{}.csv'.format(name)) 
        dataloader_train = DataLoader(dataset=DataSet_train, batch_size=batch_size, shuffle=True,
                                drop_last=True)
        DataSet_test = MyDataset(root=data_path, train=True, data_root='{}.csv'.format('test')) 
        dataloader_test = DataLoader(dataset=DataSet_test, batch_size=batch_size, shuffle=True,
                                drop_last=True)
        
        model_set = []
        for worker_id in range(world_size):
            model_set.append(CNN_Model().to(device))
             
        if load_model:
            for worker_id in range(world_size):
                model_set[worker_id].load_state_dict(torch.load(load_path))
        backup_model = CNN_Model().to(device)
        train_model = model_set[0]
        optimizer = optim.SGD(train_model.parameters(), lr=lr, momentum=momentum)  
        model = SF.Servent(model=train_model, backend=backend, rank=rank, world_size=world_size,
                           device=device, aggregation_method=aggregation_method)
        for epoch in range(1, epochs+1):
            model.train()
            model.step(model_buffer=model_set, rank=rank)

            best_state_dict = train_model.state_dict()
            backup_model.load_state_dict(best_state_dict)

            train(dataloader=dataloader_train, model=train_model, optimizer=optimizer, Loss = Loss, 
                  epoch=epoch, time=timeline, result_dir=result_dir, model_dir=model_dir, device=device,
                  csvname=csvname, modelname=modelname)

            model.update(backup_model)

            test(dataloader=dataloader_test, model=train_model, epoch=epoch, Loss=Loss, time=timeline,
                 result_dir=result_dir, model_dir=model_dir, csvname=csvname, modelname=modelname,
                 device=device)



In [4]:
# def init_process():
#     os.environ["MASTER_ADDR"] = "localhost"
#     os.environ["MASTER_PORT"] = "3456"

In [5]:
# init_process()

In [None]:
main(world_size=4, epochs=2000, rank=0, batch_size=200, 
     backend='nccl', data_path='./sp_data/',
     lr=1e-5, momentum=0.01, no_cuda=False, seed=35)