In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import List, Dict, Tuple, Optional
import copy
from datasets import load_dataset
from PIL import Image
import random
import json
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner,PathologicalPartitioner
from flwr_datasets.preprocessor import Divider


In [None]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
print(device)

First we just load and visualize the data from FEMNIST

In [None]:
dataset = load_dataset("flwrlabs/femnist")

# partitioner = DirichletPartitioner(num_partitions=10, partition_by="character",
#                                    alpha=0.5, min_partition_size=10,
#                                    self_balancing=False,seed=SEED)

# fds = FederatedDataset(dataset="flwrlabs/femnist", partitioners={"train": partitioner})

# partitioner = PathologicalPartitioner(
#     num_partitions=62,
#     partition_by="character",
#     num_classes_per_partition=1,
#     class_assignment_mode="first-deterministic"
# )

# fds = FederatedDataset(dataset="flwrlabs/femnist",preprocessor=Divider(divide_config={"train": 0.8,"test": 0.2,}, divide_split="train",), partitioners={"train": partitioner})

In [None]:
# print(set(fds.load_partition(partition_id=0)[i]['character'] for i in range(len(fds.load_partition(partition_id=0)))))
# lens=[[0]*62 for _ in range(100)]

# for _ in range(100):
#     partition = fds.load_partition(partition_id=_)
#     for i in range(len(partition)):
#         lens[_][partition[i]["character"]]+=1


In [None]:
print(fds.load_partition(split="train",partition_id=9))

In [None]:
print(dataset)
print(60*"=")
train_ds = dataset["train"]
print(f"We have {len(train_ds)} data samples.")
example = train_ds[0]
fig,axs=plt.subplots(10,10, tight_layout = True)
for i in range(10):
    for j in range(10):
        axs[i,j].imshow(train_ds[i*10+j]["image"], cmap='gray')
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])
        axs[i,j].text(0.2, 0.2, train_ds[i*10+j]["character"], horizontalalignment='center', verticalalignment='center', transform=axs[i,j].transAxes)
print("The keys of dictionary are-\n",example.keys())           

In [None]:
split = train_ds.train_test_split(test_size=0.2, seed=42)

train_data = split["train"]
test_data = split["test"]  

print(f"Samples in train - {len(train_data)}, Samples in test - {len(test_data)}")

<h3> Task1</h3>
Now the task is to give the training samples to each client<br>
We will do this by partitioning the data using a dirichlet prior<br>
Dirichlet acts like a prior over <b>class proportions per client.</b>



In [None]:
# num_classes = len(train_data.features["character"].names)
num_classes=62
num_clients = 62
print(num_classes, num_clients)

# labels = np.array(train_data["character"])

In [None]:
alpha = 0.01 #larger value promotes uniform distribution smaller alpha means skewed distribution

# Grouping indices with the same class together
indices_by_class=[np.where(labels==i)[0] for i in range(num_classes)]

# client_data[c][i] = indices for class c assigned to client i
client_data = dict()
print("Number of data samples per class")
for i in range(num_classes):
    print(f"{i}->{len(indices_by_class[i])}")
# print([len(x) for x in indices_by_class])

print("Distribution of classes per client")
for c in range(num_classes):
    client_data[c]=[]
    samples_in_class = indices_by_class[c]
    num_samples_in_class = len(samples_in_class)
    if num_samples_in_class == 0:
        continue

    # Dirichlet proportions for this class across clients
    p = np.random.dirichlet(alpha * np.ones(num_clients))

    # Integer client counts
    counts = np.random.multinomial(num_samples_in_class, p)
    prev=0
    for i in range(num_clients):
        client_data[c].append(indices_by_class[c][prev:prev+counts[i]])
        prev+=counts[i]
    # # OPTIONAL: Print summary
    print(f"Class {c}: {[chunk.shape[0] for chunk in client_data[c]]}")



In [None]:
print("Per class label distribution across clients plots")

for c in range(num_classes):
    assert (c in client_data)
    chunks = client_data[c]  # list of indice arrays for class c for each client
    
    counts = np.array([len(chunk) for chunk in chunks], dtype=int)
    total = counts.sum()
    assert(total==indices_by_class[c].shape[0])

    print(
        f"Class{c:2d} |Total={total:5d} |"
        f"min={counts.min():4d} |max={counts.max():6d} |"
        f"mean={counts.mean():7.2f} |std={counts.std():8.2f} |"
    )


In [None]:
counts_list = []
classes_to_plot=list(range(5))
n_plot=5
for c in range(5):
    per_client_data = client_data[c]
    counts_list.append(np.array([len(client) for client in per_client_data]))
    
fig, axes = plt.subplots(1, 5, figsize=(6 * n_plot, 4), sharey=True)
if n_plot == 1:
    axes = [axes]

for ax, c, counts in zip(axes, classes_to_plot, counts_list):
    frac = counts / sum(counts) 
    ax.bar(np.arange(len(frac)), frac)
    ax.set_title(f"Class {c}")
    ax.set_xlabel("Client ID")
    ax.set_xticks(np.arange(len(frac)))
    ax.set_ylim(0, 1.0)                   # same y-axis for all

axes[0].set_ylabel("Fraction of this class")
plt.tight_layout()
plt.show()


In [None]:
class FemnistCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        # Input: (N, 1, 28, 28)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2, stride=2)  # halves H and W
        self.dropout2d = nn.Dropout2d(p=0.25)

        # After two pools: 28 -> 14 -> 7, channels = 64
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, num_classes)

        nn.init.kaiming_normal_(self.conv1.weight)
        nn.init.kaiming_normal_(self.conv2.weight)
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.zeros_(self.conv1.bias)
        nn.init.zeros_(self.conv2.bias)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)

    
    def forward(self, x):
        # x: (N, 1, 28, 28)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)       # (N, 32, 14, 14)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)       # (N, 64, 7, 7)
        x = self.dropout2d(x)

        x = torch.flatten(x, 1)  # (N, 64*7*7)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)          # (N, num_classes)
        return x
    
model = FemnistCNN(num_classes).to(device)
print(model)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
test_input = torch.randn(4, 1, 28, 28).to(device)
test_output = model(test_input)

print(f"\nTest input shape: {test_input.shape}")
print(f"Test output shape: {test_output.shape}")

assert(test_output.shape==(4,num_classes))

In [None]:
# Build per-client index lists from the per-class Dirichlet splits (client_data)

client_indices = [np.hstack([client_data[i][j] for i in range(num_classes)]) for j in range(num_clients)]

#Shuffle indices within each client for randomness
for cid in range(num_clients):
    np.random.shuffle(client_indices[cid])

# Sanity check
assigned = [len(idxs) for idxs in client_indices]
print("Distribution of data samples per client is", assigned)
print(f"Total samples assigned to clients: {sum(assigned)}")
print(f"Train set size: {len(train_data)}")

assert(sum(assigned)==len(train_data))
print(np.asarray(train_data[0]["image"]).shape)

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(),   
    transforms.ToTensor(),  
])

class Custom_Dataset(Dataset):
    def __init__(self, dataset, indices, transform=transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        item = self.dataset[self.indices[idx]]
        img = item["image"]       # PIL image
        label = item["character"]
        if self.transform is not None:
            img = self.transform(img)  # -> (1, 28, 28) tensor
        return (img, label)

class client():
    def __init__(self,client_id:int,dataset:Custom_Dataset, model:FemnistCNN, optimizer:str, lr:float):
        self.client_id=client_id
        self.dataset=dataset
        self.dataloader=DataLoader(self.dataset, batch_size=640, shuffle=True)
        self.model=model
        self.lr=lr
        self.name=optimizer.lower()
        self.optimizer=self._get_optimizer(optimizer.lower())
        self.losses=[]
        self.fedprox=False
        self.u=1
        
    def _get_optimizer(self, name):
        params=self.model.parameters()
        if name=="sgd":
            return optim.SGD(params, lr=self.lr)
        elif name == "sgd_momentum":
            return optim.SGD(params, lr=self.lr, momentum=0.9, nesterov=False)
        elif name == "adam":
            return optim.Adam(params)
        elif name == "adamw":
            return optim.AdamW(params)
        elif name == "rmsprop":
            return optim.RMSprop(params)
        elif name == "radam":
            return optim.RAdam(params)
        elif name=="nadam":
            return optim.NAdam(params)
        elif name=="fedprox":
            self.fedprox=True
            return optim.SGD(params, lr=self.lr)
        else:
            raise ValueError(f"Unknown optimizer: {name}")
    def train(self,epochs, status=False):
        # for _ in range(epochs):
        # print(len(self.dataloader))
        i=0
        # items=tqdm(iter(self.dataloader), ", leave=False )
        w_ = {name: param.clone() for name, param in self.model.state_dict().items()}
        self.model.train()
        for x,y in iter(self.dataloader):
            x=x.to(device)
            y=y.to(device)
            i+=1
            if (status and (i%5==0)):
                print(f"Client {self.client_id} | Epoch {i}/{epochs}")
            outputs=self.model(x)
            loss=torch.nn.CrossEntropyLoss()(outputs,y)
            if (self.fedprox):
                for name,param in w_.items():
                    loss+=self.u*((param-self.model.state_dict()[name])**2).sum()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.losses.append(loss.item())
            if i==epochs:
                break

    def send_to_server(self):
        state = {name: param.clone() for name, param in self.model.state_dict().items()}
        n = len(self.dataset)
        return (state,n)

    def receive_from_server(self, param_dict):
        state={name: param.clone() for name, param in param_dict.items()}
        self.model.load_state_dict(state)

    def set_learning_rate(self, lr):
        self.optimizer.lr=lr
class server():
    def __init__(self, clients:[client], dataset:Custom_Dataset, model, optimizer:str, lr:float):
        self.clients=clients
        self.optimizer=optimizer.lower()
        self.model=model
        self.dataset=dataset
        self.dataloader=DataLoader(self.dataset, batch_size=1280, shuffle=True)
        self.X=None
        self.Y=None
        self.lr=lr
        self.losses=[]
        self.predictions_correct=[]
        self.local_rounds=1000
        self.path=f"c1_{self.clients[0].name}"
        self.optimizer_dict=None
    def global_round(self):
        for client in self.clients:
            client.receive_from_server(self.model.state_dict())   
        #put here random indices for client sampling
        for client in self.clients:
            client.train(self.local_rounds,True) #number of epochs
        s=0
        client_weights=[]
        for client in self.clients:
            client_weights.append(client.send_to_server())
            s+=client_weights[-1][1]

        if self.optimizer=="sgd":
            new_weight_dict={k:torch.zeros_like(v) for k,v in client_weights[0][0].items()}
            for i in range(len(client_weights)):
                client_dict=client_weights[i][0]
                client_samples=client_weights[i][1]
                for k,v in client_dict.items():
                    new_weight_dict[k]+=v*client_samples/s
            self.model.load_state_dict(new_weight_dict)
            
        elif self.optimizer=="adam":
            if (self.optimizer_dict is None):
                self.optimizer_dict={}
            self.optimizer_dict.setdefault('learning_rate', 1e-1)
            self.optimizer_dict.setdefault('beta1', 0.9)
            self.optimizer_dict.setdefault('beta2', 0.999)
            self.optimizer_dict.setdefault('epsilon', 1e-8)
            self.optimizer_dict.setdefault('m', {k:torch.zeros_like(v) for k,v in client_weights[0][0].items()})
            self.optimizer_dict.setdefault('v', {k:torch.zeros_like(v) for k,v in client_weights[0][0].items()})
            self.optimizer_dict.setdefault('t', 0)
            new_weight_dict={k:torch.zeros_like(v) for k,v in client_weights[0][0].items()}
            for i in range(len(client_weights)):
                client_dict=client_weights[i][0]
                client_samples=client_weights[i][1]
                for k,v in client_dict.items():
                    new_weight_dict[k]+=v*client_samples/len(self.dataset)
            self.optimizer_dict['t']+=1
            for k,v in client_weights[0][0].items():
                w=self.model.state_dict()[k]
                new_weight_dict[k]-=w
                m=self.optimizer_dict['m']
                v=self.optimizer_dict['v']
                m[k]=self.optimizer_dict['beta1']*m[k]+(1-self.optimizer_dict['beta1'])*new_weight_dict[k]
                v[k]=self.optimizer_dict['beta2']*v[k]+(1-self.optimizer_dict['beta2'])*new_weight_dict[k]*new_weight_dict[k]
                m_ub=m[k]/(1-self.optimizer_dict['beta1']**self.optimizer_dict['t'])
                v_ub=v[k]/(1-self.optimizer_dict['beta2']**self.optimizer_dict['t'])
                next_w=w+self.optimizer_dict['learning_rate']*m_ub/(self.optimizer_dict['epsilon']+v_ub.sqrt())
                new_weight_dict[k]=next_w
            self.model.load_state_dict(new_weight_dict)
        else:
            raise ValueError(f"Unknown global optimizer: {name}")
            
    def train(self, rounds):
        for i in range(rounds):
            print(f"Running round {i+1}/{rounds} on server.")
            self.global_round()
            self.evaluate(50)
            print(f"Loss:{self.losses[-1]}, Accuracy: {self.predictions_correct[-1]*100}%")
            self.save_state()
        self.evaluate(len(self.dataset)//1280+1)
        print(f"Final Loss:{self.losses[-1]}, Accuracy: {self.predictions_correct[-1]*100}%")
        
    def evaluate(self,batches):
        self.model.eval()
        loss=0
        correct=0
        total=0
        i=0
        for x,y in iter(self.dataloader):
            x=x.to(device)
            y=y.to(device)
            i+=1
            outputs=self.model(x)
            loss_=torch.nn.CrossEntropyLoss()(outputs,y)
            loss+=loss_.item()
            correct+=(outputs.argmax(axis=1)==y).sum()
            total+=y.shape[0]
            if (i==batches):
                break
        self.losses.append(loss)
        self.predictions_correct.append((correct/total).item())

    def save_state(self):
        torch.save(self.model.state_dict(), f"./{self.path}.pt")
        with open(f"./vs-code/losses_{self.path}.json", 'w') as file:
            file.write(json.dumps(self.losses))
        with open(f"./vs-code/training_accuracy_{self.path}.json", 'w') as file:
            file.write(json.dumps([a for a in self.predictions_correct]))
    def load_state(self):
        self.model.load_state_dict(torch.load(f"./{self.path}.pt", weights_only=True))
        with open(f"./vs-code/losses_{self.path}.json", 'r') as file:
            self.losses=json.loads(file.read())
        with open(f"./vs-code/training_accuracy_{self.path}.json", 'r') as file:
            self.predictions_correct=json.loads(file.read())

In [None]:
clients=[client(i+1,Custom_Dataset(fds.load_partition(split="train",partition_id=i),list(range(len(fds.load_partition(split="train",partition_id=i))))),FemnistCNN(num_classes).to(device),"sgd_momentum",0.001) for i in range(num_clients)]
myServer=server(clients,Custom_Dataset(fds.load_split("test"),list(range(len(fds.load_split("test"))))),FemnistCNN(num_classes).to(device),"sgd",0.01)

In [None]:
## try:
#     myServer.load_state()
#     myServer.train(1000)
# except KeyboardInterrupt:
#     myServer.save_state()
# for client in clients:
#     client.lr=0.1
#     client._get_optimizer("adam")

myServer.train(150)

In [None]:
# myServer.save_state()
# for client in clients:
#     client.lr=0.1
#     client._get_optimizer("fedprox")

In [None]:
# # %% [code]
# import pandas as pd
# import matplotlib.pyplot as plt

# results_csv = "fedavg_results_with_trainloss.csv"
# df = pd.read_csv(results_csv)

# rounds = df["round"]

# # --- Training Loss vs Rounds ---
# plt.figure(figsize=(6,4))
# plt.plot(rounds, df["train_loss"], linewidth=2)
# plt.title("Training Loss vs Rounds")
# plt.xlabel("Round")
# plt.ylabel("Training Loss")
# plt.grid(True)
# plt.show()

# # --- Test Loss vs Rounds ---
# plt.figure(figsize=(6,4))
# plt.plot(rounds, df["global_loss"], linewidth=2)
# plt.title("Test Loss vs Rounds")
# plt.xlabel("Round")
# plt.ylabel("Test Loss")
# plt.grid(True)
# plt.show()

# # --- Test Accuracy vs Rounds ---
# plt.figure(figsize=(6,4))
# plt.plot(rounds, df["global_acc"], linewidth=2)
# plt.title("Test Accuracy vs Rounds")
# plt.xlabel("Round")
# plt.ylabel("Accuracy")
# plt.grid(True)
# plt.show()