In [1]:
from collections import OrderedDict
from typing import List, Tuple, Dict, Optional, Callable, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

from gradient_descent_the_ultimate_optimizer import gdtuo

import flwr as fl
from flwr.common import Metrics
from flwr_datasets import FederatedDataset
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}
    
    
disable_progress_bar()


Training on cpu using PyTorch 2.0.1+cpu and Flower 1.7.0


In [2]:
NUM_CLIENTS = 1
BATCH_SIZE = 32

def load_datasets():
    fds = FederatedDataset(dataset="cifar10", partitioners={"train" : NUM_CLIENTS})

    def apply_transforms(batch):
        transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        )
        batch["img"] = [transform(img) for img in batch["img"]]
        return batch

    trainloaders = []
    valloaders = []
    for partition_id in range(NUM_CLIENTS):
        partition = fds.load_partition(partition_id, "train")
        partition = partition.with_transform(apply_transforms)
        partition = partition.train_test_split(train_size=0.8)
        trainloaders.append(DataLoader(partition["train"], batch_size=BATCH_SIZE))
        valloaders.append(DataLoader(partition["test"], batch_size=BATCH_SIZE))
    testset = fds.load_full("test").with_transform(apply_transforms)
    testloaders = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloaders

trainloaders, valloaders, testloaders = load_datasets()   

In [12]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
                                   nn.Conv2d(3, 32, 3, padding=1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                   )
        self.conv2 = nn.Sequential(
                                   nn.Conv2d(32, 64, 3, padding=1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                   )
        self.conv3 = nn.Sequential(
                                   nn.Conv2d(64, 64, 3, padding=1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                   )
        self.dropout = nn.Dropout(0.0)
        self.fc = nn.Sequential(
                                nn.Linear(1024, 64),
                                nn.ReLU(),
                                )
        self.clf = nn.Linear(64, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.fc(self.dropout(x.flatten(1)))
        return self.clf(self.dropout(x))


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(mw, trainloader, epochs: int, verbose=False):
    criterion = torch.nn.CrossEntropyLoss()
    epoch_loss_dict = []
    for i in range(1, epochs+1):
        running_loss = 0.0
        alpha_grad = 0
        mu_grad = 0
        j = 0
        for batch in trainloader:
            mw.begin() # call this before each step, enables gradient tracking on desired params
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            outputs = mw.forward(images)
            loss = criterion(outputs, labels)
            mw.zero_grad()
            loss.backward(create_graph=True) # important! use create_graph=True
            for name, param in mw.optimizer.parameters.items():
                if name == 'alpha_0':
                    alpha_grad += param.grad.clone().detach()
                if name == 'mu_0':
                    mu_grad += param.grad.clone().detach()
            mw.step()
            if j % 250 == 0:
                print(mw.optimizer.state_dict())
            running_loss += loss.item()
            j += 1
        train_loss = running_loss / len(trainloaders[0].dataset)
        epoch_loss_dict.append(train_loss)
        if verbose == True:
            print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss))
            print(alpha_grad)
            print(mu_grad)
    return epoch_loss_dict


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            net.begin()
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            net.zero_grad()
            outputs = net.forward(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [16]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [18]:
from gradient_descent_the_ultimate_optimizer import gdtuo

model = Net().to(DEVICE)
optim = gdtuo.SGD(alpha=0.01, mu=0.3, height=0, optimizer=gdtuo.SGD(alpha=5e-8, mu=1e-5, height=1))
mw = gdtuo.ModuleWrapper(model, optimizer=optim)
mw.initialize()

# loss1, accuracy1 = test(net, valloaders[0])
# print("Test Dataset: ",loss1, accuracy1)
# loss2, accuracy2 = test(net, trainloaders[0])
# print("Train Dataset :",loss2, accuracy2)

epoch_loss_dict = train(mw, trainloaders[0], epochs=30, verbose=True)

# epoch_loss_dict1, epoch_acc1, optimizer_state_dict1 = train(net, optimizer, trainloaders[0], epochs=30)
# epoch_loss_dict2, epoch_acc2, optimizer_state_dict2 = train(net, optimizer, trainloaders[1], epochs=30)
# epoch_loss_dict3, epoch_acc3, optimizer_state_dict3 = train(net, optimizer, trainloaders[2], epochs=30)

loss1, accuracy1 = test(mw, valloaders[0])
print("Test Dataset: ",loss1, accuracy1)
loss2, accuracy2 = test(mw, trainloaders[0])
print("Train Dataset :",loss2, accuracy2)


# print(get_parameters(net))
# print(get_parameters(optimizer))

{'alpha_0': tensor(0.0100), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0100), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0100), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0100), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0100), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
EPOCH: 1, TRAIN LOSS: 0.06384977233409882
tensor(1499.7126)
tensor(13.4975)
{'alpha_0': tensor(0.0099), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0099), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}
{'alpha_0': tensor(0.0098), 'mu_0': tensor(0.3000), 'alpha_1': tensor(5.0000e-08), 'mu_1': tensor(1.0000e-05)}


KeyboardInterrupt: 