In [56]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

In [57]:
pwd

'/Users/wzx'

In [58]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cpu


In [59]:
dataset = 'mnist'
bias = 0.5
batch_size = 32
lr = 0.005

In [60]:
import math
import torch
from torch.optim import Optimizer


class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).
    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.
        Considering the specific case of Momentum, the update can be written as
        .. math::
                  v = \rho * v + g \\
                  p = p - lr * v
        where p, g, v and :math:`\rho` denote the parameters, gradient,
        velocity, and momentum respectively.
        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form
        .. math::
             v = \rho * v + lr * g \\
             p = p - v
        The Nesterov version is analogously modified.
    """

    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, grads, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for i,p in enumerate(group['params']):
#                 if p.grad is None:
#                     continue
                
                d_p = grads[i]
                
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss

In [61]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 5)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(800, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [62]:
# print(summary(net, (1,28,28)))

In [63]:
criterion = nn.CrossEntropyLoss()

In [64]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data = torch.utils.data.DataLoader(trainset, batch_size=60000, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = torch.utils.data.DataLoader(testset, batch_size=5000, shuffle=False)

In [65]:
bias = 0.5
num_workers = 100
bias_weight = bias
other_group_size = (1 - bias_weight) / 9.
worker_per_group = num_workers / 10

In [66]:
#This code block is distributing the data amongst the clients. This is called the fang distribution
each_worker_data = [[] for _ in range(num_workers)]
each_worker_label = [[] for _ in range(num_workers)]
for i, (data, labels) in enumerate(train_data):
    for (x, y) in zip(data, labels):
        print(np.shape(x))
#         x = np.reshape(x,(1,1,28,28))
#         if args.dataset == 'cifar10' and (args.net == 'cnn' or args.net == 'resnet20'):
#             x = x.as_in_context(ctx).reshape(1, 3, 32, 32)
#         elif args.dataset == 'mnist' and args.net == 'cnn':
#             x = x.as_in_context(ctx).reshape(1, 1, 28, 28)
#         else:
#             x = x.as_in_context(ctx).reshape(-1, num_inputs)
#         y = y.as_in_context(ctx)

        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        # assign a data point to a worker
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        each_worker_data[selected_worker].append(x)
        each_worker_label[selected_worker].append(y)

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 2

In [67]:
each_worker_data = [torch.stack(each_worker) for each_worker in each_worker_data]
each_worker_label = [torch.stack(each_worker) for each_worker in each_worker_label]

In [68]:
seed = 3
random_order = np.random.RandomState(seed=seed).permutation(num_workers)
each_worker_data = [each_worker_data[i] for i in random_order]
each_worker_label = [each_worker_label[i] for i in random_order]

In [69]:
each_worker_data = [each_worker.to(device) for each_worker in each_worker_data]
each_worker_label = [each_worker.to(device) for each_worker in each_worker_label]

In [70]:
pwd

'/Users/wzx'

In [71]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(gain)

In [73]:

import numpy as np
import time
import random

net = cnn().to(device)
gain = 1
net.apply(init_weights)

lr = 0.005

latency_threshold = 50

epochs = 300


# Read latencies from file
def read_latencies(file_path):
    with open(file_path, 'r') as file:
        latencies = [[float(value) for value in line.strip().split()] for line in file]
    return latencies

latencies = read_latencies('/Users/wzx/Downloads/new_latencies_300.txt')

# Training loop
for e in range(epochs):
    start_time = time.time()
    cnn_optimizer = SGD(net.parameters(), lr=lr*(0.999**e))
    user_grads = []

    for i in range(100):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i][:])
        loss = criterion(output, each_worker_label[i][:])
        loss.backward(retain_graph=True)
        
        param_grad = []
        for param in net_.parameters():
            param_grad = param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad, param.grad.view(-1)))

        # Get latency for the current client
        latency = latencies[e][i]

        if latency <= latency_threshold: # Only the gradients from clients with latency below the threshold are accepted
            user_grads = param_grad[None, :] if len(user_grads) == 0 else torch.cat((user_grads, param_grad[None, :]), 0)
        
        del net_

    if len(user_grads) > 0:
        agg_grads = torch.mean(user_grads, dim=0)
        del user_grads

        cnn_optimizer.zero_grad()
        model_grads = []
        start_idx = 0
        for param in net.parameters():
            param_ = agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx += len(param.data.view(-1))
            model_grads.append(param_)

        cnn_optimizer.step(model_grads)

    # Evaluate model
    total, correct = 0, 0
    with torch.no_grad():
        for data in test_data:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    end_time = time.time()
    epoch_time = end_time - start_time

    print(f'Epoch {e+1}/{epochs}, Time per Epoch: {epoch_time:.2f} seconds, Accuracy: {accuracy:.4f}')

    

Epoch 1/300, Time per Epoch: 35.77 seconds, Accuracy: 0.0982
Epoch 2/300, Time per Epoch: 52.64 seconds, Accuracy: 0.0980
Epoch 3/300, Time per Epoch: 45.61 seconds, Accuracy: 0.1180
Epoch 4/300, Time per Epoch: 42.40 seconds, Accuracy: 0.1703
Epoch 5/300, Time per Epoch: 42.19 seconds, Accuracy: 0.1135
Epoch 6/300, Time per Epoch: 46.96 seconds, Accuracy: 0.1162
Epoch 7/300, Time per Epoch: 46.40 seconds, Accuracy: 0.0980
Epoch 8/300, Time per Epoch: 44.83 seconds, Accuracy: 0.1196
Epoch 9/300, Time per Epoch: 46.29 seconds, Accuracy: 0.1275
Epoch 10/300, Time per Epoch: 45.62 seconds, Accuracy: 0.1275
Epoch 11/300, Time per Epoch: 44.82 seconds, Accuracy: 0.1275
Epoch 12/300, Time per Epoch: 47.26 seconds, Accuracy: 0.0892
Epoch 13/300, Time per Epoch: 42.15 seconds, Accuracy: 0.0892
Epoch 14/300, Time per Epoch: 41.54 seconds, Accuracy: 0.1018
Epoch 15/300, Time per Epoch: 44.98 seconds, Accuracy: 0.1018
Epoch 16/300, Time per Epoch: 42.87 seconds, Accuracy: 0.2033
Epoch 17/300, Tim

In [54]:
total, correct = 0,0
with torch.no_grad():
    for i, data in enumerate(test_data):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct/total)

0.5621


In [55]:
len(global_models)

NameError: name 'global_models' is not defined

In [None]:
!nvidia-smi