In [1]:
from typing import Any, Dict, List, Optional, Tuple, Sequence
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets, transforms

In [2]:
# DATA & SAMPLER
class MNISTDataset(datasets.MNIST):

    N_CLASSES = 10

    def __init__(self, root: str, train: bool):
        transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        super().__init__(root=root, train=train, download=True, transform=transform)

    def __getitem__(self, index):
        x, y = self.data[index], self.targets[index]
        x = self.transform(x)

        return x, y

class FederatedSampler(Sampler):
    def __init__(
        self,
        dataset: Sequence,
        non_iid: int,
        n_clients: Optional[int] = 100,
        n_shards: Optional[int] = 200,
    ):
        """Sampler for federated learning in both iid and non-iid settings.

        Args:
            dataset (Sequence): Dataset to sample from.
            non_iid (int): 0: IID, 1: Non-IID
            n_clients (Optional[int], optional): Number of clients. Defaults to 100.
            n_shards (Optional[int], optional): Number of shards. Defaults to 200.
        """
        self.dataset = dataset
        self.non_iid = non_iid
        self.n_clients = n_clients
        self.n_shards = n_shards

        if self.non_iid:
            self.dict_users = self._sample_non_iid()
        else:
            self.dict_users = self._sample_iid()

    def _sample_iid(self) -> Dict[int, List[int]]:
        num_items = int(len(self.dataset) / self.n_clients)
        dict_users, all_idxs = {}, [i for i in range(len(self.dataset))]

        for i in range(self.n_clients):
            dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
            all_idxs = list(set(all_idxs) - dict_users[i])

        return dict_users

    def _sample_non_iid(self) -> Dict[int, List[int]]:
        num_imgs = len(self.dataset) // self.n_shards  # 300

        idx_shard = [i for i in range(self.n_shards)]
        dict_users = {i: np.array([]) for i in range(self.n_clients)}
        idxs = np.arange(self.n_shards * num_imgs)
        labels = self.dataset.targets.numpy()

        # sort labels
        idxs_labels = np.vstack((idxs, labels))
        idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
        idxs = idxs_labels[0, :]

        # divide and assign 2 shards/client
        for i in range(self.n_clients):
            rand_set = set(np.random.choice(idx_shard, 2, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand * num_imgs : (rand + 1) * num_imgs]),
                    axis=0,
                )

        return dict_users

    def set_client(self, client_id: int):
        self.client_id = client_id

    def __iter__(self):
        # fetch dataset indexes based on current client
        client_idxs = list(self.dict_users[self.client_id])
        for item in client_idxs:
            yield int(item)

In [3]:
# MODELS
class MLP(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_classes: int):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, n_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return self.softmax(x)

class CNN(nn.Module):
    def __init__(self, n_channels: int, n_classes: int):
        super(CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, n_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(-1, 1024)
        x = self.fc(x)

        return F.log_softmax(x, dim=1)

In [4]:
# UTILS
class Logger:
    def __init__(self, args):
        self.args = args

    def log(self, logs: Dict[str, Any]) -> None:
        print(logs)

def average_weights(weights: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    weights_avg = copy.deepcopy(weights[0])

    for key in weights_avg.keys():
        for i in range(1, len(weights)):
            weights_avg[key] += weights[i][key]
        weights_avg[key] = torch.div(weights_avg[key], len(weights))

    return weights_avg


In [5]:
# FEDAVG
class FedAvg:
    def __init__(self, args: Dict[str, Any]):
        self.args = args
        self.device = torch.device(
            f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
        )
        self.logger = Logger(args)

        self.train_loader, self.test_loader = self._get_data(
            n_clients=self.args.n_clients,
            n_shards=self.args.n_shards,
            non_iid=self.args.non_iid,
        )

        if self.args.model_name == "mlp":
            self.root_model = MLP(input_size=784, hidden_size=128, n_classes=10).to(
                self.device
            )
            self.target_acc = 0.97
        elif self.args.model_name == "cnn":
            self.root_model = CNN(n_channels=1, n_classes=10).to(self.device)
            self.target_acc = 0.99
        else:
            raise ValueError(f"Invalid model name, {self.args.model_name}")

        self.reached_target_at = None  # type: int

    def _get_data(
        self, n_clients: int, n_shards: int, non_iid: int
    ) -> Tuple[DataLoader, DataLoader]:
        """
        Args:
            n_clients (int): number of clients.
            n_shards (int): number of shards.
            non_iid (int): 0: IID, 1: Non-IID

        Returns:
            Tuple[DataLoader, DataLoader]: train_loader, test_loader
        """
        train_set = MNISTDataset(root='./data', train=True)
        test_set = MNISTDataset(root='./data', train=False)

        sampler = FederatedSampler(
            train_set, non_iid=non_iid, n_clients=n_clients, n_shards=n_shards
        )

        train_loader = DataLoader(train_set, batch_size=128, sampler=sampler)
        test_loader = DataLoader(test_set, batch_size=128)

        return train_loader, test_loader

    def _train_client(
        self, root_model: nn.Module, train_loader: DataLoader, client_idx: int
    ) -> Tuple[nn.Module, float]:
        """Train a client model.

        Args:
            root_model (nn.Module): server model.
            train_loader (DataLoader): client data loader.
            client_idx (int): client index.

        Returns:
            Tuple[nn.Module, float]: client model, average client loss.
        """
        model = copy.deepcopy(root_model)
        model.train()
        optimizer = torch.optim.SGD(
            model.parameters(), lr=self.args.lr, momentum=self.args.momentum
        )

        for epoch in range(self.args.n_client_epochs):
            epoch_loss = 0.0
            epoch_correct = 0
            epoch_samples = 0

            for idx, (data, target) in enumerate(train_loader):
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()

                logits = model(data)
                loss = F.nll_loss(logits, target)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                epoch_correct += (logits.argmax(dim=1) == target).sum().item()
                epoch_samples += data.size(0)

            # Calculate average accuracy and loss
            epoch_loss /= idx
            epoch_acc = epoch_correct / epoch_samples

            print(
                f"Client #{client_idx} | Epoch: {epoch}/{self.args.n_client_epochs} | Loss: {epoch_loss} | Acc: {epoch_acc}",
                end="\r",
            )

        return model, epoch_loss / self.args.n_client_epochs

    def train(self) -> None:
        """Train a server model."""
        train_losses = []

        for epoch in range(self.args.n_epochs):
            clients_models = []
            clients_losses = []

            # Randomly select clients
            m = max(int(self.args.frac * self.args.n_clients), 1)
            idx_clients = np.random.choice(range(self.args.n_clients), m, replace=False)

            # Train clients
            self.root_model.train()

            for client_idx in idx_clients:
                # Set client in the sampler
                self.train_loader.sampler.set_client(client_idx)

                # Train client
                client_model, client_loss = self._train_client(
                    root_model=self.root_model,
                    train_loader=self.train_loader,
                    client_idx=client_idx,
                )
                clients_models.append(client_model.state_dict())
                clients_losses.append(client_loss)

            # Update server model based on clients models
            updated_weights = average_weights(clients_models)
            self.root_model.load_state_dict(updated_weights)

            # Update average loss of this round
            avg_loss = sum(clients_losses) / len(clients_losses)
            train_losses.append(avg_loss)

            if (epoch + 1) % self.args.log_every == 0:
                # Test server model
                total_loss, total_acc = self.test()
                avg_train_loss = sum(train_losses) / len(train_losses)

                # Log results
                logs = {
                    "train/loss": avg_train_loss,
                    "test/loss": total_loss,
                    "test/acc": total_acc,
                    "round": epoch,
                }
                if total_acc >= self.target_acc and self.reached_target_at is None:
                    self.reached_target_at = epoch
                    logs["reached_target_at"] = self.reached_target_at
                    print(
                        f"\n -----> Target accuracy {self.target_acc} reached at round {epoch}! <----- \n"
                    )

                self.logger.log(logs)

                # Print results to CLI
                print(f"\n\nResults after {epoch + 1} rounds of training:")
                print(f"---> Avg Training Loss: {avg_train_loss}")
                print(
                    f"---> Avg Test Loss: {total_loss} | Avg Test Accuracy: {total_acc}\n"
                )

                # Early stopping
                if self.args.early_stopping and self.reached_target_at is not None:
                    print(f"\nEarly stopping at round #{epoch}...")
                    break

    def test(self) -> Tuple[float, float]:
        """Test the server model.

        Returns:
            Tuple[float, float]: average loss, average accuracy.
        """
        self.root_model.eval()

        total_loss = 0.0
        total_correct = 0.0
        total_samples = 0

        for idx, (data, target) in enumerate(self.test_loader):
            data, target = data.to(self.device), target.to(self.device)

            logits = self.root_model(data)
            loss = F.nll_loss(logits, target)

            total_loss += loss.item()
            total_correct += (logits.argmax(dim=1) == target).sum().item()
            total_samples += data.size(0)

        # calculate average accuracy and loss
        total_loss /= idx
        total_acc = total_correct / total_samples

        return total_loss, total_acc

In [None]:
class Arguments:
    def __init__(self):
        # 하이퍼파라미터
        self.n_client_epochs = 20          # 클라이언트 당 에포크 수
        self.non_iid = 1                   # 0: IID, 1: Non-IID
        self.lr = 0.01                     # 학습률 (옵션: 0.01, 0.001)

        # 고정 설정
        self.model_name = "cnn"            # 모델 아키텍처 ("cnn" 또는 "mlp")
        self.n_clients = 100               # 클라이언트 수
        self.n_shards = 200                # 데이터 샤드 수
        self.frac = 0.1                    # 각 라운드마다 참여하는 클라이언트의 비율
        self.n_epochs = 1000               # 총 학습 에포크 수

        # 디바이스 및 로깅
        self.device = 0                    # 디바이스 ID (예: GPU 인덱스)
        self.log_every = 1                 # 로깅 빈도 (에포크 단위)
        self.early_stopping = True         # 조기 종료 활성화 여부

        # 필수 추가 설정
        self.momentum = 0.9                # 옵티마이저의 모멘텀 값

args = Arguments()
fed_avg = FedAvg(args)
fed_avg.train()


{'train/loss': 0.0006421083043869658, 'test/loss': 2.1293575091239734, 'test/acc': 0.2726, 'round': 0}


Results after 1 rounds of training:
---> Avg Training Loss: 0.0006421083043869658
---> Avg Test Loss: 2.1293575091239734 | Avg Test Accuracy: 0.2726

{'train/loss': 0.0004090876436589497, 'test/loss': 1.8342124300125318, 'test/acc': 0.5228, 'round': 1}


Results after 2 rounds of training:
---> Avg Training Loss: 0.0004090876436589497
---> Avg Test Loss: 1.8342124300125318 | Avg Test Accuracy: 0.5228

{'train/loss': 0.00031939172990191525, 'test/loss': 1.1643322041401496, 'test/acc': 0.6168, 'round': 2}


Results after 3 rounds of training:
---> Avg Training Loss: 0.00031939172990191525
---> Avg Test Loss: 1.1643322041401496 | Avg Test Accuracy: 0.6168

{'train/loss': 0.0002773267267350416, 'test/loss': 1.1384692451892755, 'test/acc': 0.5965, 'round': 3}


Results after 4 rounds of training:
---> Avg Training Loss: 0.0002773267267350416
---> Avg Test Loss: 1.1384692451892755 | Avg T