In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from typing import Dict
import math
from tqdm.std import tqdm
from sklearn.model_selection import train_test_split
from matplotlib import pyplot
import torch
import math
import torch.nn.functional as F
from torch import nn
from typing import Tuple, List
import torch
from torch import nn, Tensor
import torchvision
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader,RandomSampler
from sklearn.metrics import accuracy_score,recall_score,f1_score,precision_score,roc_auc_score
from torchvision.transforms import transforms, ToTensor,Resize
from matplotlib import pyplot as plt
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import glob
import os.path as osp
import cv2
from PIL import Image
from torchvision.transforms import Compose,ToTensor,Resize,RandomErasing,RandomHorizontalFlip,RandomVerticalFlip

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
total_epochs= 10
number_of_samples= 5
learning_rate = 0.01
numEpoch = 4
batch_size = 32
momentum = 0.9
print_amount= 2

class Net2nn(nn.Module):
    def __init__(self):
        super(Net2nn, self).__init__()
        self.fw= torchvision.models.efficientnet_b0(pretrained=False)
        self.fw.features[0][0]=nn.Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.fw.classifier[1]=nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(1280, 4),
        )
    def forward(self,x):
        return self.fw(x)

def train(model, train_loader, criterion, device, optimizer):

    model=model.to(device)
    model.train()
    train_loss = 0.0
    correct = 0

    for data, target in tqdm(train_loader, desc="Training"):
        data,target = data.to(device),target.to(device)


        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        prediction = output.argmax(dim=1, keepdim=True)
        correct += prediction.eq(target.view_as(prediction)).sum().item()

    return train_loss / len(train_loader), correct/len(train_loader.dataset)

def validation(
    model: nn.Module,
    test_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    relp: bool = False
) -> Tuple[float, float, Tensor, Tensor, Tensor]:
    """
    Validate the neural network model on test data.

    Parameters:
    - model (nn.Module): The neural network model to validate.
    - test_loader (DataLoader): The DataLoader for test data.
    - criterion (nn.Module): The loss function.
    - device (torch.device): The device to run the validation on.
    - relp (bool): Flag to return predictions and scores.

    Returns:
    - Tuple containing test loss, accuracy, labels, predictions, and scores.
    """
    model.eval()
    model = model.to(device)
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []
    all_scores = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            score = torch.softmax(output, dim=1)
            all_scores.append(score)
            prediction = score[:,1]
            all_preds.append(prediction)
            all_labels.append(target)
            threshold = 0.65
            prediction_class = (prediction >= threshold).int()
            correct += prediction_class.eq(target.int()).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset)

    if relp:
        return test_loss, accuracy, torch.cat(all_labels, dim=0), torch.cat(all_preds, dim=0), torch.cat(all_scores, dim=0)
    return test_loss, accuracy


def validation(model, test_loader, criterion,relp=False):
    model.eval()
    model=model.to(device)
    test_loss = 0.0
    correct = 0
    labels = []
    preds= []
    scores=[]
    with torch.no_grad():
        for data, target in test_loader:
        # for data, target in tqdm(test_loader, desc="Testing"):

            data,target = data.to(device),target.to(device)
            output = model(data)

            test_loss += criterion(output, target).item()
            score=torch.softmax(output,dim=1)

            scores.append(score)
            prediction = output.argmax(dim=1, keepdim=True)
            preds = preds + prediction.cpu().detach().tolist()
            labels = labels + target.cpu().detach().tolist()
            correct += prediction.eq(target.view_as(prediction)).sum().item()

    test_loss /= len(test_loader)
    correct /= len(test_loader.dataset)
    if relp:
        return (test_loss, correct, labels, preds,torch.cat(scores,dim=0).cpu().numpy())
    return (test_loss, correct)


def create_model_optimizer_criterion_dict(number_of_samples):
    model_dict = dict()
    optimizer_dict= dict()
    criterion_dict = dict()

    for i in range(number_of_samples):
        model_name="model"+str(i)
        model_info=Net2nn()
        model_dict.update({model_name : model_info })

        optimizer_name="optimizer"+str(i)
        optimizer_info = torch.optim.SGD(model_info.parameters(), lr=learning_rate, momentum=momentum)
        optimizer_dict.update({optimizer_name : optimizer_info })

        criterion_name = "criterion"+str(i)
        criterion_info = nn.CrossEntropyLoss()
        criterion_dict.update({criterion_name : criterion_info})

    return model_dict, optimizer_dict, criterion_dict


def get_averaged_weights(main_model:nn.Module,model_dict:dict,number_of_samples:int):
    state_dict={}
    for name ,param in tqdm(main_model.state_dict().items()):
        tmp_weights = torch.zeros(param.shape)
        for i in range(number_of_samples):
            tmp_weights = tmp_weights + model_dict['model'+str(i)].state_dict()[name].cpu()
        tmp_weights = tmp_weights/number_of_samples
        state_dict[name]= tmp_weights
    main_model.load_state_dict(state_dict)

def send_main_model_to_nodes_and_update_model_dict(main_model:torch.nn.Module, model_dict:Dict[str,nn.Module], number_of_samples):
    for i in tqdm(range(number_of_samples)):
        model_dict["model"+str(i)].load_state_dict(main_model.state_dict())

def start_train_end_node_process_without_print(number_of_samples):
    for i in range (number_of_samples):

        model=model_dict['model'+str(i)]
        criterion=criterion_dict['criterion'+str(i)]
        optimizer=optimizer_dict['optimizer'+str(i)]

        for epoch in range(numEpoch):
            train_loss, train_accuracy = train(model, train_dl, criterion, optimizer)
            test_loss, test_accuracy = validation(model, test_dl, criterion)

def start_train_end_node_process_print_some(number_of_samples, print_amount, device):

    for i in range (number_of_samples):

        model=model_dict['model'+str(i)]
        criterion=criterion_dict['criterion'+str(i)]
        optimizer=optimizer_dict['optimizer'+str(i)]

        if i<print_amount:
            print("Subset" ,i)

        for epoch in range(numEpoch):

            train_loss, train_accuracy = train(model, train_dl, criterion, device, optimizer)
            test_loss, test_accuracy = validation(model, test_dl, criterion, device, relp=False)

            if i<print_amount:
                print("epoch: {:3.0f}".format(epoch+1) + " | train accuracy: {:7.5f}".format(train_accuracy) + " | test accuracy: {:7.5f}".format(test_accuracy))

class Data(Dataset):

    def __init__(self,train=False) -> None:
        super().__init__()

        self.files1 = glob.glob('/content/bloodata/bloodata/BCC/benign/*.jpg', recursive=True)
        self.files2 = glob.glob('/content/bloodata/bloodata/BCC/pre_b/*.jpg', recursive=True)
        self.files3 = glob.glob('/content/bloodata/bloodata/BCC/pro_b/*.jpg', recursive=True)
        self.files4 = glob.glob('/content/bloodata/bloodata/BCC/early_pre_b/*.jpg', recursive=True)

        if train:
            self.files1=self.files1[:int(len(self.files1)*0.8)]
            self.files2=self.files2[:int(len(self.files2)*0.8)]
            self.files3=self.files3[:int(len(self.files3)*0.8)]
            self.files4=self.files4[:int(len(self.files4)*0.8)]
            self.files = self.files1 + self.files2 + self.files3 + self.files4
            self.transform = Compose([ToTensor(),Resize((50,50)),RandomVerticalFlip(),RandomErasing()])
        else:
            self.files1=self.files1[int(len(self.files1)*0.8):]
            self.files2=self.files2[int(len(self.files2)*0.8):]
            self.files3=self.files3[int(len(self.files3)*0.8):]
            self.files4=self.files4[int(len(self.files4)*0.8):]

            self.files = self.files1 + self.files2 + self.files3 + self.files4
            self.transform = Compose([ToTensor(),Resize((50,50))])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        file = self.files[index]
        label = file.split('/')
        if label[3] == 'benign':
            label = 0
        elif label[3] == 'pre_b':
            label = 1
        elif label[3] == 'pro_b':
            label = 2
        elif label[3] == 'early_pre_b':
            label = 3
        img = Image.open(file)

        return self.transform(img),label

def create_data_load(batch,drop_last,shuffle):
    train_data=Data(train=True)
    train_dataload = DataLoader(train_data,batch_size=batch,shuffle=shuffle,drop_last=drop_last)

    test_data=Data(train=False)
    test_dataload = DataLoader(test_data,batch_size=batch,shuffle=shuffle,drop_last=drop_last)

    return train_dataload,test_dataload

train_dl,valid_dl = create_data_load(batch_size,drop_last=True, shuffle=True)
test_dl = valid_dl

main_model = Net2nn()
main_optimizer = torch.optim.SGD(main_model.parameters(), lr=learning_rate, momentum=0.9)
main_criterion = nn.CrossEntropyLoss()

model_dict, optimizer_dict, criterion_dict = create_model_optimizer_criterion_dict(number_of_samples)

send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)

from sklearn.metrics import roc_curve, auc, precision_score, recall_score, f1_score, accuracy_score
import numpy as np
import matplotlib.pyplot as plt

def compute_metrics(scores, labels, threshold=0.5):
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

    predictions = [int(score >= threshold) for score in scores]

    p = precision_score(labels, predictions, average='macro')
    r = recall_score(labels, predictions, average='macro')
    f1 = f1_score(labels, predictions, average='macro')
    acc = accuracy_score(labels, predictions)

    return p, r, f1, acc, roc_auc
start_train_end_node_process_print_some(number_of_samples, print_amount, device)

get_averaged_weights(main_model,model_dict,number_of_samples)

for epoch_ in range(total_epochs):
    send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)
    start_train_end_node_process_print_some(number_of_samples, print_amount)
    get_averaged_weights(main_model,model_dict,number_of_samples)

    test_loss, correct, label, preds, scores = validation(main_model, test_dl, main_criterion, device, relp=True)
    p, r, f1, acc, roc_auc = compute_metrics(preds.cpu().numpy(), label.cpu().numpy(), threshold=0.65)
    print("test_loss:",test_loss)
    print(f"P: {p} | R: {r} | F1: {f1} | Acc: {acc} | AUC: {roc_auc}")