In [None]:
from __future__ import print_function
from __future__ import print_function
from __future__ import division
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torchvision
import torch.utils.data as data
from torchsummary import summary
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from PIL import Image
from pathlib import Path
import os.path
import matplotlib.pyplot as plt
import csv
from torchvision import models
import time
import os
import copy
import scipy.io
import numpy as np
import math
import glob as glob
import pandas as pd
from collections import defaultdict
import random
import warnings
from collections import Counter

warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
random_seed = 42
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)  # Only for single GPU
random.seed(random_seed)

# Ensuring deterministic behavior in convolutional operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Hyperparameters

In [None]:
root = './patches'

global_available=True #If not available, then make it false

#Hyper-parameters
batchsize=128
ROUNDS = 100
CLIENT_EPOCHS=1
GLOBAL_EPOCHS=30
learning_rate=0.001


use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0') #Update according to your requirement

#Extra printing param
prompt=True

In [None]:
cwd = os.getcwd()
dataset = cwd.split('/')[-2]
method = cwd.split('/')[-1]

if global_available:
    initial_datapath = os.path.join(root,'initial')
    
clients_path = os.path.join(root,'clients')
test_datapath = os.path.join(root,'test')

model_name = agg_algo+"_"+dataset+"_"+method+".pth"

if global_available:
    preclient_csv = 'preClientTrain_'+method+'.csv'
    
postclient_csv = 'postClientTrain_'+method+'.csv'

DATA LOADER

In [None]:
IMG_EXTENSIONS = [
   '.jpg', '.JPG', '.jpeg', '.JPEG',
   '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP','.mat',
]

def is_image_file(filename):
   return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
   classes = os.listdir(dir)
   classes.sort()
   class_to_idx = {classes[i]: i for i in range(len(classes))}
   return classes, class_to_idx

In [None]:
def make_dataset(dir, class_to_idx):
   images = []
   for target in os.listdir(dir):
       d = os.path.join(dir, target)
       if not os.path.isdir(d):
           continue

       for filename in os.listdir(d):
           if is_image_file(filename):
               path = '{0}/{1}'.format(target, filename)
               item = (path, class_to_idx[target])
               images.append(item)

   return images

def default_loader(path):
   return Image.open(path).convert('RGB')

def mat_loader(path):
   return scipy.io.loadmat(path1)

In [None]:
classes, class_to_idx = find_classes(test_datapath)
dataset = make_dataset(test_datapath, class_to_idx)
if prompt:
    print("Dataset Structure(first 5):\n",dataset[:5])
    print("\nLength of Dataset: ",len(dataset))
    print("\nClass Mapping:\n",class_to_idx)

In [None]:
class ImageFolderLoader(data.Dataset):
    def __init__(self, root1, transform_1=None,
                 target_transform=None,
                 loader=default_loader, filename_return=False):
        classes1, class_to_idx1 = find_classes(root1)
        imgs1 = make_dataset(root1, class_to_idx1)

        self.root1 = root1
        self.imgs1 = imgs1
        self.classes1 = classes1
        self.class_to_idx1 = class_to_idx1
        self.transform_1 = transform_1
        self.target_transform = target_transform
        self.loader = loader

        self.img_noise = None
        self.img_rgb = None
        self.filename_return = filename_return

    def FFT(self):
        img = self.img_rgb  # PIL image
        img_np = np.array(img)  # Convert to NumPy array (H, W, C)

        if img_np.ndim == 2:
            img_np = img_np[:, :, np.newaxis]

        fft_channels = []
        for c in range(img_np.shape[2]):
            channel = img_np[:, :, c]
            f = np.fft.fft2(channel)
            fshift = np.fft.fftshift(f)
            magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8)  # avoid log(0)
            fft_channels.append(magnitude_spectrum)

        fft_np = np.stack(fft_channels, axis=0)  # Shape: (C, H, W)
        self.img_noise = torch.tensor(fft_np, dtype=torch.float32)

    def __getitem__(self, index):
        path1, target1 = self.imgs1[index]
        img1 = self.loader(os.path.join(self.root1, path1))
        filename = Path(path1).stem

        self.img_rgb = img1
        self.FFT()

        if self.transform_1 is not None:
            img1 = self.transform_1(self.img_rgb)

        if self.target_transform is not None:
            target1 = self.target_transform(target1)

        if self.filename_return:
            return img1, self.img_noise, target1, filename

        return img1, self.img_noise, target1

    def __len__(self):
        return len(self.imgs1)


# Define your transforms
data_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
if global_available:
    initial_dataset = ImageFolderLoader(initial_datapath,transform_1=data_transforms)
    
    initial_loader = torch.utils.data.DataLoader(
            initial_dataset, batch_size=batchsize,
            shuffle=True, num_workers=8
      )


clients_loader = []
clients = sorted(os.listdir(clients_path))
for client in clients:
    clt_path = os.path.join(clients_path, client)

    client_dataset = ImageFolderLoader(clt_path,transform_1=data_transforms)

    client_loader = torch.utils.data.DataLoader(
        client_dataset, batch_size=batchsize,
        shuffle=True, num_workers=8
      )

    clients_loader.append(client_loader)


test_dataset = ImageFolderLoader(test_datapath,transform_1=data_transforms)

test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batchsize,
        shuffle=True, num_workers=8
  )

In [None]:
total_client_images = 0
clients = len(clients_loader)
client_images = []
for i, client_loader in enumerate(clients_loader):
    num_images_client =len(client_loader.dataset)
    client_images.append(num_images_client)
    total_client_images += num_images_client
    
if prompt:
    print("Total client images: ",total_client_images)
    print("Images per client:",client_images)

MODEL BUILDING

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.shared(self.avg_pool(x))
        max_out = self.shared(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        max_, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg, max_], dim=1)
        return self.sigmoid(self.conv(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Load pretrained MobileNetV2
        mobilenet1 = models.mobilenet_v2(pretrained=True)
        mobilenet2 = models.mobilenet_v2(pretrained=True)

        # Extract only the feature extractor part (excluding classifier)
        self.feature1 = mobilenet1.features  # [B, 1280, H/32, W/32]
        self.feature2 = mobilenet2.features

        # Final output of MobileNetV2 feature extractor is 1280 channels
        self.cbam1 = CBAM(1280)
        self.cbam2 = CBAM(1280)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Pool to [B, 1280, 1, 1]

        self.fc1 = nn.Linear(1280 * 2, 512)
        self.fc2 = nn.Linear(512, len(class_to_idx))

        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x1, x2):
        x1 = self.feature1(x1)
        x1 = self.cbam1(x1)
        x1 = self.pool(x1)
        x1 = x1.view(x1.size(0), -1)  # [B, 1280]

        x2 = self.feature2(x2)
        x2 = self.cbam2(x2)
        x2 = self.pool(x2)
        x2 = x2.view(x2.size(0), -1)  # [B, 1280]

        x = torch.cat((x1, x2), dim=1)  # [B, 2560]
        x = self.fc1(x)                # [B, 512]
        x = self.fc2(x)                # [B, num_classes]

        out_fc = x
        output = self.logsoftmax(x)
        return output, out_fc

In [None]:
def custom_categorical_cross_entropy(y_pred, y_true):
    y_pred = torch.clamp(y_pred, 1e-9, 1 - 1e-9)
    return -(y_true * torch.log(y_pred)).sum(dim=1).mean()

In [None]:
def modelInit():
    model = Net().to(device)
    return model

In [None]:
model = modelInit()
criterion = nn.NLLLoss()
optimizerGlobal = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
if prompt:
    print("Device:- ",device)
    print(model)

TRAINING AND TESTING

In [None]:
def Test(model, test_loader, return_result=False):
    model.eval()
    
    test_loss = 0
    total_correct_test = 0
    total_batches = len(test_loader)

    with torch.no_grad():
        for batch_idx, (imgs1,imgs2,labels1) in enumerate(test_loader):
            img_org, mat_img, target = imgs1.to(device),imgs2.to(device), labels1.to(device)
            output,_ = model(img_org,mat_img)

            # Compute loss
            loss = criterion(output, target)
            test_loss += loss.item()  # Sum up batch loss
            
            # Calculate accuracy
            _, predicted = torch.max(output.data, 1)
            correct = (predicted == target).sum().item()
            total_correct_test += correct

            # Show progress
            prompt=False
            if prompt:
                if batch_idx % 10 == 0 or batch_idx == total_batches - 1:  # Adjust frequency as needed
                    print(f"Processed {batch_idx + 1}/{total_batches} batches.")

    # Calculate average loss and accuracy
    avg_loss = test_loss / total_batches
    accuracy = 100.0 * total_correct_test / len(test_loader.dataset)
    
    # Print final results
    print(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
            
    if return_result:
        return {'accuracy': accuracy, 'loss': avg_loss}
    else:
        return

In [None]:
def Train(model, train_data, optimizer, valid_data=None, epochs=1, return_result=True, 
          overwrite_model=True, save_path='best_model.pth', csv_file=None, 
          global_model=None):
    
    train_loss_history = []
    valid_loss_history = []
    train_accuracy_history = []
    valid_accuracy_history = []
    best_accuracy = 0  # Track the best validation accuracy

    # Check if CSV logging is enabled
    if csv_file:
        csv_exists = os.path.isfile(csv_file)

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        total_correct_train = 0
        total_batches = len(train_data)

        for batch_idx, (imgs1,imgs2,labels1) in enumerate(train_data):
            img_org, mat_img, target = imgs1.to(device),imgs2.to(device), labels1.to(device)

            # Forward pass
            optimizer.zero_grad()
            output,_ = model(img_org,mat_img)
            loss = criterion(output, target)

            total_train_loss += loss.item()

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Calculate accuracy for the batch
            _, predicted = torch.max(output.data, 1)
            correct = (predicted == target).sum().item()
            total_correct_train += correct

        # Calculate and store average training loss and accuracy for the epoch
        avg_train_loss = total_train_loss / total_batches
        train_accuracy = 100.0 * total_correct_train / len(train_data.dataset)
        train_loss_history.append(avg_train_loss)
        train_accuracy_history.append(train_accuracy)
        print(f"\nEpoch [{epoch + 1}/{epochs}] - Average Training Loss: {avg_train_loss:.4f}, "
              f"Training Accuracy: {train_accuracy:.2f}%")

        # Optional validation step
        avg_valid_loss = None
        valid_accuracy = None
        if valid_data is not None:
            validation_results = Test(model, valid_data, return_result=True)
            avg_valid_loss = validation_results['loss']
            valid_accuracy = validation_results['accuracy']
            valid_loss_history.append(avg_valid_loss)
            valid_accuracy_history.append(valid_accuracy)

            # Save model if validation accuracy improves
            if valid_accuracy > best_accuracy:
                best_accuracy = valid_accuracy
                if overwrite_model:
                    torch.save(model.state_dict(), save_path)
                    print(f"New best model saved with accuracy: {valid_accuracy:.2f}%\n")

        # Write metrics to CSV if csv_file parameter is provided
        if csv_file:
            with open(csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                # Write header only if the file does not exist
                if not csv_exists:
                    writer.writerow(['Epoch', 'Train Loss', 'Train Accuracy', 'Valid Loss', 'Valid Accuracy'])
                    csv_exists = True  # Set flag to avoid rewriting header in subsequent epochs

                writer.writerow([
                    epoch + 1,
                    avg_train_loss,
                    train_accuracy,
                    avg_valid_loss if valid_data else None,
                    valid_accuracy if valid_data else None
                ])

    # Return training and validation results if required
    if return_result:
        results = {
            "train_loss": train_loss_history,
            "train_accuracy": train_accuracy_history,
            "valid_loss": valid_loss_history if valid_data else None,
            "valid_accuracy": valid_accuracy_history if valid_data else None
        }
        return results

In [None]:
def fedAvg(global_model, client_models, client_images=client_images, total_client_images=total_client_images):
    # Initialize an empty state dictionary to accumulate the weighted weights
    avg_state_dict = {key: torch.zeros_like(value) for key, value in global_model.state_dict().items()}

    # Weighted sum of each client's model parameters
    for i in range(len(client_models)):
        client_state_dict = client_models[i].state_dict()
        weight = client_images[i] / total_client_images
        for key in avg_state_dict:
            if avg_state_dict[key].dtype == torch.float32:  # Only perform weighted average for float tensors
                avg_state_dict[key] += client_state_dict[key] * weight
            else:
                avg_state_dict[key] += client_state_dict[key]  # Direct addition for non-float types

    # Load the averaged weights back into the global model
    global_model.load_state_dict(avg_state_dict)

In [None]:
def clientTrain(global_model, clients_loader, rounds=1, client_epochs=1,method="FedAvg", return_result=True, save_path='best_model.pth', csv_file=None):
    acc = []
    loss = []
    best_acc = None

    # Check if CSV logging is enabled
    if csv_file:
        csv_exists = os.path.isfile(csv_file)

    for round in range(rounds):
        if prompt:
            print("Round: ", (round + 1))
        client_models = []

        for i in range(clients):
            if prompt:
                print("Client: ", (i + 1))
            modelClient = modelInit()
            optimizerClient = torch.optim.Adam(modelClient.parameters(), lr=learning_rate)
            modelClient.load_state_dict(global_model.state_dict())

            clientHistory = Train(modelClient, clients_loader[i],optimizer=optimizerClient, epochs=client_epochs, overwrite_model=False)
    
            client_models.append(modelClient)

        fedAvg(global_model, client_models, client_images=client_images, total_client_images=total_client_images)
    
        # Validation on test data
        result = Test(global_model, test_loader, return_result=True)
        acc.append(result['accuracy'])
        loss.append(result['loss'])

        # Save best model
        if best_acc is None or acc[-1] > best_acc:
            best_acc = acc[-1]
            torch.save(global_model.state_dict(), save_path)

        # Write metrics to CSV if csv_file parameter is provided
        if csv_file:
            with open(csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                # Write header only if the file does not exist
                if not csv_exists:
                    writer.writerow(['Round', 'Validation Loss', 'Validation Accuracy'])
                    csv_exists = True  # Update flag to prevent re-writing the header

                writer.writerow([round + 1, result['loss'], result['accuracy']])

    # Return validation loss and accuracy history if required
    if return_result:
        results = {
            "valid_loss": loss,
            "valid_accuracy": acc
        }
        return results
    else:
        return

In [None]:
if global_available:
    history = Train(model, epochs=GLOBAL_EPOCHS,optimizer = optimizerGlobal, train_data=initial_loader, valid_data=test_loader, save_path=model_name,csv_file=preclient_csv)

In [None]:
model.load_state_dict(torch.load(model_name))

In [None]:
history2 = clientTrain(model,clients_loader,client_epochs=CLIENT_EPOCHS,rounds=ROUNDS,save_path=model_name,csv_file=postclient_csv,method=agg_algo)

FINDING PLA AND ILA

In [None]:
model.load_state_dict(torch.load(model_name))

In [None]:
val_dataset = ImageFolderLoader(
        test_datapath,
        data_transforms,
        filename_return=True
    )

test_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1,
        shuffle=False, num_workers=12
  )

In [None]:
def patch_level_accuracy(test_loader, model, result_folder='Results'):
    model.eval()
    device = next(model.parameters()).device
    os.makedirs(result_folder, exist_ok=True)
    
    class_data = defaultdict(list)
    total_correct = 0
    total_samples = 0
    progress_interval = max(len(test_loader) // 10, 1)  # Print progress every 10% of batches

    with torch.no_grad():
        for i, (imgs1,imgs2, targets, filenames) in enumerate(test_loader):
            img_org,mat_img, target = imgs1.to(device), imgs2.to(device), targets.to(device)
            output, _ = model(img_org,mat_img)
            predicted_class = torch.max(output, 1)[1]
            target_class = target
            
            for j in range(img_org.size(0)):
                filename = filenames[j]
                entry = {
                    'Filename': filename,
                    'Target Class': target_class[j].item(),
                    'Predicted Class': predicted_class[j].item()
                }
                class_data[target_class[j].item()].append(entry)

                # Update accuracy tracking
                if target_class[j].item() == predicted_class[j].item():
                    total_correct += 1
                total_samples += 1
            
            # Show progress at 10% intervals
            if (i + 1) % progress_interval == 0 or i == len(test_loader) - 1:
                print(f"Progress Done (PLA): {((i + 1) / len(test_loader)) * 100:.1f}%")

    # Save predictions to class-specific CSV files
    for target_class, entries in class_data.items():
        df = pd.DataFrame(entries)
        df.to_csv(os.path.join(result_folder, f'class_{target_class}.csv'), index=False)

    # Calculate and print patch level accuracy
    patch_level_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0
    print(f"Patch Level Accuracy: {patch_level_accuracy:.2f}%")
    
    # Flatten predicted data for further processing
    predicted_data = []
    for entries in class_data.values():
        for entry in entries:
            predicted_data.append((entry['Filename'], entry['Target Class'], entry['Predicted Class']))
    
    return patch_level_accuracy,predicted_data

In [None]:
def image_level_predictions(directory):
    
    csv_files = [f for f in os.listdir(directory) if f.startswith("class_") and f.endswith(".csv")]
    total_images = 0
    correct_images = 0

    predictions = []  # List to store image-level predictions

    for file in csv_files:
        result_dict = {}
        file_path = os.path.join(directory, file)
        df = pd.read_csv(file_path)
        target_class = None

        for index, row in df.iterrows():
            filename = row.iloc[0]  # First column
            target_class = row.iloc[1]  # Second column (true class)
            predicted_class = row.iloc[2]  # Third column (predicted class)

            A_value, _ = filename.rsplit("_", 1)
            
            if A_value not in result_dict:
                result_dict[A_value] = []
            
            result_dict[A_value].append(predicted_class)

        total_images += len(result_dict)

        for key, values in result_dict.items():
            voted_class = Counter(values).most_common(1)[0][0]
            if voted_class == target_class:
                correct_images += 1

            predictions.append([key, target_class, voted_class])

    output_file = os.path.join(directory, "Image Level Prediction.csv")
    df_predictions = pd.DataFrame(predictions, columns=["Filename", "Target", "Predictions"])
    df_predictions.to_csv(output_file, index=False)

    ila = (correct_images / total_images)*100
    return ila

In [None]:
result_folder='Results'

In [None]:
pla,predicted_data = patch_level_accuracy(test_loader, model, result_folder=result_folder)

In [None]:
print(pla)

In [None]:
result_path = './'+result_folder
ila = image_level_predictions(result_path)

In [None]:
print(ila)