In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [2]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, random_split
import torchvision.transforms as transforms
import torch
import numpy as np
from collections import Counter
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models



# Check for CUDA GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
#### With class 'Other'


from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
from tqdm import tqdm
from collections import Counter

# Updated class mapping to combine classes into unique classes (including 'Augmented') 
# put PVC and Other to one class
class_mapping = {
    #BigBag2
    'BigBag2_1_PET': 0,  # PET
    'BigBag2_2_PP': 1,   # PP
    'BigBag2_3_PE': 2,   # PE
    'BigBag2_4_Tetra': 3, # Tetra
    'BigBag2_5_PVC': 5, # PVC
    'BigBag2_6_PS': 4,   # PS
    'BigBag2_7_Other': 6, # Other
    'BigBag2_4_Tetra_Augmented': 3,  # Augmented Tetra
    'BigBag2_6_PS_Augmented': 4,  # Augmented PS
    
    #BigBag4
    'BigBag4_1_PET': 0,  # PET
    'BigBag4_2_PP': 1,   # PP
    'BigBag4_3_PE': 2,   # PE
    'BigBag4_4_Tetra': 3, # Tetra
    'BigBag4_6_PS': 4,   # PS
    'BigBag4_5_PVC': 5, # PVC
    'BigBag4_7_Other': 6, # Other
    
    #BigBag1
    'BigBag1_1_PET': 0,  # PET
    'BigBag1_2_PP': 1,   # PP
    'BigBag1_3_PE': 2,   # PE
    'BigBag1_4_Tetra': 3, # Tetra
    #'BigBag2_4_Tetra_Augmented': 3,  # Augmented Tetra
    #'BigBag2_5_PVC': 5, # PVC
    'BigBag1_6_PS': 4,   # PS
    'BigBag1_7_Other': 6, # Other
    #'BigBag2_6_PS_Augmented': 4,  # Augmented PS
    
    #BigBag3
    'BigBag3_PET': 0,  # PET
    'BigBag3_2_PP': 1,   # PP
    'BigBag3_PE': 2,   # PE
    'BigBag3_TETRA': 3, # Tetra
    #'BigBag3_PVC': 5, # PVC
    'BigBag3_6_PS': 4,   # PS
    'BigBag3_Other': 6, # Other
    
    'DWRL7_extension_2_PVC': 5, # 
    'BigBag2_5_PVC_Augmented': 5,  # Augmented PVC
}

class CustomPlasticDataset(Dataset):
    def __init__(self, root_dir, class_mapping, transform, tetra_transform=None, ps_transform=None, pvc_transform=None, diverse_transform=None):
        self.root_dir = root_dir
        self.class_mapping = class_mapping
        self.transform = transform
        self.tetra_transform = tetra_transform
        self.ps_transform = ps_transform
        self.pvc_transform = pvc_transform
        self.diverse_transform = diverse_transform
        self.image_paths = []
        self.labels = []
        
        # Gather image paths and labels
        for class_folder in class_mapping.keys():
            # Load original images
            image_dir = os.path.join(root_dir, class_folder, 'images_cutout')
            if os.path.exists(image_dir):
                image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
                print(f"Loaded {len(image_files)} original images for {class_folder}")
                self.image_paths.extend([os.path.join(image_dir, img) for img in image_files])
                self.labels.extend([class_mapping[class_folder]] * len(image_files))

            # Load augmented images if they exist
            augmented_dir = os.path.join(root_dir, class_folder)  # Path to the augmented class
            if os.path.exists(augmented_dir) and 'Augmented' in class_folder:
                augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
                print(f"Loaded {len(augmented_files)} augmented images for {class_folder}")
                self.image_paths.extend([os.path.join(augmented_dir, img) for img in augmented_files])
                self.labels.extend([class_mapping[class_folder]] * len(augmented_files))  # Map to the same class label

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]
        
        # Apply specific transformations based on class
        if label == 3:  # Tetra
            image = self.tetra_transform(image)
        elif label == 4:  # PS
            image = self.ps_transform(image)
        elif label == 5:  # PVC
            image = self.pvc_transform(image)
            
        else:
            # Use diverse_transform with a 50% chance for other classes
            if self.diverse_transform is not None and random.random() > 0.5:
                image = self.diverse_transform(image)

            else:
                image = self.transform(image)
        
        return image, label


# Augmentation for TETRA (Moderate)
tetra_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Augmentation for PS (Light)
ps_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Regular transform for other classes
regular_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
])
    
# Augmentation for PVC (Light)
pvc_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define more diverse transformations for augmentation
diverse_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=30),  # Random rotation by up to 30 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color adjustments
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),  # Random crop with resizing
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def augment_and_save_images(original_dataset, class_label, transform, target_dir, num_augmented_images):
    os.makedirs(target_dir, exist_ok=True)  # Create the target directory if it doesn't exist
    class_images = [original_dataset.image_paths[idx] for idx, label in enumerate(original_dataset.labels) if label == class_label]

    for img_path in tqdm(class_images, desc=f'Augmenting Class {class_label}'):
        img = Image.open(img_path)
        for i in range(num_augmented_images):
            augmented_img = transform(img)  # Apply the transformation
            augmented_img = transforms.ToPILImage()(augmented_img)  # Convert back to PIL Image
            
            # Create a unique filename using the original filename and the index
            base_filename = os.path.basename(img_path).split('.')[0]  # Get the original filename without extension
            augmented_img.save(os.path.join(target_dir, f'augmented_{base_filename}_{i}.png'))  # Save with a unique name

# Function to count images in each class
def count_images_in_classes(base_dir, class_mapping):
    class_counts = {class_name: 0 for class_name in class_mapping.keys()}

    for class_name in class_mapping.keys():
        # Check for original images
        image_dir = os.path.join(base_dir, class_name, 'images_cutout')  # Original images path
        if os.path.exists(image_dir):
            image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
            class_counts[class_name] += len(image_files)

        # Check for augmented images in their respective folders
        if 'Augmented' in class_name:
            augmented_dir = os.path.join(base_dir, class_name)  # Path to the augmented class
            if os.path.exists(augmented_dir):
                augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
                class_counts[class_name] += len(augmented_files)  # Count augmented images

    return class_counts

# Directory where all class folders are stored
data_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data'

# Create dataset instance
plastic_dataset = CustomPlasticDataset(
    root_dir=data_dir, 
    class_mapping=class_mapping, 
    transform=regular_transform,
    tetra_transform=tetra_transform,
    ps_transform=ps_transform,
    pvc_transform=pvc_transform,
    diverse_transform=diverse_transform 
)

# After creating the dataset instance, count the classes
initial_class_counts = Counter(plastic_dataset.labels)
print('Initial ClassCounts after augmentation:', initial_class_counts)

Loaded 742 original images for BigBag2_1_PET
Loaded 1403 original images for BigBag2_2_PP
Loaded 1203 original images for BigBag2_3_PE
Loaded 192 original images for BigBag2_4_Tetra
Loaded 4 original images for BigBag2_5_PVC
Loaded 227 original images for BigBag2_6_PS
Loaded 1268 original images for BigBag2_7_Other
Loaded 1350 augmented images for BigBag2_4_Tetra_Augmented
Loaded 1694 augmented images for BigBag2_6_PS_Augmented
Loaded 904 original images for BigBag4_1_PET
Loaded 1483 original images for BigBag4_2_PP
Loaded 833 original images for BigBag4_3_PE
Loaded 173 original images for BigBag4_4_Tetra
Loaded 254 original images for BigBag4_6_PS
Loaded 3 original images for BigBag4_5_PVC
Loaded 1373 original images for BigBag4_7_Other
Loaded 458 original images for BigBag1_1_PET
Loaded 414 original images for BigBag1_2_PP
Loaded 518 original images for BigBag1_3_PE
Loaded 21 original images for BigBag1_4_Tetra
Loaded 47 original images for BigBag1_6_PS
Loaded 984 original images for

In [None]:
# #### Without class 'Other'


# from torchvision import transforms
# from torch.utils.data import Dataset
# from PIL import Image
# import os
# from tqdm import tqdm
# from collections import Counter

# # Updated class mapping to combine classes into unique classes (including 'Augmented')
# # Removed the 'Other' class (class 6)
# class_mapping = {
#     #BigBag2
#     'BigBag2_1_PET': 0,  # PET
#     'BigBag2_2_PP': 1,   # PP
#     'BigBag2_3_PE': 2,   # PE
#     'BigBag2_4_Tetra': 3, # Tetra
#     'BigBag2_5_PVC': 5,  # PVC
#     'BigBag2_6_PS': 4,   # PS
#     'BigBag2_4_Tetra_Augmented': 3,  # Augmented Tetra
#     'BigBag2_6_PS_Augmented': 4,  # Augmented PS
    
#     #BigBag4
#     'BigBag4_1_PET': 0,  # PET
#     'BigBag4_2_PP': 1,   # PP
#     'BigBag4_3_PE': 2,   # PE
#     'BigBag4_4_Tetra': 3, # Tetra
#     'BigBag4_6_PS': 4,   # PS
#     'BigBag4_5_PVC': 5,  # PVC
    
#     #BigBag1
#     'BigBag1_1_PET': 0,  # PET
#     'BigBag1_2_PP': 1,   # PP
#     'BigBag1_3_PE': 2,   # PE
#     'BigBag1_4_Tetra': 3, # Tetra
#     'BigBag1_6_PS': 4,   # PS
    
#     #BigBag3
#     'BigBag3_PET': 0,  # PET
#     'BigBag3_2_PP': 1,   # PP
#     'BigBag3_PE': 2,   # PE
#     'BigBag3_TETRA': 3, # Tetra
#     #'BigBag3_PVC': 5, # PVC
#     'BigBag3_6_PS': 4,   # PS
    
#     'DWRL7_extension_2_PVC': 5,  # PVC
#     'BigBag2_5_PVC_Augmented': 5,  # Augmented PVC
# }

# class CustomPlasticDataset(Dataset):
#     def __init__(self, root_dir, class_mapping, transform, tetra_transform=None, ps_transform=None, pvc_transform=None, diverse_transform=None):
#         self.root_dir = root_dir
#         self.class_mapping = class_mapping
#         self.transform = transform
#         self.tetra_transform = tetra_transform
#         self.ps_transform = ps_transform
#         self.pvc_transform = pvc_transform
#         self.diverse_transform = diverse_transform
#         self.image_paths = []
#         self.labels = []
        
#         # Gather image paths and labels
#         for class_folder in class_mapping.keys():
#             # Load original images
#             image_dir = os.path.join(root_dir, class_folder, 'images_cutout')
#             if os.path.exists(image_dir):
#                 image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
#                 print(f"Loaded {len(image_files)} original images for {class_folder}")
#                 self.image_paths.extend([os.path.join(image_dir, img) for img in image_files])
#                 self.labels.extend([class_mapping[class_folder]] * len(image_files))

#             # Load augmented images if they exist
#             augmented_dir = os.path.join(root_dir, class_folder)  # Path to the augmented class
#             if os.path.exists(augmented_dir) and 'Augmented' in class_folder:
#                 augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
#                 print(f"Loaded {len(augmented_files)} augmented images for {class_folder}")
#                 self.image_paths.extend([os.path.join(augmented_dir, img) for img in augmented_files])
#                 self.labels.extend([class_mapping[class_folder]] * len(augmented_files))  # Map to the same class label

#     def __len__(self):
#         return len(self.image_paths)
    
#     def __getitem__(self, idx):
#         img_path = self.image_paths[idx]
#         image = Image.open(img_path)
#         label = self.labels[idx]
        
#         # Apply specific transformations based on class
#         if label == 3:  # Tetra
#             image = self.tetra_transform(image)
#         elif label == 4:  # PS
#             image = self.ps_transform(image)
#         elif label == 5:  # PVC
#             image = self.pvc_transform(image)
#         else:
#             # Use diverse_transform with a 50% chance for other classes
#             if self.diverse_transform is not None and random.random() > 0.5:
#                 image = self.diverse_transform(image)
#             else:
#                 image = self.transform(image)
        
#         return image, label


# # Augmentation for TETRA (Moderate)
# tetra_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(30),
#     transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
#     transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# # Augmentation for PS (Light)
# ps_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomCrop(224, padding=4),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# # Regular transform for other classes
# regular_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
# ])
    
# # Augmentation for PVC (Light)
# pvc_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomHorizontalFlip(),
#     transforms.RandomRotation(15),
#     transforms.ColorJitter(brightness=0.3, contrast=0.3),
#     transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# # Define more diverse transformations for augmentation
# diverse_transform = transforms.Compose([
#     transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
#     transforms.RandomRotation(degrees=30),  # Random rotation by up to 30 degrees
#     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color adjustments
#     transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),  # Random crop with resizing
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# def augment_and_save_images(original_dataset, class_label, transform, target_dir, num_augmented_images):
#     os.makedirs(target_dir, exist_ok=True)  # Create the target directory if it doesn't exist
#     class_images = [original_dataset.image_paths[idx] for idx, label in enumerate(original_dataset.labels) if label == class_label]

#     for img_path in tqdm(class_images, desc=f'Augmenting Class {class_label}'):
#         img = Image.open(img_path)
#         for i in range(num_augmented_images):
#             augmented_img = transform(img)  # Apply the transformation
#             augmented_img = transforms.ToPILImage()(augmented_img)  # Convert back to PIL Image
            
#             # Create a unique filename using the original filename and the index
#             base_filename = os.path.basename(img_path).split('.')[0]  # Get the original filename without extension
#             augmented_img.save(os.path.join(target_dir, f'augmented_{base_filename}_{i}.png'))  # Save with a unique name

# # Function to count images in each class
# def count_images_in_classes(base_dir, class_mapping):
#     class_counts = {class_name: 0 for class_name in class_mapping.keys()}

#     for class_name in class_mapping.keys():
#         # Check for original images
#         image_dir = os.path.join(base_dir, class_name, 'images_cutout')  # Original images path
#         if os.path.exists(image_dir):
#             image_files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
#             class_counts[class_name] += len(image_files)

#         # Check for augmented images in their respective folders
#         if 'Augmented' in class_name:
#             augmented_dir = os.path.join(base_dir, class_name)  # Path to the augmented class
#             if os.path.exists(augmented_dir):
#                 augmented_files = [f for f in os.listdir(augmented_dir) if f.endswith(('.jpg', '.png'))]
#                 class_counts[class_name] += len(augmented_files)  # Count augmented images

#     return class_counts

# # Directory where all class folders are stored
# data_dir = '/raid/home/somayeh.shami/project/somayeh_workspace/DWRL7/data'

# # Create dataset instance
# plastic_dataset = CustomPlasticDataset(
#     root_dir=data_dir, 
#     class_mapping=class_mapping, 
#     transform=regular_transform,
#     tetra_transform=tetra_transform,
#     ps_transform=ps_transform,
#     pvc_transform=pvc_transform,
#     diverse_transform=diverse_transform 
# )

# # After creating the dataset instance, count the classes
# initial_class_counts = Counter(plastic_dataset.labels)
# print('Initial ClassCounts after augmentation:', initial_class_counts)

In [6]:
from torch.utils.data import Subset, random_split, DataLoader, WeightedRandomSampler
from collections import defaultdict, Counter
import random
import torch

# Step 1: Separate Original and Augmented Images by class
original_indices_by_class = defaultdict(list)
augmented_indices_by_class = defaultdict(list)

for i, (path, label) in enumerate(zip(plastic_dataset.image_paths, plastic_dataset.labels)):
    if 'Augmented' not in path:  # Original images
        original_indices_by_class[label].append(i)
    else:  # Augmented images
        augmented_indices_by_class[label].append(i)

# Step 2: Calculate test size based on total images (original + augmented)
test_indices = []
train_val_indices_by_class = defaultdict(list)

for label, original_indices in original_indices_by_class.items():
    total_images_in_class = len(original_indices) + len(augmented_indices_by_class[label])  # Total images in class
    test_size = int(0.1 * total_images_in_class)  # 10% of total images (original + augmented)
    
    # Ensure there are enough original images to select from for the test set
    if test_size > len(original_indices):
        raise ValueError(f"Not enough original images in class {label} for test set.")
    
    random.shuffle(original_indices)  # Shuffle for randomness
    test_indices.extend(original_indices[:test_size])  # First 10% goes to test set
    train_val_indices_by_class[label].extend(original_indices[test_size:])  # Remaining original images

# Step 3: Combine remaining original and augmented images for training/validation split
train_val_indices = []
for label, original_remaining in train_val_indices_by_class.items():
    train_val_indices.extend(original_remaining)  # Add remaining original images
    train_val_indices.extend(augmented_indices_by_class[label])  # Add augmented images for this class

# Step 4: Split the remaining images into training and validation sets (80/20 split)
train_size = int(0.8 * len(train_val_indices))  # 80% for training
val_size = len(train_val_indices) - train_size  # 20% for validation

# Randomly split the combined dataset into train and validation sets
train_indices, val_indices = random_split(train_val_indices, [train_size, val_size])

# Step 5: Create Subsets for train, validation, and test sets
train_dataset = Subset(plastic_dataset, train_indices)
val_dataset = Subset(plastic_dataset, val_indices)
test_dataset = Subset(plastic_dataset, test_indices)  # Test set with only original images

# Step 6: Split train_dataset into 4 clients
client_split_size = len(train_dataset) // 4
remaining = len(train_dataset) - client_split_size * 4
split_sizes = [client_split_size] * 4
split_sizes[-1] += remaining  # Adjust last split for any remainder

client_datasets = random_split(train_dataset, split_sizes)

# Step 7: Create DataLoaders with WeightedRandomSampler for each client
batch_size = 32
client_loaders = []

for client_data in client_datasets:
    # Calculate sample weights for each client
    client_labels = [train_dataset.dataset.labels[idx] for idx in client_data.indices]
    client_class_counts = Counter(client_labels)
    client_class_weights = {cls: 1.0 / client_class_counts[cls] for cls in client_class_counts}
    
    # Create sample weights for each image
    sample_weights = [client_class_weights[label] for label in client_labels]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    
    # Create DataLoader with the sampler
    client_loader = DataLoader(client_data, batch_size=batch_size, sampler=sampler)
    client_loaders.append(client_loader)

# Step 8: Create DataLoaders for validation and test sets without sampling
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Print dataset sizes for verification
print(f'Training set size (with original + augmented): {len(train_dataset)}')
print(f'Validation set size (with original + augmented): {len(val_dataset)}')
print(f'Test set size (original only): {len(test_dataset)}')
print("DataLoader setup complete for 4 clients.")

Training set size (with original + augmented): 14721
Validation set size (with original + augmented): 3681
Test set size (original only): 2040
DataLoader setup complete for 4 clients.


In [5]:
# Print sizes of each client dataset
for i, client_data in enumerate(client_datasets):
    print(f"Client {i+1} dataset size: {len(client_data)}")


Client 1 dataset size: 3680
Client 2 dataset size: 3680
Client 3 dataset size: 3680
Client 4 dataset size: 3681


In [8]:
import torch
import torch.nn as nn
from collections import Counter

# Count the frequency of each class in the entire training dataset
train_labels = [plastic_dataset.labels[i] for i in train_dataset.indices]
train_class_counts = Counter(train_labels)

# Print class counts for debugging
print('Training ClassCounts:', train_class_counts)

# Calculate weights for each class: inverse of frequency for the classes (0 to 6)
class_weights = {cls: 1.0 / count for cls, count in train_class_counts.items() if cls < 7}

# Create the class_weight_tensor based on class weights (for classes 0-6)
class_weight_tensor = torch.tensor(
    [class_weights.get(i, 0.0) for i in range(7)]
).to(device)

# Print class weights for debugging
print("Class Weights:", class_weight_tensor)

# Use CrossEntropyLoss with class weights in each client’s training
criterion = nn.CrossEntropyLoss(weight=class_weight_tensor)

Training ClassCounts: Counter({6: 3102, 1: 2810, 2: 2102, 5: 1862, 0: 1818, 4: 1730, 3: 1297})
Class Weights: tensor([0.0006, 0.0004, 0.0005, 0.0008, 0.0006, 0.0005, 0.0003],
       device='cuda:0')


In [10]:
scaled_class_weight_tensor = class_weight_tensor * 1000
print("Scaled Class Weights:", scaled_class_weight_tensor)

Scaled Class Weights: tensor([0.5501, 0.3559, 0.4757, 0.7710, 0.5780, 0.5371, 0.3224],
       device='cuda:0')


### Normalizing Class Wights

In [None]:
# # Calculate the total sum of class weights
# total_weight = sum(class_weight_tensor.cpu().numpy())

# # Normalize the weights
# normalized_class_weight_tensor = class_weight_tensor / total_weight

# # Convert back to tensor if necessary
# normalized_class_weight_tensor = normalized_class_weight_tensor.to(device)

# # Print the normalized weights for debugging
# print("Normalized Class Weights:", normalized_class_weight_tensor)

# # Use CrossEntropyLoss with class weights
# criterion = nn.CrossEntropyLoss(weight=normalized_class_weight_tensor)

Normalized Class Weights: tensor([0.1531, 0.0990, 0.1359, 0.2139, 0.1608, 0.1486, 0.0887],
       device='cuda:0')


### FL with FedAvg

In [None]:
# import torch
# import torch.optim as optim
# import torch.nn as nn
# from torchvision import models

# # Define the model preparation function for 7 classes (DWRL)
# def prepare_model(num_classes=7):
#     """Load a pre-trained ResNet18 model and modify it for DWRL."""
#     model = models.resnet18(pretrained=True)
#     model.fc = nn.Linear(model.fc.in_features, num_classes)
#     return model

# # Early Stopping Class (unchanged from previous code)
# class EarlyStopping:
#     def __init__(self, patience=5, min_delta=0):
#         self.patience = patience
#         self.min_delta = min_delta
#         self.counter = 0
#         self.best_loss = None
#         self.early_stop = False

#     def __call__(self, val_loss):
#         if self.best_loss is None:
#             self.best_loss = val_loss
#         elif val_loss > self.best_loss - self.min_delta:
#             self.counter += 1
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_loss = val_loss
#             self.counter = 0

# # Federated Training Functions
# def train_client(model, train_loader, criterion, optimizer, epochs=1):
#     model.to(device)
#     model.train()
#     for _ in range(epochs):
#         for inputs, labels in train_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#     return model.state_dict()

# def federated_averaging(state_dicts):
#     avg_state_dict = {}
#     for key in state_dicts[0].keys():
#         avg_state_dict[key] = sum(state_dict[key] for state_dict in state_dicts) / len(state_dicts)
#     return avg_state_dict

# # Federated Training Loop for DWRL with 4 Clients
# def train_federated_model(client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0):
#     model = prepare_model(num_classes=7).to(device)
#     criterion = nn.CrossEntropyLoss(weight=normalized_class_weight_tensor)  # Apply class weights for imbalance handling
#     early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
#     for round in range(num_epochs):
#         print(f"Starting federated learning round {round+1}/{num_epochs}...")
#         state_dicts = []
        
#         # Train each client's model
#         for i, client_loader in enumerate(client_loaders):
#             print(f"Training model for client {i+1}...")
#             client_model = prepare_model(num_classes=7).to(device)
#             client_model.load_state_dict(model.state_dict())
#             optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
#             client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
#             state_dicts.append(client_state_dict)

#         # Federated Averaging
#         avg_state_dict = federated_averaging(state_dicts)
#         model.load_state_dict(avg_state_dict)
#         model.to(device)
        
#         # Validation Phase
#         print("Validating model...")
#         model.eval()
#         val_running_loss = 0.0
#         val_correct = 0
#         val_total = 0
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs, labels = inputs.to(device), labels.to(device)
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
#                 val_running_loss += loss.item()
#                 _, predicted = torch.max(outputs.data, 1)
#                 val_total += labels.size(0)
#                 val_correct += (predicted == labels).sum().item()

#         val_loss = val_running_loss / len(val_loader)
#         val_accuracy = 100 * val_correct / val_total
#         print(f'Federated Round {round+1}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

#         # Early stopping
#         early_stopping(val_loss)
#         if early_stopping.early_stop:
#             print("Early stopping")
#             break

#     # Test Phase
#     print("Testing model...")
#     model.eval()
#     test_running_loss = 0.0
#     test_correct = 0
#     test_total = 0
#     with torch.no_grad():
#         for inputs, labels in test_loader:
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             test_running_loss += loss.item()
#             _, predicted = torch.max(outputs.data, 1)
#             test_total += labels.size(0)
#             test_correct += (predicted == labels).sum().item()

#     test_loss = test_running_loss / len(test_loader)
#     test_accuracy = 100 * test_correct / test_total
#     print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

#     # Save the final model
#     model_save_path = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/final_federated_model_DWRL_4_clients.pth'
#     torch.save(model.state_dict(), model_save_path)
#     print(f"Model saved to {model_save_path}")

#     return model, val_loss, val_accuracy, test_loss, test_accuracy

# # Parameters
# num_clients = 4
# num_epochs = 40
# learning_rate = 0.0001
# patience = 5
# min_delta = 0.01

# # Start federated training
# print("Starting federated training for DWRL...")
# model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model(
#     client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate, patience, min_delta)
# print("Federated training complete.")

Starting federated training for DWRL...




Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 1/40, Val Loss: 0.6332, Val Accuracy: 72.40%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 2/40, Val Loss: 0.5697, Val Accuracy: 74.87%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 3/40, Val Loss: 0.4952, Val Accuracy: 77.15%
Starting federated learning round 4/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Validating model...
Federated Round 4/40, Val Loss: 0.4822, Val Accuracy: 78.67%
Starting fed

### FL with FedLAW

In [12]:
# Sample client losses for testing (replace these with realistic values if needed)
client_losses = [1.2, 0.8, 1.5, 1.0]  # Simulated validation losses for 4 clients

# Normalize client losses to use as weights (lower loss = higher weight)
total_loss = sum(client_losses)
weights = [(1 - loss / total_loss) for loss in client_losses]

# Normalize weights to ensure they sum to 1
weights = [w / sum(weights) for w in weights]

# Print results for debugging
print(f"Client Losses: {client_losses}")
print(f"Calculated Weights: {weights}")

Client Losses: [1.2, 0.8, 1.5, 1.0]
Calculated Weights: [0.24444444444444446, 0.2740740740740741, 0.22222222222222224, 0.25925925925925924]


In [13]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import models

# Define the model preparation function for 7 classes (DWRL)
def prepare_model(num_classes=7):
    """Load a pre-trained ResNet18 model and modify it for DWRL."""
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Federated Training Function with adaptive weighting for FedLAW
def train_client(model, train_loader, criterion, optimizer, epochs=1):
    model.to(device)
    model.train()
    for _ in range(epochs):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def federated_adaptive_weighting(state_dicts, weights):
    # Adaptive weighted averaging
    avg_state_dict = {}
    for key in state_dicts[0].keys():
        avg_state_dict[key] = sum(weights[i] * state_dicts[i][key] for i in range(len(state_dicts)))
    return avg_state_dict

# Federated Training Loop with FedLAW for DWRL with 4 Clients
def train_federated_model_with_FedLAW(client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate=0.0001, patience=5, min_delta=0):
    model = prepare_model(num_classes=7).to(device)
    criterion = nn.CrossEntropyLoss(weight=scaled_class_weight_tensor)  # Apply class weights for imbalance handling
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
    
    for round in range(num_epochs):
        print(f"Starting federated learning round {round+1}/{num_epochs}...")
        state_dicts = []
        client_losses = []
        
        # Train each client's model
        for i, client_loader in enumerate(client_loaders):
            print(f"Training model for client {i+1}...")
            client_model = prepare_model(num_classes=7).to(device)
            client_model.load_state_dict(model.state_dict())
            optimizer = optim.Adam(client_model.parameters(), lr=learning_rate)
            client_state_dict = train_client(client_model, client_loader, criterion, optimizer)
            state_dicts.append(client_state_dict)

            # Calculate validation loss to determine adaptive weight
            client_model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = client_model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader)
            client_losses.append(val_loss)

        # Normalize client losses to use as weights (lower loss = higher weight)
        # total_loss = sum(client_losses)
        # weights = [(1 - loss / total_loss) for loss in client_losses]
        
       # Normalize client losses to use as weights (lower loss = higher weight)
        total_loss = sum(client_losses)
        weights = [(1 - loss / total_loss) for loss in client_losses]

        # Normalize weights to ensure they sum to 1
        weights = [w / sum(weights) for w in weights]
        print(f"Round {round+1} Client Weights: {weights}")

        
        # Perform Federated Adaptive Weighted Averaging (FedLAW)
        avg_state_dict = federated_adaptive_weighting(state_dicts, weights)
        model.load_state_dict(avg_state_dict)
        model.to(device)
        
        # Validation Phase
        print("Validating model...")
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f'Federated Round {round+1}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    # Test Phase
    print("Testing model...")
    model.eval()
    test_running_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss = test_running_loss / len(test_loader)
    test_accuracy = 100 * test_correct / test_total
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

    # Save the final model
    model_save_path = '/raid/home/somayeh.shami/project/somayeh_workspace/federated_learning/models/final_federated_model_DWRL_FedLAW.pth'
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    return model, val_loss, val_accuracy, test_loss, test_accuracy

# Parameters
num_clients = 4
num_epochs = 40
learning_rate = 0.0001
patience = 5
min_delta = 0.01

# Start federated training with FedLAW
print("Starting federated training for DWRL with FedLAW...")
model, val_loss, val_accuracy, test_loss, test_accuracy = train_federated_model_with_FedLAW(
    client_loaders, val_loader, test_loader, num_clients, num_epochs, learning_rate, patience, min_delta)
print("Federated training complete.")

Starting federated training for DWRL with FedLAW...
Starting federated learning round 1/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Round 1 Client Weights: [0.2450785343709934, 0.2551459911047975, 0.25552735733174214, 0.244248117192467]
Validating model...
Federated Round 1/40, Val Loss: 0.6588, Val Accuracy: 70.99%
Starting federated learning round 2/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Round 2 Client Weights: [0.25128937026691145, 0.25101796629929657, 0.24709984520386205, 0.25059281822992996]
Validating model...
Federated Round 2/40, Val Loss: 0.5580, Val Accuracy: 74.98%
Starting federated learning round 3/40...
Training model for client 1...
Training model for client 2...
Training model for client 3...
Training model for client 4...
Round 3 Client Weights: [0.25595510152336204, 0.250070379311115, 0.2540091