In [None]:
import os
import shutil
import random
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os.path
import glob
from pathlib import Path
import sys
import seaborn as sns
import pandas as pd

random_seed = 42
np.random.seed(random_seed)
random.seed(random_seed)

In [None]:
datapath = r'./data' #Please mention the path of your data (must be splitted in train & test)

File Separation Code

In [None]:
def min_images_class(dir):
    classes = os.listdir(dir)
    nos_files = []
    for cl in classes:
        cl_path = os.path.join(dir,cl)
        nos_files.append(len(os.listdir(cl_path)))

    return min(nos_files)

In [None]:
def files_trace(dir):
    classes = sorted(os.listdir(dir))
    nos_files_rem = {}
    nos_files_used={}
    for cl in classes:
        cl_path = os.path.join(dir,cl)
        nos_files_rem[cl] = int(len(os.listdir(cl_path)))
        nos_files_used[cl] = 0
    return nos_files_used,nos_files_rem

In [None]:
def random_split(total_images, num_clients):
    splits = []
    
    # Generate random split points
    for _ in range(num_clients - 1):
        split_point = random.randint(0, total_images)
        splits.append(split_point)
    
    splits.sort()

    # Calculate the number of images each client will get
    distribution = []
    prev_split = 0

    for split in splits:
        distribution.append(split - prev_split)
        prev_split = split

    distribution.append(total_images - prev_split)  # Add the remaining images to the last client

    return distribution

In [None]:
def files_split(parent_dir, number_clients=5, initial_ratio=0.06, dist_type='equal'):
    train_path = os.path.join(parent_dir, 'train')
    classes = sorted(os.listdir(train_path))

    if dist_type == 'skew':
        number_clients = len(classes)

    initial_global_data = {}  # Dictionary to store global data paths for each class
    client_data = {f'Client_{i+1}': {} for i in range(number_clients)}  # Dictionary to store client data paths
    client_info = []  # List to record the number of images per client per class

    min_files = min_images_class(train_path)  # Minimum number of images across all classes

    ##################### CREATING THE INITIAL GLOBAL DATA #################################
    for cl in classes:
        cl_path_src = os.path.join(train_path, cl)
        initial_files = int(min_files * initial_ratio)
        global_files = sorted([f for f in os.listdir(cl_path_src) if f.endswith(('g', 'G'))])[:initial_files]

        initial_global_data[cl] = [os.path.join(cl_path_src, img) for img in global_files]

    print("Initial Data Created!")

    ################# CREATING THE CLIENT DATA ########################################
    for i, cl in enumerate(classes):
        class_path_src = os.path.join(train_path, cl)
        all_files = sorted([f for f in os.listdir(class_path_src) if f.endswith(('g', 'G'))])
        random.shuffle(all_files)

        if dist_type == 'equal':
            # Equally distribute images among clients
            files_per_client = len(all_files) // number_clients
            remainder = len(all_files) % number_clients

            start_index = 0
            for j in range(number_clients):
                client_name = f'Client_{j + 1}'
                client_files_count = files_per_client + (1 if j < remainder else 0)
                end_index = start_index + client_files_count
                client_files = all_files[start_index:end_index]

                client_data[client_name].setdefault(cl, []).extend([os.path.join(class_path_src, img) for img in client_files])
                client_info.append((client_name, cl, len(client_files)))  # Record number of images

                start_index = end_index

            print(f"Class '{cl}' equally distributed among clients.")

        elif dist_type == 'random':
            # Random unbalanced distribution
            total_images = len(all_files)
            split_distribution = random_split(total_images, number_clients)

            start_index = 0
            for j, files_for_client in enumerate(split_distribution):
                client_name = f'Client_{j + 1}'
                end_index = start_index + files_for_client
                client_files = all_files[start_index:end_index]

                client_data[client_name].setdefault(cl, []).extend([os.path.join(class_path_src, img) for img in client_files])
                client_info.append((client_name, cl, len(client_files)))

                start_index = end_index

            print(f"Class '{cl}' randomly distributed among clients.")

        elif dist_type == 'skew':
            # Distribute 90% to main client and 10% among others
            main_class_ratio = 0.9
            main_client_files_count = int(len(all_files) * main_class_ratio)
            remaining_files_count = len(all_files) - main_client_files_count
            main_client_name = f'Client_{i + 1}'
        
            # Assign 90% of files to the main client
            client_data[main_client_name].setdefault(cl, []).extend(
                [os.path.join(class_path_src, img) for img in all_files[:main_client_files_count]]
            )
            client_info.append((main_client_name, cl, main_client_files_count))
        
            # Distribute the remaining 10% among other clients
            remaining_files = all_files[main_client_files_count:]
            random.shuffle(remaining_files)
            
            # List of other clients, excluding the main client
            other_clients = [f'Client_{j + 1}' for j in range(number_clients) if j + 1 != i + 1]
            
            if remaining_files_count > 0:
                split_distribution = random_split(remaining_files_count, len(other_clients))
        
                start_index = 0
                for j, files_for_client in enumerate(split_distribution):
                    other_client_name = other_clients[j]  # Use the pre-defined list to avoid gaps
                    
                    end_index = start_index + files_for_client
                    client_files = remaining_files[start_index:end_index]
        
                    client_data[other_client_name].setdefault(cl, []).extend(
                        [os.path.join(class_path_src, img) for img in client_files]
                    )
                    client_info.append((other_client_name, cl, len(client_files)))
        
                    start_index = end_index
        
            print(f"Class '{cl}' distributed: {main_class_ratio*100}% to '{main_client_name}' and the rest across other clients.")


    ######################### COPY TEST ######################################
    test_path = os.path.join(parent_dir, 'test')
    test_data = {}

    for cl in classes:
        class_test_path = os.path.join(test_path, cl)
        test_data[cl] = [os.path.join(class_test_path, f) for f in os.listdir(class_test_path) if f.endswith(('g', 'G'))]

    print("Test Data Created!")

    return classes, initial_global_data, client_data, client_info, test_data


In [None]:
def plot_client_distribution(client_info):
    # Convert client_info to a DataFrame for easier manipulation
    df = pd.DataFrame(client_info, columns=['Client', 'Class', 'Number_of_Images'])
    
    # Map class names to index numbers
    class_indices = {class_name: idx + 1 for idx, class_name in enumerate(df['Class'].unique())}
    df['Class_Index'] = df['Class'].map(class_indices)  # Add index column for plotting
    
    # Get the unique list of clients
    clients = df['Client'].unique()
    num_clients = len(clients)
    
    # Set up subplots with dynamic row calculation based on the number of clients
    fig, axes = plt.subplots(nrows=(num_clients + 1) // 2, ncols=2, figsize=(14, num_clients * 3))
    axes = axes.flatten()
    
    # Set seaborn to a minimal style and turn off color
    sns.set(style="whitegrid")

    # Plot each client's data on separate subplot
    for i, client in enumerate(clients):
        client_data = df[df['Client'] == client]
        
        sns.barplot(ax=axes[i], x='Class_Index', y='Number_of_Images', data=client_data, color='gray')
        axes[i].set_title(f"Image Distribution for {client}")
        axes[i].set_xlabel("Class Index")
        axes[i].set_ylabel("Number of Images")
    
    # Remove unused subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])
    
    # Adjust layout to avoid overlap
    plt.tight_layout()
    plt.savefig('distribution.pdf')
    #plt.show()

Patching Code

In [None]:
def create_directories(classes, clients, subset):
    if subset == 'clients':
        for client in clients:
            for cl in classes:
                lbp_dir = os.path.join("patches",subset,client,cl)
                os.makedirs(lbp_dir,exist_ok=True)
    else:
        for cl in classes:
            lbp_dir = os.path.join("patches",subset,cl)
            os.makedirs(lbp_dir,exist_ok=True)

In [None]:
def center_crop(img, dim):
    height, width,_ = img.shape
    crop_height, crop_width = dim[0],dim[1]
    mid_x, mid_y = width//2, height//2
    half_cropH, half_cropW = crop_height//2, crop_width//2 
    crop_img = img[mid_y-half_cropH:mid_y+half_cropH, mid_x-half_cropW:mid_x+half_cropW]
    return crop_img

In [None]:
def laplacian_variance(patch_gray):
    return cv2.Laplacian(patch_gray, cv2.CV_64F).var()

In [None]:
def process_images(data, classes, output_path, patch_size=64, total_patches=128):
    with open('patch_extraction_process_status.txt', 'a') as f:
        error = 0
        class_count = 0

        for cl in classes:
            class_count += 1
            lbp_dir = os.path.join(output_path, cl)
            os.makedirs(lbp_dir, exist_ok=True)
            files = data[cl]
            total_files_in_class = len(files)
            files_processed = 0

            for file in files:
                img = cv2.imread(file)
                if img is None:
                    f.write(f"Not an image: {file}\n")
                    error += 1
                    continue

                height, width, _ = img.shape
                f.write("Original image dim: " + str(img.shape) + "\n")

                crop_h = (height // patch_size) * patch_size
                crop_w = (width // patch_size) * patch_size
                img_crop = center_crop(img, (crop_h, crop_w))
                f.write("Crop image dim: " + str(img_crop.shape) + "\n")

                h, w, _ = img_crop.shape
                patch_info = []

                for i in range(0, h, patch_size):
                    for j in range(0, w, patch_size):
                        patch = img_crop[i:i + patch_size, j:j + patch_size, :]
                        patch_gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
                        score = laplacian_variance(patch_gray)
                        patch_info.append((patch, score))

                if not patch_info:
                    continue

                # Sort patches based on Laplacian variance (focus measure)
                sorted_patches = sorted(patch_info, key=lambda x: x[1], reverse=True)
                selected_patches = [patch for patch, _ in sorted_patches[:total_patches]]

                for i, patch in enumerate(selected_patches):
                    filename = Path(file).stem + '_' + str(i + 1) + '.jpg'
                    img_dir = os.path.join(lbp_dir, filename)
                    cv2.imwrite(img_dir, patch)

                files_processed += 1
                print("Class processed  [" + str(class_count) + "/" + str(len(classes)) + "] -> (" +
                      str(files_processed) + "/" + str(total_files_in_class) + ")")
                print('\n\n----------------------------------------------------------------------------------\n\n')
                f.write("Class processed  [" + str(class_count) + "/" + str(len(classes)) + "] -> (" +
                        str(files_processed) + "/" + str(total_files_in_class) + ")\n")
                f.write('\n\n----------------------------------------------------------------------------------\n\n')

        f.write("Total error: " + str(error))


In [None]:
classes, initial_global_data, client_data, client_info, test_data = files_split(datapath, number_clients=5, initial_ratio=0.06, dist_type='equal')
plot_client_distribution(client_info)

In [None]:
subset = ['test','initial','clients']
clients = list(client_data.keys())

with open('patch_extraction_process_status.txt', 'a') as f:
    for sub in subset:
        create_directories(classes, clients, sub)
        if sub=='initial':
            data = initial_global_data
            output_path = os.path.join('patches',sub)
            process_images(data,classes,output_path)
        elif sub == 'clients':
            for client in clients:
                data = client_data[client]
                output_path = os.path.join('patches',sub,client)
                process_images(data,classes,output_path)
        else:
            data = test_data
            output_path = os.path.join('patches',sub)
            process_images(data,classes,output_path)
        f.write("\n\n"+sub+" Completed!\n\n")