In [None]:
import simpy
import random
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import os
from torchvision.datasets.folder import default_loader
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim import AdamW
import torchvision.transforms as transforms
import timm
import psutil
import time
import statistics
from glob import glob
from sklearn.metrics import recall_score, precision_score, f1_score
import torchvision.transforms as transforms

# Set some parameters
NUM_NODES = 5 
EPOCHS = 10     
BATCH_SIZE = 32 
DATA_PATTERN = [0.3, 0.7]  # Data distribution pattern: Node 1: 30%, the remaining nodes evenly distribute: 70%.

class CustomDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = []
        self.labels = set()  # Create a set to store unique labels

        # Iterate over patient ID folders
        for patient_id in sorted(os.listdir(root)):
            patient_path = os.path.join(root, patient_id)
            for class_label in ['0', '1']:
                class_path = os.path.join(patient_path, class_label)
                if os.path.isdir(class_path):
                    for img_name in os.listdir(class_path):
                        img_path = os.path.join(class_path, img_name)
                        label = int(class_label)
                        self.samples.append((img_path, label))
                        self.labels.add(label)  # Add label to the set

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label = self.samples[index]
        sample = default_loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label

# Define model
class SwinTransformerModel(nn.Module):
    def __init__(self, num_classes=4):
        super(SwinTransformerModel, self).__init__()
        self.swin_transformer = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)

        # Freeze all parameters of the pre-trained model
        for param in self.swin_transformer.parameters():
            param.requires_grad = False
        
        # Get the number of input features for the last layer
        num_features = self.swin_transformer.head.in_features
        self.swin_transformer.head = nn.Sequential(
            nn.Dropout(0.5),  # Adding Dropout Layers to Reduce Overfitting
            nn.Linear(num_features, 512),  # Top level fully connected layer
            nn.ReLU(),  # Activation function
            nn.Linear(512, num_classes)  # Output layer
        )
        
        # Ensure that only the parameters of the newly added fully connected layer are updated
        for param in self.swin_transformer.head.parameters():
            param.requires_grad = True

        for name, param in self.swin_transformer.named_parameters():
            if name in ['layer4.2.conv3.weight', 'layer4.2.bn3.weight', 'layer4.2.bn3.bias']:
                param.requires_grad = True

        # Add a global average pooling layer to handle the spatial dimensions
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.swin_transformer.forward_features(x)  # Extract features

        # Adjust the dimension order
        x = x.permute(0, 3, 1, 2)  # From [32, 7, 7, 768] to [32, 768, 7, 7]

        # Apply global average pooling
        x = self.global_avg_pool(x)  # From [32, 768, 7, 7] to [32, 768, 1, 1]

        x = torch.flatten(x, 1)  # Flatten from [32, 768, 1, 1] to [32, 768]
        x = self.swin_transformer.head(x)  # Apply fully connected layer
       
        return x

def check_dataset_labels(dataset):
    # Get all unique labels in the dataset
    all_labels = [label for _, label in dataset]
    unique_labels = set(all_labels)
    print("Unique labels in the dataset:", unique_labels)
    return max(unique_labels) + 1


# Load dataset
def load_data():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Use the CustomDataset
    dataset = CustomDataset(root='D:/USYD S3/archive/IDC_regular_ps50_idx5', transform=transform)
    num_classes = len(dataset.labels)  # Correctly access the number of unique labels
    print("Number of samples in the dataset:", len(dataset))
    print("Detected number of classes:", num_classes)

    # Use DataLoader to handle batches and shuffling
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=12
    )
    return data_loader, num_classes



def validate(model, val_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():  # No need to track gradients
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(target.cpu().numpy())
    accuracy = 100 * correct / total
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='macro')
    return accuracy, recall, f1, precision

def calculate_dynamic_threshold(epoch, max_epochs, base_threshold, max_threshold=100):
    """
    Calculate a dynamic GPU usage threshold based on the current epoch.

    Args:
        epoch (int): Current epoch number.
        max_epochs (int): Total number of epochs planned for training.
        base_threshold (int): Calculated median GPU usage as base threshold.
        max_threshold (int): Maximum threshold towards the end of training.

    Returns:
        int: Calculated dynamic GPU usage threshold.
    """
    progression = epoch / max_epochs
    return base_threshold + (max_threshold - base_threshold) * progression

def should_node_continue(node_id, epoch, model, val_loader, threshold):
    """
    Determine whether the node continues to participate in training based on GPU usage.

    Args:
        node_id (int): Node ID.
        epoch (int): Current epoch number.
        model (torch.nn.Module): Model instance.
        val_loader (torch.utils.data.DataLoader): Validation data loader.
        threshold (int): Dynamic GPU usage threshold.

    Returns:
        bool: Whether to continue participating in training.
    """
    if torch.cuda.is_available():
        gpu_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
    else:
        gpu_usage = 0  # Assume no GPU usage if no GPU available

    print(f'Node {node_id}, Epoch {epoch}, GPU Usage: {gpu_usage}%')

    if gpu_usage > threshold:
        print(f'Node {node_id} exiting due to high GPU usage: GPU {gpu_usage}%')
        return False

    return True

def node_process(env, node_id, net, train_loader, val_loader, global_weights, num_classes, status, all_done):
    model = SwinTransformerModel(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)  # Adjust learning rate

    gpu_usages = []
    base_threshold = None

    start_epoch = 0


    accuracy_list = []
    recall_list = []
    f1_list = []
    precision_list = []
    val_accuracy_list = []
    val_recall_list = []
    val_f1_list = []
    val_precision_list = []

    for epoch in range(start_epoch, EPOCHS):
        if epoch == 0:  # First epoch, collect GPU usages
            if torch.cuda.is_available():
                current_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() * 100
                gpu_usages.append(current_usage)

        # After first epoch, calculate median usage as base threshold
        if epoch == 1:
            base_threshold = statistics.median(gpu_usages)
        
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        true_positive = 0
        false_negative = 0
        false_positive = 0
        all_preds = []
        all_labels = []

        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{EPOCHS}, Node {node_id}')
        for batch_idx, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()


            # Calculate training accuracy
            _, predicted = torch.max(output.data, 1)
            true_positive += ((predicted == 1) & (target == 1)).sum().item()
            false_negative += ((predicted != 1) & (target == 1)).sum().item()
            false_positive += ((predicted == 1) & (target != 1)).sum().item()
            correct += (predicted == target).sum().item()
            total += target.size(0)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(target.cpu().numpy())
            pbar.set_postfix(loss=loss.item(), accuracy=f'{100 * correct / total:.2f}%')

            # Synchronize and update model across nodes
            params = [param.data for param in model.parameters()]
            yield env.process(broadcast_params(env, net, params))
            params_list, weights = yield env.process(gather_params(env, net, global_weights))
            avg_params = fed_avg(params_list)
            with torch.no_grad():
                for param, avg_param in zip(model.parameters(), avg_params):
                    param.copy_(avg_param)
        
        # Operations after one epoch of training
        # Calculate and record training accuracy
        training_accuracy = 100 * correct / total
        accuracy_list.append(training_accuracy)
        training_precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
        # Calculate and record training recall after each epoch
        training_recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
        training_f1_score = 2 * (training_precision * training_recall) / (training_precision + training_recall) if (training_precision + training_recall) > 0 else 0
        recall_list.append(training_recall)
        f1_list.append(training_f1_score)
        precision_list.append(training_precision)


        # Perform model validation
        val_accuracy, val_recall, val_f1, val_precision = validate(model, val_loader, device)
        val_accuracy_list.append(val_accuracy)
        val_recall_list.append(val_recall)
        val_f1_list.append(val_f1)
        val_precision_list.append(val_precision)

        # Print epoch summary
        scheduler.step()  # Update learning rate
        print(f'Node {node_id}, Epoch {epoch}, Loss: {loss.item()}, Training Accuracy: {training_accuracy:.2f}%, Training Recall: {training_recall:.2f}, Training F1 Score: {training_f1_score:.2f}, Training Precision Score: {training_precision:.2f}, Val Accuracy: {val_accuracy:.2f}%, Val Recall: {val_recall:.2f}, Val F1 Score: {val_f1:.2f}, Val Precision Score: {val_precision:.2f}')

        # Dynamic threshold calculation starts from the second epoch
        if epoch > 0:
            current_threshold = calculate_dynamic_threshold(epoch, EPOCHS, base_threshold)
            if not should_node_continue(node_id, epoch, model, val_loader, current_threshold):
                status[node_id] = False
                break
            else:
                status[node_id] = True



    # Plot training and validation
    plt.figure(figsize=(6, 18))
    plt.subplot(4, 1, 1)
    plt.plot(range(len(accuracy_list)), accuracy_list, 'navy', label='Training Accuracy')
    plt.plot(range(len(val_accuracy_list)), val_accuracy_list, 'skyblue', label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'Training and Validation Accuracy at Node {node_id}')
    plt.legend()


    plt.subplot(4, 1, 2)
    plt.plot(range(len(recall_list)), recall_list, 'darkred', label='Training Recall')
    plt.plot(range(len(val_recall_list)), val_recall_list, 'salmon', label='Validation Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.title(f'Training and Validation Recall at Node {node_id}')
    plt.legend()
    


    plt.subplot(4, 1, 3)
    plt.plot(range(len(f1_list)), f1_list, 'darkgreen', label='Training F1 Score')
    plt.plot(range(len(val_f1_list)), val_f1_list, 'lightgreen', label='Validation F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.title(f'Training and Validation F1 Score at Node {node_id}')
    plt.legend()
    
    plt.subplot(4, 1, 4)
    plt.plot(range(len(precision_list)), precision_list, 'purple', label='Training Precision Score')
    plt.plot(range(len(val_precision_list)), val_precision_list, 'lavender', label='Validation Precision Score')
    plt.xlabel('Epoch')
    plt.ylabel('Precision Score')
    plt.title(f'Training and Validation Precision Score at Node {node_id}')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Mark the node as done only if it has completed all epochs
    if status[node_id]:
        all_done[node_id] = True
        print(f'Node {node_id} marked as done.')
    yield env.timeout(1)  # Simulate delay

def broadcast_params(env, net, params):
    def broadcast_helper(env, net, params):
        yield net.broadcast(params)

    yield env.process(broadcast_helper(env, net, params))

def gather_params(env, net, weights):
    def gather_helper(env, net):
        received_events = yield net.gather()
        params_list = [event.value for event in received_events if event.value is not None]
        return params_list

    params_list = yield env.process(gather_helper(env, net))
    return params_list, weights

# fedAvg to merge parameters
def fed_avg(params_list):
    avg_params = []
    num_nodes = len(params_list)
    for params in zip(*params_list):
        avg_param = sum(params) / num_nodes
        avg_params.append(avg_param)
    return avg_params

# compute network objects
class P2PNetwork(object):
    def __init__(self, env, num_nodes):
        self.env = env
        self.num_nodes = num_nodes  # Ensure num_nodes is correctly defined and used
        
        self.pipes = [simpy.Store(env) for _ in range(self.num_nodes)]
        self.delays = [random.randint(1, 10) for _ in range(self.num_nodes)]
        
    def broadcast(self, value):
        events = [self.pipes[i].put(value) for i in range(self.num_nodes)]
        return self.env.all_of(events)
    
    def gather(self):
        received = [self.pipes[i].get() for i in range(self.num_nodes)]
        delays = [self.env.timeout(self.delays[i]) for i in range(self.num_nodes)]
        all_events = received.copy()
        all_events.extend(delays)
        return self.env.all_of(all_events)

# main function
def main():
    global device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    env = simpy.Environment()
    net = P2PNetwork(env, NUM_NODES)
    
    # Initialize the list to track completion status of nodes
    all_done = [False] * NUM_NODES

    # Load data and unpack the returned tuple
    train_loader, num_classes = load_data()
    
    # Distribute data to nodes
    num_samples = len(train_loader.dataset)
    num_train = int(0.8 * num_samples)  # 80% of the dataset for training
    num_val = num_samples - num_train  # Remaining 20% for validation
    train_subset, val_subset = random_split(train_loader.dataset, [num_train, num_val])

    # Create data loaders for training and validation subsets
    train_subset_loader = torch.utils.data.DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_subset_loader = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

    indices = list(range(num_train))
    split = int(np.floor(DATA_PATTERN[0] * num_train))
    node1_indices = indices[:split]
    remaining_indices = indices[split:]
    num_remaining = len(remaining_indices)
    chunk_size = num_remaining // (NUM_NODES - 1)
    node_indices = [node1_indices] + [remaining_indices[i:i+chunk_size] for i in range(0, num_remaining, chunk_size)]
    node_subsets = [torch.utils.data.Subset(train_subset, idx) for idx in node_indices]
    node_loaders = [torch.utils.data.DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) for subset in node_subsets]
    
    # Compute node weight (according to data size)
    node_weights = [len(subset) / num_train for subset in node_subsets]

    # Initialize status list
    status = [True] * NUM_NODES  # Initialize node status (all active)

    # Start a node process
    processes = [env.process(node_process(env, i, net, node_loaders[i], val_subset_loader, node_weights[i], num_classes, status, all_done)) for i in range(NUM_NODES)]

    while not all(all_done):  # Check if all nodes are done
        env.run(until=env.timeout(1))
        #print(f'Checking completion status: {all_done}')
        # Check and restart inactive nodes in the next round
        for i in range(NUM_NODES):
            if not status[i] and not all_done[i]:  # Node is not done and inactive
                print(f'Restarting node {i} for the next round')
                status[i] = True  # Reset the status to True before restarting
                processes.append(env.process(node_process(env, i, net, node_loaders[i], val_subset_loader, node_weights[i], num_classes, status, all_done)))

    print("All nodes have completed training.")

if __name__ == '__main__':
    main()
