In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [10]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, random_split
import torchvision.transforms as transforms
import torch
import random
import numpy as np
from collections import Counter
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models



# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
#### With class 'Other'


from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
from tqdm import tqdm
from collections import Counter

# Updated class mapping to combine classes into unique classes (including 'Augmented') 
# put PVC and Other to one class
class_mapping = {
    #BigBag2
    'BigBag2_1_PET': 0,  # PET
    'BigBag2_2_PP': 1,   # PP
    'BigBag2_3_PE': 2,   # PE
    'BigBag2_4_Tetra': 3, # Tetra
    'BigBag2_5_PVC': 5, # PVC
    'BigBag2_6_PS': 4,   # PS
    'BigBag2_7_Other': 6, # Other
    'BigBag2_4_Tetra_Augmented': 3,  # Augmented Tetra
    'BigBag2_6_PS_Augmented': 4,  # Augmented PS
    
    #BigBag4
    'BigBag4_1_PET': 0,  # PET
    'BigBag4_2_PP': 1,   # PP
    'BigBag4_3_PE': 2,   # PE
    'BigBag4_4_Tetra': 3, # Tetra
    'BigBag4_6_PS': 4,   # PS
    'BigBag4_5_PVC': 5, # PVC
    'BigBag4_7_Other': 6, # Other
    
    #BigBag1
    'BigBag1_1_PET': 0,  # PET
    'BigBag1_2_PP': 1,   # PP
    'BigBag1_3_PE': 2,   # PE
    'BigBag1_4_Tetra': 3, # Tetra
    #'BigBag2_4_Tetra_Augmented': 3,  # Augmented Tetra
    #'BigBag2_5_PVC': 5, # PVC
    'BigBag1_6_PS': 4,   # PS
    'BigBag1_7_Other': 6, # Other
    #'BigBag2_6_PS_Augmented': 4,  # Augmented PS
    
    #BigBag3
    'BigBag3_PET': 0,  # PET
    'BigBag3_2_PP': 1,   # PP
    'BigBag3_PE': 2,   # PE
    'BigBag3_TETRA': 3, # Tetra
    #'BigBag3_PVC': 5, # PVC
    'BigBag3_6_PS': 4,   # PS
    'BigBag3_Other': 6, # Other
    
    'DWRL7_extension_2_PVC': 5, # 
    'BigBag2_5_PVC_Augmented': 5,  # Augmented PVC
}

class CustomPlasticDataset(Dataset):
    def __init__(self, root_dir, class_mapping, transform, tetra_transform=None, ps_transform=None, pvc_transform=None, diverse_transform=None):
        self.root_dir = root_dir
        self.class_mapping = class_mapping
        self.transform = transform
        self.tetra_transform = tetra_transform
        self.ps_transform = ps_transform
        self.pvc_transform = pvc_transform
        self.diverse_transform = diverse_transform
        self.image_paths = []
        self.labels = []
        
        # Gather image paths and labels
        for class_folder in class_mapping.keys():
            # Load original images
            image_dir = os.path.join(root_dir, class_folder, 'images_cutout')
            if os.path.exists(image_dir):
                image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
                print(f"Loaded {len(image_files)} original images for {class_folder}")
                self.image_paths.extend([os.path.join(image_dir, img) for img in image_files])
                self.labels.extend([class_mapping[class_folder]] * len(image_files))

            # Load augmented images if they exist
            augmented_dir = os.path.join(root_dir, class_folder)  # Path to the augmented class
            if os.path.exists(augmented_dir) and 'Augmented' in class_folder:
                augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
                print(f"Loaded {len(augmented_files)} augmented images for {class_folder}")
                self.image_paths.extend([os.path.join(augmented_dir, img) for img in augmented_files])
                self.labels.extend([class_mapping[class_folder]] * len(augmented_files))  # Map to the same class label

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]
        
        # Apply specific transformations based on class
        if label == 3:  # Tetra
            image = self.tetra_transform(image)
        elif label == 4:  # PS
            image = self.ps_transform(image)
        elif label == 5:  # PVC
            image = self.pvc_transform(image)
            
        else:
            # Use diverse_transform with a 50% chance for other classes
            if self.diverse_transform is not None and random.random() > 0.5:
                image = self.diverse_transform(image)

            else:
                image = self.transform(image)
        
        return image, label


# Augmentation for TETRA (Moderate)
tetra_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Augmentation for PS (Light)
ps_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Regular transform for other classes
regular_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
])
    
# Augmentation for PVC (Light)
pvc_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define more diverse transformations for augmentation
diverse_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=30),  # Random rotation by up to 30 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color adjustments
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),  # Random crop with resizing
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def augment_and_save_images(original_dataset, class_label, transform, target_dir, num_augmented_images):
    os.makedirs(target_dir, exist_ok=True)  # Create the target directory if it doesn't exist
    class_images = [original_dataset.image_paths[idx] for idx, label in enumerate(original_dataset.labels) if label == class_label]

    for img_path in tqdm(class_images, desc=f'Augmenting Class {class_label}'):
        img = Image.open(img_path)
        for i in range(num_augmented_images):
            augmented_img = transform(img)  # Apply the transformation
            augmented_img = transforms.ToPILImage()(augmented_img)  # Convert back to PIL Image
            
            # Create a unique filename using the original filename and the index
            base_filename = os.path.basename(img_path).split('.')[0]  # Get the original filename without extension
            augmented_img.save(os.path.join(target_dir, f'augmented_{base_filename}_{i}.png'))  # Save with a unique name

# Function to count images in each class
def count_images_in_classes(base_dir, class_mapping):
    class_counts = {class_name: 0 for class_name in class_mapping.keys()}

    for class_name in class_mapping.keys():
        # Check for original images
        image_dir = os.path.join(base_dir, class_name, 'images_cutout')  # Original images path
        if os.path.exists(image_dir):
            image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
            class_counts[class_name] += len(image_files)

        # Check for augmented images in their respective folders
        if 'Augmented' in class_name:
            augmented_dir = os.path.join(base_dir, class_name)  # Path to the augmented class
            if os.path.exists(augmented_dir):
                augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
                class_counts[class_name] += len(augmented_files)  # Count augmented images

    return class_counts

# Directory where all class folders are stored
data_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data'

# Create dataset instance
plastic_dataset = CustomPlasticDataset(
    root_dir=data_dir, 
    class_mapping=class_mapping, 
    transform=regular_transform,
    tetra_transform=tetra_transform,
    ps_transform=ps_transform,
    pvc_transform=pvc_transform,
    diverse_transform=diverse_transform 
)

# After creating the dataset instance, count the classes
initial_class_counts = Counter(plastic_dataset.labels)
print('Initial ClassCounts after augmentation:', initial_class_counts)

Loaded 742 original images for BigBag2_1_PET
Loaded 1403 original images for BigBag2_2_PP
Loaded 1203 original images for BigBag2_3_PE
Loaded 192 original images for BigBag2_4_Tetra
Loaded 4 original images for BigBag2_5_PVC
Loaded 227 original images for BigBag2_6_PS
Loaded 1268 original images for BigBag2_7_Other
Loaded 1350 augmented images for BigBag2_4_Tetra_Augmented
Loaded 1694 augmented images for BigBag2_6_PS_Augmented
Loaded 904 original images for BigBag4_1_PET
Loaded 1483 original images for BigBag4_2_PP
Loaded 833 original images for BigBag4_3_PE
Loaded 173 original images for BigBag4_4_Tetra
Loaded 254 original images for BigBag4_6_PS
Loaded 3 original images for BigBag4_5_PVC
Loaded 1373 original images for BigBag4_7_Other
Loaded 458 original images for BigBag1_1_PET
Loaded 414 original images for BigBag1_2_PP
Loaded 518 original images for BigBag1_3_PE
Loaded 21 original images for BigBag1_4_Tetra
Loaded 47 original images for BigBag1_6_PS
Loaded 984 original images for

In [4]:
from torch.utils.data import random_split, DataLoader, WeightedRandomSampler
from collections import Counter
import torch

# Total number of images for splits
total_images = len(plastic_dataset)
num_train_images = int(0.8 * total_images)  # 80% for training
num_val_images = int(0.1 * total_images)    # 10% for validation
num_test_images = total_images - num_train_images - num_val_images  # Remaining 10% for testing

# Randomly split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = random_split(plastic_dataset, [num_train_images, num_val_images, num_test_images])

# Create DataLoaders for training, validation, and test sets
batch_size = 32  # Adjust as needed
num_workers = 4  # Adjust based on your system

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Output to check the size of the training, validation, and test sets
print(f"Training Dataset Size: {len(train_loader.dataset)} images")
print(f"Validation Dataset Size: {len(val_loader.dataset)} images")
print(f"Test Dataset Size: {len(test_loader.dataset)} images")

# Optional: Check class distribution in training dataset
train_labels = [plastic_dataset.labels[i] for i in train_dataset.indices]
train_class_counts = Counter(train_labels)
print(f"Training Class Distribution: {train_class_counts}")

# Calculate weights for the training dataset
train_class_weights = {cls: 1.0 / count for cls, count in train_class_counts.items() if count > 0}
print("Training Class Weights:", train_class_weights)  # Debug print

# Create sample weights based on the training labels
sample_weights = [train_class_weights[label] for label in train_labels]
print("Sample Weights:", sample_weights)  # Debug print

# Create a WeightedRandomSampler for the training dataset
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Create DataLoader for training with sampler
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)

# Print to confirm sizes after sampler
print(f"Training set size with sampler: {len(train_loader.dataset)}")

# Step 2: Split the training dataset into clients
client_count = 4  # Number of clients

# Calculate how many images each client will get
client_images_per_client = len(train_dataset) // client_count  # Images per client
remaining_images = len(train_dataset) % client_count  # Remaining images after even distribution

# Create the list of lengths for each client
client_lengths = [client_images_per_client + 1 if i < remaining_images else client_images_per_client for i in range(client_count)]

# Split the training dataset into clients
client_splits = random_split(train_dataset, client_lengths)

# Verify that each client has the expected number of images
for i, client_dataset in enumerate(client_splits):
    print(f"Client {i+1} Dataset Size: {len(client_dataset)} images")

Training Dataset Size: 16353 images
Validation Dataset Size: 2044 images
Test Dataset Size: 2045 images
Training Class Distribution: Counter({6: 3478, 1: 3127, 2: 2308, 5: 2102, 0: 1997, 4: 1907, 3: 1434})
Training Class Weights: {2: 0.0004332755632582322, 6: 0.0002875215641173088, 5: 0.0004757373929590866, 1: 0.0003197953309881676, 3: 0.000697350069735007, 0: 0.000500751126690035, 4: 0.0005243838489774515}
Sample Weights: [0.0004332755632582322, 0.0004332755632582322, 0.0002875215641173088, 0.0002875215641173088, 0.0002875215641173088, 0.0004757373929590866, 0.0004332755632582322, 0.0004332755632582322, 0.0003197953309881676, 0.0004332755632582322, 0.0003197953309881676, 0.0003197953309881676, 0.0003197953309881676, 0.0003197953309881676, 0.000697350069735007, 0.0004332755632582322, 0.0003197953309881676, 0.000697350069735007, 0.0004757373929590866, 0.0002875215641173088, 0.0003197953309881676, 0.0002875215641173088, 0.0003197953309881676, 0.000500751126690035, 0.000500751126690035, 0

In [5]:
# After creating the dataset instance, count the classes
initial_class_counts = Counter(plastic_dataset.labels)
print('Initial ClassCounts after augmentation:', initial_class_counts)

Initial ClassCounts after augmentation: Counter({6: 4327, 1: 3899, 2: 2899, 5: 2616, 0: 2495, 4: 2400, 3: 1806})


In [6]:
import numpy as np
from torch.utils.data import random_split, ConcatDataset

# Define the parameters
initial_ratio = 0.37  # Ratio for the initial batch based on previous splits
num_incremental_batches = 6  # Number of incremental batches

# Prepare list for the client splits (initial and incremental batches)
client_splits_cl = []

# Iterate over each client dataset and split based on new sizes
for client_idx, client_dataset in enumerate(client_splits):
    # Calculate the size for the initial batch
    initial_batch_size = int(len(client_dataset) * initial_ratio)
    remaining_images = len(client_dataset) - initial_batch_size

    # Split initial and remaining dataset
    initial_batch, remaining_dataset = random_split(client_dataset, [initial_batch_size, remaining_images])
    
    # Calculate incremental batch sizes
    incremental_batch_size = remaining_images // num_incremental_batches
    remaining_incremental = remaining_images % num_incremental_batches
    
    incremental_sizes = [incremental_batch_size] * num_incremental_batches
    # Distribute any remainder across the first few batches
    for i in range(remaining_incremental):
        incremental_sizes[i] += 1
    
    # Create incremental batches from remaining data
    incremental_batches = random_split(remaining_dataset, incremental_sizes)

    # Initialize the client split with the initial batch
    client_batches = [initial_batch]
    
    # Append each incremental batch with replay data
    for inc_batch in incremental_batches:
        previous_data = ConcatDataset(client_batches)  # All previous data
        replay_data, _ = random_split(previous_data, [300, len(previous_data) - 300])  # Replay size
        new_batch = ConcatDataset([inc_batch, replay_data])
        client_batches.append(new_batch)
    
    client_splits_cl.append(client_batches)

# Print out the splits for each client
for client_idx, splits in enumerate(client_splits_cl):
    for i, split in enumerate(splits):
        print(f"Client {client_idx + 1} - Batch {i + 1}: {len(split)} images")

Client 1 - Batch 1: 1512 images
Client 1 - Batch 2: 730 images
Client 1 - Batch 3: 730 images
Client 1 - Batch 4: 730 images
Client 1 - Batch 5: 729 images
Client 1 - Batch 6: 729 images
Client 1 - Batch 7: 729 images
Client 2 - Batch 1: 1512 images
Client 2 - Batch 2: 730 images
Client 2 - Batch 3: 730 images
Client 2 - Batch 4: 729 images
Client 2 - Batch 5: 729 images
Client 2 - Batch 6: 729 images
Client 2 - Batch 7: 729 images
Client 3 - Batch 1: 1512 images
Client 3 - Batch 2: 730 images
Client 3 - Batch 3: 730 images
Client 3 - Batch 4: 729 images
Client 3 - Batch 5: 729 images
Client 3 - Batch 6: 729 images
Client 3 - Batch 7: 729 images
Client 4 - Batch 1: 1512 images
Client 4 - Batch 2: 730 images
Client 4 - Batch 3: 730 images
Client 4 - Batch 4: 729 images
Client 4 - Batch 5: 729 images
Client 4 - Batch 6: 729 images
Client 4 - Batch 7: 729 images


### Step 5: Continual Learning (CL) for Each Client
#### Objective: Each client independently performs Continual Learning (CL) using incremental batches of data to adapt to new information without forgetting previous knowledge.

In [11]:
import os
import torch
import random
from torchvision import models
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

# Model Preparation Function (ResNet18 for DWRL)
def prepare_model(num_classes=7, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained ResNet18 model and modify it for DWRL with optional dropout."""
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Training function for Continual Learning
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, patience=5, min_delta=0):
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    return model

# Testing function to evaluate the model on the test set
def test_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_accuracy = 100 * test_correct / test_total
    print(f'Test Accuracy: {test_accuracy:.2f}%')

# Save function
def save_model(model, client_idx, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    torch.save(model.state_dict(), model_path)
    print(f'Model for Client {client_idx} saved at {model_path}')

# Prepare the training and validation loaders for each client
batch_size = 32
num_workers = 4
num_epochs = 5  # Set the number of epochs for each CL round

# Define the directory to save the models
save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/'

# Iterate over all clients and perform incremental training (CL)
for client_idx, client_batches in enumerate(client_splits_cl):
    print(f"\nStarting Continual Learning for Client {client_idx + 1}...")

    # Initialize the model once for each client, and then update it for each batch
    model = prepare_model(num_classes=7).to(device)   ### update num of classes if needed
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Train on each batch incrementally, continuing from the last updated model
    for batch_idx, batch in enumerate(client_batches):
        print(f"Training on Batch {batch_idx + 1}/{len(client_batches)}")

        # Create DataLoader for the current batch
        train_loader = DataLoader(batch, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        # Use the same validation set for all batches
        model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

    # After training on all batches, test the model on the test set
    print(f"Testing model for Client {client_idx + 1} after Continual Learning...")
    test_model(model, test_loader)

    # Save the model for each client
    save_model(model, client_idx + 1, save_dir)

print("Continual Learning, testing, and saving models for all clients completed.")


Starting Continual Learning for Client 1...
Training on Batch 1/7
Epoch 1/5, Train Loss: 1.2316, Train Accuracy: 50.99%, Val Loss: 0.9869, Val Accuracy: 61.99%
Epoch 2/5, Train Loss: 0.7348, Train Accuracy: 73.81%, Val Loss: 0.8598, Val Accuracy: 68.20%
Epoch 3/5, Train Loss: 0.5385, Train Accuracy: 82.08%, Val Loss: 0.8053, Val Accuracy: 70.01%
Epoch 4/5, Train Loss: 0.4273, Train Accuracy: 85.78%, Val Loss: 0.8541, Val Accuracy: 68.44%
Epoch 5/5, Train Loss: 0.3861, Train Accuracy: 87.24%, Val Loss: 0.8782, Val Accuracy: 68.74%
Training on Batch 2/7
Epoch 1/5, Train Loss: 0.5559, Train Accuracy: 81.23%, Val Loss: 0.9416, Val Accuracy: 66.93%
Epoch 2/5, Train Loss: 0.3953, Train Accuracy: 87.12%, Val Loss: 0.7626, Val Accuracy: 73.34%
Epoch 3/5, Train Loss: 0.2722, Train Accuracy: 90.96%, Val Loss: 0.7787, Val Accuracy: 70.99%
Epoch 4/5, Train Loss: 0.2238, Train Accuracy: 93.42%, Val Loss: 0.8084, Val Accuracy: 71.43%
Epoch 5/5, Train Loss: 0.2231, Train Accuracy: 93.84%, Val Loss: 

### Step 6-a: Prepare Client Loaders (Validation Only)

### Step 6-b: Load Each Client’s Model (Saved After CL) into cl_models

In [None]:
# from torch.utils.data import DataLoader, random_split

# def create_client_val_loaders(dataset, num_clients=4, batch_size=32, val_split=0.1, num_workers=4):
#     """Split dataset into `num_clients` parts and create validation loaders."""
#     val_loaders = []
    
#     # Calculate sizes for each client
#     client_sizes = [len(dataset) // num_clients] * num_clients
#     for i in range(len(dataset) % num_clients):
#         client_sizes[i] += 1  # Distribute the remainder

#     # Split the dataset randomly into `num_clients` parts
#     client_datasets = random_split(dataset, client_sizes)
    
#     for client_dataset in client_datasets:
#         # Split each client's dataset into train and validation sets
#         train_size = int((1 - val_split) * len(client_dataset))
#         val_size = len(client_dataset) - train_size
#         _, val_dataset = random_split(client_dataset, [train_size, val_size])  # Only need val_dataset

#         # Create DataLoader for validation set
#         val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
#                                 num_workers=num_workers, pin_memory=True)
#         val_loaders.append(val_loader)
    
#     return val_loaders

# # Assuming `train_dataset` is your main dataset from DWRL
# client_val_loaders = create_client_val_loaders(train_dataset, num_clients=4, batch_size=32, num_workers=4)

# # Verify the sizes of the validation loaders for each client
# for i, val_loader in enumerate(client_val_loaders):
#     print(f"Client {i+1} - Validation set size: {len(val_loader.dataset)}")

# import os
# import torch
# from torchvision import models
# import torch.nn as nn

# # Model Preparation Function (ResNet18 for DWRL)
# def prepare_model(num_classes=7):
#     """Initialize a ResNet18 model with the specified number of classes."""
#     model = models.resnet18(weights='DEFAULT')
#     num_ftrs = model.fc.in_features
#     model.fc = nn.Linear(num_ftrs, num_classes)
#     return model

# # Function to load a model from a file
# def load_client_model(client_idx, save_dir, num_classes=7):
#     """Load a model for a specific client after CL training."""
#     model = prepare_model(num_classes=num_classes).to(device)
#     model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
#     state_dict = torch.load(model_path)
#     model.load_state_dict(state_dict)
#     print(f'Loaded model for Client {client_idx} from {model_path}')
#     return model

# # Load each client’s model saved after CL
# save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models'
# cl_models = [load_client_model(client_idx + 1, save_dir) for client_idx in range(4)]

# # Function to average client models (FL aggregation)
# def federated_averaging(state_dicts):
#     """Average the model parameters from multiple clients."""
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# # Apply federated averaging on the loaded client models
# state_dicts = [model.state_dict() for model in cl_models]
# avg_state_dict = federated_averaging(state_dicts)

# # Load averaged state dict into a global model
# global_model = prepare_model(num_classes=7).to(device)
# global_model.load_state_dict(avg_state_dict)

# print("Global model created with federated averaging.")



# # Function to test the global model
# def test_global_model(model, test_loader):
#     """Evaluates the global model on the test dataset and returns accuracy."""
#     model.eval()  # Set model to evaluation mode
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():  # Disable gradient computation for testing
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_accuracy = 100 * test_correct / test_total
#     return test_accuracy

# # Assuming `test_loader` is defined for your DWRL dataset
# # Test the aggregated global model
# test_accuracy = test_global_model(global_model, test_loader)
# print(f'Test Accuracy of Global Model: {test_accuracy:.2f}%')

Global model created with federated averaging.


In [13]:
# import os
# import torch.optim as optim
# import torch.nn as nn
# from torchvision import models
# from torch.cuda.amp import autocast, GradScaler

# # Path where client models are saved
# save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/'

# # Model Preparation (ResNet18 for DWRL)
# def prepare_model(num_classes=7, use_dropout=False, dropout_prob=0.2):
#     """Load a pre-trained ResNet18 model and modify it for DWRL with optional dropout."""
#     model = models.resnet18(weights='DEFAULT')  # Adjusted to use the new weights parameter
#     num_ftrs = model.fc.in_features
#     if use_dropout:
#         model.fc = nn.Sequential(
#             nn.Dropout(p=dropout_prob),
#             nn.Linear(num_ftrs, num_classes)
#         )
#     else:
#         model.fc = nn.Linear(num_ftrs, num_classes)
#     return model

# # Function to load a model from a file
# def load_client_model(client_idx, save_dir, num_classes=7):
#     model = prepare_model(num_classes=num_classes).to(device)
#     model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
#     state_dict = torch.load(model_path)
#     model.load_state_dict(state_dict, strict=False)
#     print(f'Loaded model for Client {client_idx} from {model_path}')
#     return model

# # Function to average client models (FL aggregation)
# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# # Step 6: Federated Learning (FL) after Continual Learning (CL)
# def apply_federated_learning(cl_models, test_loader, num_clients=4, num_rounds=3):
#     # Initialize the global model with the averaged state dict from client models
#     state_dicts = [model.state_dict() for model in cl_models]
#     avg_state_dict = federated_averaging(state_dicts)
#     global_model = prepare_model().to(device)
#     global_model.load_state_dict(avg_state_dict)

#     for round in range(num_rounds):
#         print(f'\n--- Federated Learning Round {round + 1} ---')

#         # Perform federated averaging (without fine-tuning on local data)
#         avg_state_dict = federated_averaging(state_dicts)
#         global_model.load_state_dict(avg_state_dict)

#         # Test the global model after aggregation
#         test_accuracy = test_global_model(global_model, test_loader)
#         print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

#     return global_model

# # Function to test the global model
# def test_global_model(model, test_loader):
#     model.eval()
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()
#     test_accuracy = 100 * test_correct / test_total
#     return test_accuracy

# # Load models saved after CL
# cl_models = [load_client_model(client_idx, save_dir, num_classes=7) for client_idx in range(1, 5)]

# # Apply FL and test the global model (without additional local fine-tuning)
# global_model = apply_federated_learning(cl_models, test_loader)

In [8]:
from torch.utils.data import DataLoader, random_split

# Function to create DataLoaders for each client
def create_client_loaders(dataset, num_clients=4, batch_size=32, val_split=0.1, num_workers=4):
    """Split dataset into `num_clients` parts and create train/validation loaders."""
    client_loaders = []
    val_loaders = []
    
    # Calculate sizes for each client
    client_sizes = [len(dataset) // num_clients] * num_clients
    for i in range(len(dataset) % num_clients):
        client_sizes[i] += 1  # Distribute the remainder

    # Split the dataset randomly into `num_clients` parts
    client_datasets = random_split(dataset, client_sizes)
    
    for client_dataset in client_datasets:
        # Further split each client's dataset into train and validation sets
        train_size = int((1 - val_split) * len(client_dataset))
        val_size = len(client_dataset) - train_size
        train_dataset, val_dataset = random_split(client_dataset, [train_size, val_size])
        
        # Create DataLoaders for the client with pin_memory and optimizations
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                  num_workers=num_workers, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                                num_workers=num_workers, pin_memory=True)
        
        client_loaders.append(train_loader)
        val_loaders.append(val_loader)
    
    return client_loaders, val_loaders

# Assuming `train_dataset` is your training dataset from DWRL
client_train_loaders, client_val_loaders = create_client_loaders(train_dataset, num_clients=4, batch_size=32, num_workers=4)

# Verify the sizes of the loaders for each client
for i, (train_loader, val_loader) in enumerate(zip(client_train_loaders, client_val_loaders)):
    print(f"Client {i+1} - Training set size: {len(train_loader.dataset)}")
    print(f"Client {i+1} - Validation set size: {len(val_loader.dataset)}")

Client 1 - Training set size: 3680
Client 1 - Validation set size: 409
Client 2 - Training set size: 3679
Client 2 - Validation set size: 409
Client 3 - Training set size: 3679
Client 3 - Validation set size: 409
Client 4 - Training set size: 3679
Client 4 - Validation set size: 409


In [17]:
# Function to average client models with FedAvg (weighted averaging based on data sizes)
def fedavg_weighted_averaging(state_dicts, data_sizes):
    avg_state_dict = {}
    total_data = sum(data_sizes)
    
    for key in state_dicts[0].keys():
        weighted_sum = sum(state_dict[key] * (size / total_data) for state_dict, size in zip(state_dicts, data_sizes))
        avg_state_dict[key] = weighted_sum
    return avg_state_dict

# Function to load each client model and prepare for FL
def load_and_get_data_sizes(save_dir, num_clients=4, num_classes=7):
    cl_models = []
    data_sizes = []
    
    for client_idx in range(1, num_clients + 1):
        model = load_client_model(client_idx, save_dir, num_classes=num_classes)
        cl_models.append(model)
        
        # Get the size of the dataset for weighting
        data_sizes.append(len(client_train_loaders[client_idx - 1].dataset))  # Assume train_loader created earlier
        
    return cl_models, data_sizes

# Step 1: Load client models and dataset sizes
cl_models, data_sizes = load_and_get_data_sizes(save_dir, num_clients=4)

# Step 2: Federated averaging with weighted FedAvg
def apply_federated_learning_with_fedavg(cl_models, client_train_loaders, test_loader, num_clients=4, num_epochs=3):
    global_model = prepare_model(num_classes=7).to(device)

    # Collect state dicts from all clients after CL
    state_dicts = [model.state_dict() for model in cl_models]

    # Apply weighted federated averaging (FedAvg) using data sizes
    avg_state_dict = fedavg_weighted_averaging(state_dicts, data_sizes)
    global_model.load_state_dict(avg_state_dict)

    # Testing the aggregated model
    test_accuracy = test_global_model(global_model, test_loader)
    print(f"Test Accuracy of Global Model after FedAvg: {test_accuracy:.2f}%")

    return global_model

# Execute FedAvg
global_model = apply_federated_learning_with_fedavg(cl_models, client_train_loaders, test_loader)


Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_4_model.pth
Test Accuracy of Global Model after FedAvg: 68.75%


### Step 6a: BASELINE(FL After Full Continual Learning" (FL-FCL))- Apply Federated Learning (FL) After Continual Learning (CL)
#### Federated Averaging: We collect the models from each client after CL, average their weights, and load the averaged weights into the global model.
#### Test the Global Model: After aggregation, the global model is tested on the test set, and the test accuracy is reported.

In [None]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models
from torch.cuda.amp import autocast, GradScaler  # For mixed-precision training

# Path where client models are saved
#saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/cl_models/'
save_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/'

# Model Preparation (ResNet18 for DWRL)
def prepare_model(num_classes=7, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained Resnet18 model and modify it for DWRL with optional dropout."""
    model = models.resnet18(weights='DEFAULT')  # Adjusted to use the new weights parameter
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir, num_classes=7):
    model = prepare_model(num_classes=num_classes).to(device)  # Specify the correct num_classes
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    
    # Load the state dict with strict=False to ignore mismatched keys
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict, strict=False)
    
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FL aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=3, use_mixed_precision=True):
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    # Optional: Mixed precision training
    scaler = GradScaler() if use_mixed_precision else None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            if use_mixed_precision:
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            if use_mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()

# Step 6: Federated Learning (FL) after Continual Learning (CL)
def apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader, num_clients=4, num_epochs=3):
    global_model = prepare_model().to(device)

    # Collect state dicts from all clients after CL
    state_dicts = [model.state_dict() for model in cl_models]

    # Perform federated averaging
    avg_state_dict = federated_averaging(state_dicts)
    global_model.load_state_dict(avg_state_dict)

    for round in range(num_epochs):
        print(f'\n--- Federated Learning Round {round + 1} ---')
        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} with the global model')
            client_state_dict = fine_tune_client(global_model, train_loaders[client_idx], val_loaders[client_idx], num_epochs=1)
            client_state_dicts.append(client_state_dict)

        # Federated averaging after each round
        avg_state_dict = federated_averaging(client_state_dicts)
        global_model.load_state_dict(avg_state_dict)

        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Round {round + 1}: {test_accuracy:.2f}%')

    return global_model

# Function to test the global model
def test_global_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    return test_accuracy

# Load models saved after CL
cl_models = [load_client_model(client_idx, save_dir, num_classes=7) for client_idx in range(1, 5)]

# Prepare DataLoader splits for each client with pin_memory=True and fewer workers
train_loaders, val_loaders = create_client_loaders(train_dataset, num_clients=4, batch_size=32)

# Apply FL and test the global model
global_model = apply_federated_learning(cl_models, train_loaders, val_loaders, test_loader)

Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_4_model.pth

--- Federated Learning Round 1 ---

Training client 1 with the global model
Epoch 1/1, Train Accuracy: 75.98%

Training client 2 with the global model
Epoch 1/1, Train Accuracy: 76.54%

Training client 3 with the global model
Epoch 1/1, Train Accuracy: 77.03%

Training client 4 with the global model
Epoch 1/1, Train Accuracy: 77.87%
Test Accuracy after Round 1: 81.37%

--- Federated Learning Round 2 ---

Training client 1 with the global model
Epoch 1/1, Train Accuracy: 81.90%

Training client 2 with the global model
Epoch 1/1, Train Acc

## Step 6b: (FL After Each Continual Learning Round (FL-CL)) - Apply Federated Learning (FL) After Each CL Round
- **Federated Averaging**: After each client completes a round of CL (training on a batch), we perform FL by averaging the client models.
- **Global Model Updates**: The aggregated global model is redistributed to all clients before continuing with the next batch.
- **Testing**: The global model is tested after each FL round to evaluate its performance.

In [14]:
import os
import torch.optim as optim
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader

# Path where client models are saved
saved_models_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/'

# Model Preparation (ResNet18 for DWRL)
def prepare_model(num_classes=7, use_dropout=False, dropout_prob=0.2):
    """Load a pre-trained ResNet18 model and modify it for DWRL with optional dropout."""
    model = models.resnet18(weights='DEFAULT')  # Update to use weights argument
    num_ftrs = model.fc.in_features
    if use_dropout:
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(num_ftrs, num_classes)
        )
    else:
        model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Function to load a model from a file
def load_client_model(client_idx, save_dir):
    model = prepare_model(num_classes=7).to(device)
    model_path = os.path.join(save_dir, f'client_{client_idx}_model.pth')
    model.load_state_dict(torch.load(model_path))
    print(f'Loaded model for Client {client_idx} from {model_path}')
    return model

# Function to average client models (FL aggregation)
def federated_averaging(state_dicts):
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
    return avg_state_dict

# Training function for each client after receiving the global model
def fine_tune_client(global_model, train_loader, val_loader, num_epochs=3):  # Adjusted to 3 epochs
    model = global_model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Train Accuracy: {train_accuracy:.2f}%')

    return model.state_dict()  # Return the fine-tuned state_dict


# Step 6: Federated Learning (FL) after each batch of Continual Learning (CL)
# Assume client_splits_cl is correctly initialized with batch splits for each client.
# Check if client_splits_cl is correctly defined with the expected shape.

def apply_federated_learning_after_each_batch(client_splits_cl, val_loaders, test_loader, num_clients=4, num_epochs=3):
    # Initialize a global model
    global_model = prepare_model(num_classes=7).to(device)

    # Iterate over each batch of CL for all clients
    for batch_idx in range(len(client_splits_cl[0])):  
        print(f'\n--- Training and Federated Learning after Batch {batch_idx + 1} ---')

        client_state_dicts = []

        for client_idx in range(num_clients):
            print(f'\nTraining client {client_idx + 1} on Batch {batch_idx + 1}')

            # Create DataLoader for the current batch
            train_loader = DataLoader(client_splits_cl[client_idx][batch_idx], batch_size=32, shuffle=True)

            # Fine-tune client model with the current global model
            client_state_dict = fine_tune_client(global_model, train_loader, val_loaders[client_idx], num_epochs=num_epochs)
            client_state_dicts.append(client_state_dict)

        # Perform federated averaging after this batch for all clients
        avg_state_dict = federated_averaging(client_state_dicts)
        global_model.load_state_dict(avg_state_dict)

        # Optionally test the global model after each batch
        test_accuracy = test_global_model(global_model, test_loader)
        print(f'Test Accuracy after Batch {batch_idx + 1}: {test_accuracy:.2f}%')

    return global_model

# Make sure to load models saved after Continual Learning (CL) as before
cl_models = [load_client_model(client_idx, saved_models_dir) for client_idx in range(1, 5)]

# Ensure that client_splits_cl is defined properly with the batches
global_model = apply_federated_learning_after_each_batch(client_splits_cl, val_loaders, test_loader, num_clients=4, num_epochs=3)

Loaded model for Client 1 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_1_model.pth
Loaded model for Client 2 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_2_model.pth
Loaded model for Client 3 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_3_model.pth
Loaded model for Client 4 from /raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/models/client_4_model.pth

--- Training and Federated Learning after Batch 1 ---

Training client 1 on Batch 1
Epoch 1/3, Train Accuracy: 52.38%
Epoch 2/3, Train Accuracy: 72.35%
Epoch 3/3, Train Accuracy: 79.37%

Training client 2 on Batch 1
Epoch 1/3, Train Accuracy: 68.72%
Epoch 2/3, Train Accuracy: 77.71%
Epoch 3/3, Train Accuracy: 82.47%

Training client 3 on Batch 1
Epoch 1/3, Train Accuracy: 72.09%
Epoch 2/3, Train Accuracy: 79.96%
Epoch 3/3, Train Accuracy: 85.19%

Training client 4 on Batch 1
Epoch 1/3, Train Accuracy: 71.23%
Epoch 2/3, Train Accuracy: 80