In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transfroms
import numpy as np
from torch.utils.data import DataLoader
from collections import OrderedDict
import flwr as fl
import math

if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0') # 해당 조의 GPU 번호로 변경 ex) 1조 : cuda:1
else:
    DEVICE = torch.device('cpu')

print('Using PyTorch version:', torch.__version__, ' Device:', DEVICE)

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.dense2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.softmax(x)
        return x

def train(model, epoch, train_loader, optimizer, log_interval, loss_fn):
    model.train()
    for batch_idx, (image, label) in enumerate(train_loader):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = loss_fn(output, label)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(
                                                    epoch, batch_idx * len(image), 
                                                    len(train_loader.dataset), 100. * batch_idx / len(train_loader), 
                                                    loss.item()))
            
def evaluate(model, test_loader, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for image, label in test_loader:
            image = image.to(DEVICE)
            label = label.to(DEVICE)
            output = model(image)
            test_loss += loss_fn(output, label).item()
            prediction = output.max(1, keepdim = True)[1]
            correct += prediction.eq(label.view_as(prediction)).sum().item()
    
    test_loss /= (len(test_loader.dataset) / BATCH_SIZE)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy


train_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = True,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor() # 데이터를 0에서 255까지 있는 값을 0에서 1사이 값으로 변환
    ])
)
test_set = torchvision.datasets.MNIST(
    root = './data/MNIST',
    train = False,
    download = True,
    transform = transfroms.Compose([
        transfroms.ToTensor() # 데이터를 0에서 255까지 있는 값을 0에서 1사이 값으로 변환
    ])
)

BATCH_SIZE = 128

def list_split(arr, n):
    num = math.ceil(len(arr) / n)
    return [arr[i: i + num] for i in range(0, len(arr), num)]

num_clients = 3
x_train_list, y_train_list, x_val_list, y_val_list = map(list_split, (train_set.data, train_set.targets, test_set.data, test_set.targets), (num_clients, num_clients, num_clients, num_clients))


class MnistDataSet(torch.utils.data.Dataset):
    def __init__(self, images, labels, transforms=None):
        self.X = images
        self.y = labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, i):
        data = self.X[i, :]
        data = np.array(data).astype(np.float32).reshape(1, 28, 28)
        return (data, self.y[i])

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, tarinloader, testloader, opt, loss_fn):
        self.model = model
        self.train_loader = tarinloader
        self.test_loader = testloader
        self.optimizer = opt
        self.loss_fn = loss_fn

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters): # pytorch 모델에 파라미터를 적용하는 코드가 복잡하여 함수로 정의
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters) # 위에서 정의한 set_parameters함수를 사
        train(self.model, 1, self.train_loader, self.optimizer, 200, self.loss_fn)
        return self.get_parameters(config={}), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = evaluate(self.model, self.test_loader, self.loss_fn)
        return loss, len(self.test_loader.dataset), {"accuracy": accuracy}
    
model_fl = MLP().to(DEVICE)
criterion_fl = nn.CrossEntropyLoss().to(DEVICE)
optimizer_fl = torch.optim.Adam(model_fl.parameters())

client_num = 1

train_dataset_fl = MnistDataSet(x_train_list[client_num], y_train_list[client_num])
test_dataset_fl = MnistDataSet(x_val_list[client_num], y_val_list[client_num])
BATCH_SIZE = 128
train_loader_fl = DataLoader(train_dataset_fl, batch_size=BATCH_SIZE)
test_loader_fl = DataLoader(test_dataset_fl, batch_size=BATCH_SIZE)

flwr_client = FlowerClient(model_fl, train_loader_fl, test_loader_fl, optimizer_fl, criterion_fl)

fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=flwr_client)

INFO flwr 2023-08-25 11:36:31,197 | grpc.py:50 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-08-25 11:36:31,202 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flwr 2023-08-25 11:36:31,204 | connection.py:39 | ChannelConnectivity.CONNECTING
DEBUG flwr 2023-08-25 11:36:31,205 | connection.py:39 | ChannelConnectivity.READY


Using PyTorch version: 1.12.1+cu113  Device: cuda:0


DEBUG flwr 2023-08-25 11:36:54,630 | connection.py:113 | gRPC channel closed
INFO flwr 2023-08-25 11:36:54,631 | app.py:185 | Disconnect and shut down
