In [None]:
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import copy
import random
import time
import pandas as pd
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils, datasets
from argparse import ArgumentParser
from torchvision import transforms as tt
from torchvision.models import resnet18, ResNet18_Weights

In [None]:
def compute_mean_std(loader):
    mean = 0.0
    std = 0.0
    total_samples = 0

    
    for images, _ in loader:
        batch_samples = images.size(0)  # Batch size
        images = images
        images = images.view(batch_samples, 3, 244,244)
        mean += images.mean([0, 2, 3]) * batch_samples
        std += images.std([0, 2, 3]) * batch_samples
        total_samples += batch_samples

    mean /= total_samples
    std /= total_samples
    
    print("Mean:", mean)
    print("Std:", std)
    return mean, std

In [None]:
seed = 42

# general reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

"""## Partitioning the Data (IID and non-IID)"""

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

def iid_partition(dataset, clients):
    """
    I.I.D paritioning of data over clients
    Shuffle the data
    Split it between clients

    params:
      - dataset (torch.utils.Dataset): Dataset containing the Images
      - clients (int): Number of Clients to split the data between

    returns:
      - Dictionary of image indexes for each client
    """

    num_items_per_client = int(len(dataset) / clients)
    client_dict = {}
    image_idxs = [i for i in range(len(dataset))]

    for i in range(clients):
        client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))
        image_idxs = list(set(image_idxs) - client_dict[i])

    return client_dict


def non_iid_partition(dataset, n_nets, alpha):
    """
        :param dataset: dataset name
        :param n_nets: number of clients
        :param alpha: beta parameter of the Dirichlet distribution
        :return: dictionary containing the indexes for each client
    """
    print('non iid setup')
    y_train = np.array(dataset.targets)
    min_size = 0
    K = len(np.unique(y_train))
    N = y_train.shape[0]
    print(N)
    net_dataidx_map = {}

    while min_size < 10:
        idx_batch = [[] for _ in range(n_nets)]
        # for each class in the dataset
        for k in range(K):
            idx_k = np.where(y_train == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
            ## Balance
            proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(n_nets):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = np.array(idx_batch[j])

    # net_dataidx_map is a dictionary of length #of clients: {key: int, value: [list of indexes mapping the data among the workers}
    # traindata_cls_counts is a dictionary of length #of clients, basically assesses how the different labels are distributed among
    # the client, counting the total number of examples per class in each client.
    print('partitioj done')
    return net_dataidx_map


"""## Federated Averaging

### Local Training (Client Update)

Local training for the model on client side
"""


class CustomDataset(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

def collate_fn(batch):
    images, labels = zip(*batch)
    images = [img.expand(3, -1, -1) if img.shape[0] == 1 else img for img in images]
    return torch.stack(images), torch.tensor(labels)

class ClientUpdate(object):
    def __init__(self, dataset, batchSize, learning_rate, epochs, idxs, sch_flag):
        self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True, collate_fn=collate_fn)
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.sch_flag = sch_flag

    def train(self, model):
        criterion = nn.CrossEntropyLoss()
        # optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        # if self.sch_flag == True:
        #    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)
        # my_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        e_loss = []
        for epoch in range(1, self.epochs + 1):
            train_loss = 0.0

            model.train()
            for data, labels in self.train_loader:
                torch.cuda.empty_cache()
                if data.size()[0] < 2:
                    continue;
                torch.cuda.empty_cache()
        
                if torch.cuda.is_available():
                    model.cuda()
                    data, labels = data.cuda(), labels.cuda()
                # if torch.cuda.is_available():
                #     # data, labels = data.cuda(), labels.cuda()
                #     torch.cuda.empty_cache()
                #     data, labels = data.to(device, non_blocking=True), labels.to(device, non_blocking=True)

                # clear the gradients
                optimizer.zero_grad()
                # make a forward pass
                output = model(data)
                # calculate the loss
                loss = criterion(output, labels)
                # do a backwards pass
                loss.backward()
                # perform a single optimization step
                optimizer.step()
                # update training loss
                train_loss += loss.item() * data.size(0)
                # if self.sch_flag == True:
                #  scheduler.step(train_loss)
            # average losses
            train_loss = train_loss / len(self.train_loader.dataset)
            e_loss.append(train_loss)

            # self.learning_rate = optimizer.param_groups[0]['lr']

        total_loss = sum(e_loss) / len(e_loss)

        return model.state_dict(), total_loss


"""### Server Side Training

Following Algorithm 1 from the paper
"""


def training(model, rounds, batch_size, lr, ds, data_dict, C, K, E, plt_title, plt_color, cifar_data_test,
             test_batch_size, criterion, num_classes, classes_test, sch_flag):
    """
    Function implements the Federated Averaging Algorithm from the FedAvg paper.
    Specifically, this function is used for the server side training and weight update

    Params:
      - model:           PyTorch model to train
      - rounds:          Number of communication rounds for the client update
      - batch_size:      Batch size for client update training
      - lr:              Learning rate used for client update training
      - ds:              Dataset used for training
      - data_dict:       Type of data partition used for training (IID or non-IID)
      - C:               Fraction of clients randomly chosen to perform computation on each round
      - K:               Total number of clients
      - E:               Number of training passes each client makes over its local dataset per round
      - tb_writer_name:  Directory name to save the tensorboard logs
    Returns:
      - model:           Trained model on the server
    """

    # global model weights
    global_weights = model.state_dict()

    # training loss
    train_loss = []
    test_loss = []
    test_accuracy = []
    best_accuracy = 0
    # measure time
    start = time.time()

    for curr_round in range(1, rounds + 1):
        w, local_loss = [], []
        # Retrieve the number of clients participating in the current training
        m = max(int(C * K), 1)
        # Sample a subset of K clients according with the value defined before
        S_t = np.random.choice(range(K), m, replace=False)
        # For the selected clients start a local training
        for k in S_t:
            torch.cuda.empty_cache()  # Free up unused memory

            # Compute a local update
            local_update = ClientUpdate(dataset=ds, batchSize=batch_size, learning_rate=lr, epochs=E, idxs=data_dict[k],
                                        sch_flag=sch_flag)
            # Update means retrieve the values of the network weights
            weights, loss = local_update.train(model=copy.deepcopy(model))

            w.append(copy.deepcopy(weights))
            local_loss.append(copy.deepcopy(loss))
        lr = 0.99*lr
        # updating the global weights
        weights_avg = copy.deepcopy(w[0])
        for k in weights_avg.keys():
            for i in range(1, len(w)):
                weights_avg[k] += w[i][k]

            weights_avg[k] = torch.div(weights_avg[k], len(w))

        global_weights = weights_avg

        # if curr_round == 200:
        #     lr = lr / 2
        #     E = E - 1

        # if curr_round == 300:
        #     lr = lr / 2
        #     E = E - 2

        # if curr_round == 400:
        #     lr = lr / 5
        #     E = E - 3

        # move the updated weights to our model state dict
        model.load_state_dict(global_weights)

        # loss
        loss_avg = sum(local_loss) / len(local_loss)
        train_loss.append(loss_avg)

        t_accuracy, t_loss = testing(model, cifar_data_test, test_batch_size, criterion, num_classes, classes_test)
        test_accuracy.append(t_accuracy)
        test_loss.append(t_loss)

        if best_accuracy < t_accuracy:
            best_accuracy = t_accuracy
        # torch.save(model.state_dict(), plt_title)
        print(f"Round {curr_round}, loss_avg: {loss_avg}, t_loss: {t_loss}, test_acc: {test_accuracy[0]}, best_acc: {best_accuracy}")
        # print(curr_round, loss_avg, t_loss, test_accuracy[0], best_accuracy)
        # print('best_accuracy:', best_accuracy, '---Round:', curr_round, '---lr', lr, '----localEpocs--', E)

    end = time.time()
    plt.rcParams.update({'font.size': 8})
    fig, ax = plt.subplots()
    x_axis = np.arange(1, rounds + 1)
    y_axis1 = np.array(train_loss)
    y_axis2 = np.array(test_accuracy)
    y_axis3 = np.array(test_loss)

    ax.plot(x_axis, y_axis1, 'tab:' + 'green', label='train_loss')
    ax.plot(x_axis, y_axis2, 'tab:' + 'blue', label='test_accuracy')
    ax.plot(x_axis, y_axis3, 'tab:' + 'red', label='test_loss')
    ax.legend(loc='upper left')
    ax.set(xlabel='Number of Rounds', ylabel='Train Loss',
           title=plt_title)
    ax.grid()
    # fig.savefig(plt_title+'.jpg', format='jpg')
    print("Training Done!")
    print("Total time taken to Train: {}".format(end - start))

    return model



class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
                                 eps=1e-5, affine=True)

    def forward(self, x):
        x = self.norm(x)
        return x


"""## Testing Loop"""


def testing(model, dataset, bs, criterion, num_classes, classes):
    # test loss
    test_loss = 0.0
    correct_class = list(0. for i in range(num_classes))
    total_class = list(0. for i in range(num_classes))

    test_loader = DataLoader(dataset, batch_size=bs, collate_fn=collate_fn)
    l = len(test_loader)
    model.eval()
    print("testing")
    with torch.no_grad():
        for data, labels in test_loader:
            torch.cuda.empty_cache()
    
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()
    
            output = model(data)
            loss = criterion(output, labels)
            test_loss += loss.item() * data.size(0)
    
            _, pred = torch.max(output, 1)
    
            correct_tensor = pred.eq(labels.data.view_as(pred))
            correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(
                correct_tensor.cpu().numpy())
    
            # test accuracy for each object class
            for i in range(num_classes):
                label = labels.data[i]
                correct_class[label] += correct[i].item()
                total_class[label] += 1

    # avg test loss
    test_loss = test_loss / len(test_loader.dataset)

    return 100. * np.sum(correct_class) / np.sum(total_class), test_loss


# if __name__ == '__main__':

    # parser = ArgumentParser()
    # parser.add_argument('--norm', default="bn")
    # parser.add_argument('--partition', default="noniid")
    # parser.add_argument('--client_number', default=100)
    # parser.add_argument('--alpha_partition', default=0.5)
    # parser.add_argument('--commrounds', type=int, default=100)
    # parser.add_argument('--clientfr', type=float, default=0.1)
    # parser.add_argument('--numclient', type=int, default=100)
    # parser.add_argument('--clientepochs', type=int, default=20)
    # parser.add_argument('--clientbs', type=int, default=64)
    # parser.add_argument('--clientlr', type=float, default=0.001)
    # parser.add_argument('--sch_flag', default=False)

    # args = parser.parse_args()
    

class Args:
    def __init__(self):
        self.norm = "bn"
        self.partition = "noniid"
        self.client_number = 20
        self.alpha_partition = 1
        self.commrounds = 30
        self.clientfr = 0.25
        self.numclient = 20
        self.clientepochs = 10
        self.clientbs = 16
        self.clientlr = 0.001
        self.sch_flag = False

args = Args()


# classes = np.array(list(cifar_data_train.class_to_idx.values()))
# classes_test = np.array(list(cifar_data_test.class_to_idx.values()))
# num_classes = len(classes_test)


# custom dataset class for knee x_ray dataset
import os
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    
    def __init__(self, annotations_file, img_dir, transform, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self,idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform((image.float()/255.0))
        if self.target_transform:
            label = self.target_transform(label)
        # return image.permute(1,2,0), label
        return image, label

    @property
    def targets(self):
        return list(self.img_labels.values[:, 1])


# Compute mean and std
# mean, std = compute_mean_std(loader)
normalisation_stats = ((0.2654), (0.2872)) # ( (mean), (std) )


transform = transforms.Compose([
    # transforms.Resize((265,265)),
    # transforms.CenterCrop((512,240)),
    transforms.CenterCrop((512,244)),
    transforms.Normalize(*normalisation_stats),
    transforms.Grayscale(num_output_channels=3)
])

train_annotation_file_path = "/kaggle/input/knee-xray-split-dataset/knee_xray_split_dataset/train/train_annotations.csv"
test_annotation_file_path = "/kaggle/input/knee-xray-split-dataset/knee_xray_split_dataset/test/test_annotations.csv"
train_img_file_path = "/kaggle/input/knee-xray-split-dataset/knee_xray_split_dataset/train"
test_img_file_path = "/kaggle/input/knee-xray-split-dataset/knee_xray_split_dataset/test"


'''
This is the mixed dataset that had test and train images mixed, so results are invalid

# train_annotation_file_path = "/kaggle/input/large-knee-xray-data/large-knee-xray/train/annotations.csv"
# test_annotation_file_path = "/kaggle/input/large-knee-xray-data/large-knee-xray/train/annotations.csv"
# train_img_file_path = "/kaggle/input/large-knee-xray-data/large-knee-xray/train"
# test_img_file_path = "/kaggle/input/large-knee-xray-data/large-knee-xray/train"
'''

xray_train = CustomImageDataset(train_annotation_file_path,
                                      train_img_file_path,
                                      transform=transform)
xray_test = CustomImageDataset(test_annotation_file_path,
                                      test_img_file_path,
                                      transform=transform)

# xray_train = dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms.ToTensor())


classes = np.array([0,1])
classes_test = np.array([0,1])
num_classes = len(classes_test)


criterion = nn.CrossEntropyLoss()

# Hyperparameters_List (H) = [rounds, client_fraction, number_of_clients, number_of_training_rounds_local, local_batch_size, lr_client]
H = [args.commrounds, args.clientfr, args.numclient, args.clientepochs, args.clientbs, args.clientlr]

if args.partition == 'noniid':
    # (dataset, clients, total_shards, shards_size, num_shards_per_client):
    # alpha for the Dirichlet distribution
    print('creating noniid partition')
    data_dict = non_iid_partition(xray_train, args.client_number, float(args.alpha_partition))
else:
    print('creating iid partition')
    data_dict = iid_partition(xray_train, 100)  # Uncomment for idd_partition

# if args.norm == 'gn':
#     cifar_cnn = resnet.ResNet(resnet.Bottleneck, [3, 4, 6, 3], num_classes=10, zero_init_residual=False, groups=1,
#                               width_per_group=64, replace_stride_with_dilation=None, norm_layer=MyGroupNorm)
# else:
#     cifar_cnn = resnet.ResNet(resnet.Bottleneck, [3, 4, 6, 3], num_classes=10, zero_init_residual=False, groups=1,
#                               width_per_group=64, replace_stride_with_dilation=None)




model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
# model = CustomVGG19(num_classes=2)

print('model created')

device=get_device()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = torch.nn.DataParallel(model)
to_device(model, device)

plot_str = args.partition + '_' + args.norm + '_' + 'comm_rounds_' + str(args.commrounds) + '_clientfr_' + str(
    args.clientfr) + '_numclients_' + str(args.numclient) + '_clientepochs_' + str(
    args.clientepochs) + '_clientbs_' + str(args.clientbs) + '_clientLR_' + str(args.clientlr)
print(plot_str)

trained_model = training(model, H[0], H[4], H[5], xray_train, data_dict, H[1], H[2], H[3], plot_str,
                         "green", xray_test, 128, criterion, num_classes, classes_test, args.sch_flag)


In [None]:
import torch

# Assuming `model` is your trained model
torch.save(trained_model, 'model.pth')

# To load it back:
# model = 'model.pth'
# model = '/kaggle/input/96pc/pytorch/default/1/96pc.pth'
model = trained_model
# model = torch.load(model, weights_only=False)
model.eval()  # Set to evaluation mode if needed

# Assuming you have a test dataset
test_loader = DataLoader(xray_test, batch_size=32, shuffle=False, collate_fn=collate_fn)  # Adjust batch size as needed
correct = 0
total = 0
model.to(device)

with torch.no_grad():  # No need to track gradients during evaluation
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)  # Forward pass
        _, predicted = torch.max(outputs, 1)  # Get class with highest probability

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

