In [None]:
import os
import re
import cv2
import math
import random
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
import matplotlib as mpl
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from collections import Counter
from sklearn.manifold import TSNE
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

In [2]:
# Load the synthetic data from the CSV file into a pandas DataFrame
data = pd.read_csv("/kaggle/input/synthetic-data-csv/synthetic_data.csv")

# Display the first 30 rows of the DataFrame to get an overview of the dataset
data.head(30)


In [3]:
# List of full class labels
labels_long = ["Diabet", "Glaucoma", "Cataract", "AMD", "Hypertension", "Myopia", "Normal"]

# Create short class labels by taking the first letter of each full label
labels_short = [ll[0] for ll in labels_long]

# Mapping of short labels to full labels (e.g., 'D' -> 'Diabet')
class_short2full = {
    ls: ll
    for ls, ll in zip(labels_short, labels_long)
}

# Create a dictionary that maps each short class label to a unique integer (for encoding)
class_dict = {class_: i for i, class_ in enumerate(class_short2full.keys())}

# Reverse dictionary that maps each integer back to the short class label (for decoding)
class_dict_rev = {v: k for k, v in class_dict.items()}

# Total number of unique classes (in this case 7)
NUM_CLASSES = len(class_dict)


In [4]:
# Class label mapping (short label -> full label)
class_short2full = {
    "D": "Diabet",          # 'D' -> 'Diabet'
    "G": "Glaucoma",        # 'G' -> 'Glaucoma'
    "C": "Cataract",        # 'C' -> 'Cataract'
    "A": "AMD",             # 'A' -> 'AMD' (Age-related Macular Degeneration)
    "H": "Hypertension",    # 'H' -> 'Hypertension'
    "M": "Myopia",          # 'M' -> 'Myopia'
    "N": "Normal"           # 'N' -> 'Normal'
}

# Class ID mapping (short label -> class ID)
class_dict = {
    "D": 0,                 # 'D' -> 0 (Diabet)
    "G": 1,                 # 'G' -> 1 (Glaucoma)
    "C": 2,                 # 'C' -> 2 (Cataract)
    "A": 3,                 # 'A' -> 3 (AMD)
    "H": 4,                 # 'H' -> 4 (Hypertension)
    "M": 5,                 # 'M' -> 5 (Myopia)
    "N": 6                  # 'N' -> 6 (Normal)
}


In [5]:
# Set a random seed for reproducibility
SEED = 42

# Define color palette for different plot elements
COLORS = {
    "fig_bg": "#f6f5f5",              # Background color of the figure
    "plot_neut": "#ddbea9",           # Neutral color for plots
    "plot_text": "#343a40",           # Text color for plots
    
    # A list of colors for plotting (for different categories or groups)
    "cmap_color_list": ["#001219", "#005F73", "#0A9396", "#94D2BD", "#E9D8A6",
                        "#EE9B00", "#CA6702", "#BB3E03", "#AE2012", "#9B2226"],
    
    # Colors for different splits of data
    "split": {
        "train": "#264653",          # Training data color
        "val": "#2a9d8f",            # Validation data color
        "test": "#e9c46a"            # Test data color
    }
}

# Assign colors to each class from the 'class_short2full' dictionary
# 'class_short2full' contains the mapping from short class labels to full names
COLORS["class"] = {ls: c for ls, c in zip(class_short2full.keys(), COLORS["cmap_color_list"][:len(class_short2full.keys())])}

# Define color maps for plots, from a list of color codes
COLORS["cmap"] = mpl.colors.LinearSegmentedColormap.from_list("", COLORS["cmap_color_list"])
COLORS["cmap_pos"] = mpl.colors.LinearSegmentedColormap.from_list("", ["#F0F3F8", "#D1DBE9", "#A2B7D2", "#7493BC", "#6487B4", "#3D5A80"])

# List of colors associated with each class for easy access
colors_class_list = list(COLORS["class"].values())

# Font settings for various plot titles, labels, and texts
FONT_KW = {
    "plot_title" : {                # Font settings for main plot title
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "25",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_title_small" : {          # Font settings for smaller plot title
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "16",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_subtitle" : {             # Font settings for plot subtitle
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "12",               # Font size
        "style": "normal"           # Normal style
    },
    "subplot_title" : {             # Font settings for subplot title
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "18",               # Font size
        "style": "normal"           # Normal style
    },
    "subplot_title_small" : {       # Font settings for smaller subplot title
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "12",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_label" : {                # Font settings for plot labels
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "16",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_label_small" : {          # Font settings for smaller plot labels
        "fontname": "serif",        # Font type (serif)
        "weight": "bold",           # Bold text
        "size": "12",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_text" : {                 # Font settings for general plot text
        "fontname": "serif",        # Font type (serif)
        "weight": "normal",         # Normal weight
        "size": "12",               # Font size
        "style": "normal"           # Normal style
    },
    "plot_text_small" : {           # Font settings for smaller plot text
        "fontname": "serif",        # Font type (serif)
        "weight": "normal",         # Normal weight
        "size": "8",                # Font size
        "style": "normal"           # Normal style
    },
}


In [None]:
# Function to count unique values and return their relative frequency as percentages
def count_values_relative(y):
    bins, vals = np.unique(y, return_counts=True)  # Get unique values and their counts
    return bins, 100 * vals / np.sum(vals)         # Return values and their relative frequencies in percentages

# Function to round up a number to the nearest multiple of 'd' (default 1000)
def ceil_d(n, d=1000):
    return int(np.ceil(n / d) * d)  # Round up 'n' to the nearest multiple of 'd'

# Function to compute subplot dimensions given 'N' total subplots
def get_subplot_dims(N):
    r = np.ceil(np.sqrt(N))  # Rows (sqrt of N, rounded up)
    c = np.floor(np.sqrt(N))  # Columns (sqrt of N, rounded down)
    if r*c < N:  # If the grid is too small to fit all subplots
        r += 1  # Increase the number of rows
    return int(r), int(c)  # Return number of rows and columns

# Function to convert a text to an integer if possible, else return the text
def atoi(text):
    return int(text) if text.isdigit() else text  # Convert to integer if it's a digit

# Function to split a string into parts that can be naturally sorted (numerical sorting)
def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]  # Split and sort text based on numeric values

# Function to return unique sorted values in a column, sorted naturally, with handling of missing values
def natural_sort_col_unique(df, colname, missing="NA"):
    arr = df[colname].unique().tolist()  # Get unique values from the column
    if np.nan in arr:  # Handle NaN values in the column
        arr[arr.index(np.nan)] = missing  # Replace NaN with 'missing'
    arr.sort(key=natural_keys)  # Sort values naturally (handling numbers within strings)
    return arr  # Return sorted list of unique values

# Function to calculate conditional entropy H(x | y)
def conditional_entropy(x, y):
    y_counter = Counter(y)  # Count occurrences of each value in y
    xy_counter = Counter(list(zip(x, y)))  # Count occurrences of pairs (x, y)
    total_occurrences = sum(y_counter.values())  # Total number of occurrences in y
    entropy = 0
    for xy in xy_counter.keys():  # Loop through each (x, y) pair
        p_xy = xy_counter[xy] / total_occurrences  # Probability of (x, y) pair
        p_y = y_counter[xy[1]] / total_occurrences  # Probability of y
        entropy += p_xy * math.log(p_y / p_xy)  # Compute conditional entropy
    return entropy

# Function to calculate Theil's U statistic (a measure of association between two variables)
def theil_u(x, y):
    s_xy = conditional_entropy(x, y)  # Conditional entropy H(x | y)
    x_counter = Counter(x)  # Count occurrences of each value in x
    total_occurrences = sum(x_counter.values())  # Total number of occurrences in x
    p_x = list(map(lambda n: n / total_occurrences, x_counter.values()))  # Probability of each value in x
    s_x = ss.entropy(p_x)  # Entropy of x
    if s_x == 0:  # If entropy of x is 0 (no variation in x), return 1
        return 1
    else:
        # Calculate Theil's U: (S(x) - S(x | y)) / S(x)
        return (s_x - s_xy) / s_x
    
# Function to calculate the correlation ratio (eta-squared) between categories and continuous measurements
def correlation_ratio(categories, measurements):
    if isinstance(categories, pd.Series):
        categories = categories.values  # Convert to numpy array if it's a pandas Series
    if isinstance(measurements, pd.Series):
        measurements = measurements.values  # Convert to numpy array if it's a pandas Series
    fcat, _ = pd.factorize(categories)  # Factorize categories into integer labels
    cat_num = np.max(fcat) + 1  # Number of unique categories
    y_avg_array = np.zeros(cat_num)  # Array to store the average measurement for each category
    n_array = np.zeros(cat_num)  # Array to store the count of measurements per category
    for i in range(0, cat_num):
        cat_measures = measurements[np.argwhere(fcat == i).flatten()]  # Measurements for category i
        n_array[i] = len(cat_measures)  # Number of measurements for category i
        y_avg_array[i] = np.average(cat_measures)  # Average measurement for category i
    y_total_avg = np.sum(np.multiply(y_avg_array, n_array)) / np.sum(n_array)  # Overall average
    numerator = np.sum(
        np.multiply(n_array, np.power(np.subtract(y_avg_array, y_total_avg), 2)))  # Between-group variance
    denominator = np.sum(np.power(np.subtract(measurements, y_total_avg), 2))  # Total variance
    if numerator == 0:  # If there is no variance between categories, eta = 0
        eta = 0.0
    else:
        eta = np.sqrt(numerator / denominator)  # Correlation ratio (square root of eta-squared)
    return eta  # Return the correlation ratio

In [None]:
# Define constants
DATA_PATH = "/kaggle/input/synthetic-data/synthetic_data"  # Path to the synthetic data
IMG_SIZE = 224  
IMAGE_SIZE = [IMG_SIZE, IMG_SIZE]

# Function to convert a class index into a one-hot encoded label
def label_image(c):
    label = np.full((NUM_CLASSES), 0, dtype=int)  # Initialize an array of zeros of length NUM_CLASSES
    label[c] = 1  # Set the label at index 'c' to 1 (one-hot encoding)
    return label  # Return the one-hot encoded label

# Function to calculate the shape of the Gaussian filter
def get_gaussian_filter_shape(IMG_SIZE):
    # The filter shape is roughly one-fourth of the image size, minus one
    return IMG_SIZE // 4 - 1  # Return the computed filter size

# Function to apply a Gaussian blur on an image
def blur_image(image, sigma=10):
    # Get the filter shape based on the image size
    filter_shape = get_gaussian_filter_shape(IMG_SIZE)
    # Apply the Gaussian filter to the image using TensorFlow Addons
    return tfa.image.gaussian_filter2d(image, filter_shape=filter_shape, sigma=sigma)

# Function to apply a weighted transformation on an image
def weighted_image(image, alpha=4, beta=-4, gamma=128):
    # Weighted sum of the original image, blurred image, and a constant gamma
    return image * alpha + blur_image(image) * beta + gamma  # Return the weighted image

In [8]:
# Create a copy of the original dataframe for exploratory data analysis (EDA)
df_eda = data.copy()  # This ensures that the original 'data' remains unmodified while we work with 'df_eda'

In [9]:
# Extract class labels from the 'labels' column and store them in a new 'class' column
df_eda["class"] = df_eda["labels"].apply(lambda x: " ".join(re.findall("[a-zA-Z]+", x)))

# - `re.findall("[a-zA-Z]+", x)`: This regular expression finds all sequences of alphabetic characters in the string `x`.
# - `" ".join(...)`: Joins the alphabetic sequences found by `re.findall` into a single string, separated by spaces.


In [None]:
# Create a figure with two subplots side by side (1 row, 2 columns)
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 6), dpi=70, gridspec_kw={"wspace": 0.5})

# Set background color of the figure
fig.patch.set_facecolor(COLORS["fig_bg"])

# Count the occurrences of each class and calculate the percentage
value_counts = df_eda["class"].value_counts().rename("num").to_frame()  # Class count
value_counts["percent"] = value_counts / value_counts.sum()  # Percentage for each class
value_counts.reindex(index=COLORS["class"].keys())  # Reindex based on class colors

# Create horizontal bar plot for label-based class distribution
b1 = ax1.barh(value_counts.index, value_counts["percent"])

# Set y-ticks with class names, apply font settings, and set label color
ax1.set_yticks(value_counts.index, [class_short2full[i] for i in value_counts.index], **FONT_KW["plot_label"], color=COLORS["plot_text"])
ax1.tick_params(axis="y", length=0)  # Remove tick marks on y-axis
ax1.set_title("Label-Based", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"], pad=30)  # Set title
ax1.text(0, 8.2, "(Multi-class)", **FONT_KW["subplot_title_small"], color=COLORS["plot_text"])  # Subtitle for multi-class

# Add labels to the bars with counts and percentages
ax1.bar_label(b1, labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts["num"], value_counts["percent"])],
              color=COLORS["plot_text"], **FONT_KW["plot_text"])

# Set background color for the axes and adjust bar and label colors
ax1.set_facecolor(COLORS["fig_bg"])
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts.index[i]]
    ax1.get_yticklabels()[i].set_color(c)  # Set color of the y-tick labels
    b1[i].set_color(c)  # Set color of the bars

# Hide x-axis for aesthetic purposes
ax1.axes.get_xaxis().set_visible(False)

# Remove the spines from the top, right, and bottom of the plot
for spine in ["bottom", "right", "top"]:
    ax1.spines[spine].set_visible(False)

#########################################################
# Multiclass Multi-label part (based on diagnosis)

# Sum the number of occurrences for each label across all rows and calculate percentage
value_count_diag = df_eda[labels_short].sum().rename("num").to_frame()
value_count_diag["percent"] = value_count_diag / df_eda.shape[0]  # Percentage of each diagnosis
value_count_diag = value_count_diag.reindex(index=value_counts.index)  # Reindex based on original class ordering

# Create horizontal bar plot for diagnosis-based class distribution
b2 = ax2.barh(value_count_diag.index, value_count_diag["percent"])

# Set y-ticks with class names, apply font settings, and set label color
ax2.set_yticks(value_count_diag.index, [class_short2full[i] for i in value_count_diag.index], **FONT_KW["plot_label"], color=COLORS["plot_text"])
ax2.tick_params(axis="y", length=0)  # Remove tick marks on y-axis
ax2.set_title("Diagnosis-Based", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"], pad=30)  # Title for diagnosis-based
ax2.text(0, 8.2, "(Multi-class Multi-label)", **FONT_KW["subplot_title_small"], color=COLORS["plot_text"])  # Subtitle for multi-label

# Add labels to the bars with counts and percentages
ax2.bar_label(b2, labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_count_diag["num"], value_count_diag["percent"])],
              color=COLORS["plot_text"], **FONT_KW["plot_text"])

# Set background color for the axes and adjust bar and label colors
ax2.set_facecolor(COLORS["fig_bg"])
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_count_diag.index[i]]
    ax2.get_yticklabels()[i].set_color(c)  # Set color of the y-tick labels
    b2[i].set_color(c)  # Set color of the bars

# Hide x-axis for aesthetic purposes
ax2.axes.get_xaxis().set_visible(False)

# Remove the spines from the top, right, and bottom of the plot
for spine in ["bottom", "right", "top"]:
    ax2.spines[spine].set_visible(False)

# Add the overall plot title
plt.figtext(0, 1.05, "Class Distribution", **FONT_KW["plot_title"], color=COLORS["plot_text"])

# Show the plot
plt.show()


In [None]:
# Iterate over the class labels and their corresponding counts
for cls, num in zip(value_counts.index, value_counts["num"]):
    # Retrieve the full class name using the short class name (cls)
    class_name = class_short2full[cls]
    
    # Print the full class name and its corresponding count
    print(f"{class_name}: {num}")


In [None]:
print(value_counts.index[0])

In [13]:
# Apply a lambda function to the "labels" column to extract only alphabetic characters and join them as a string
data["class"] = data["labels"].apply(lambda x: " ".join(re.findall("[a-zA-Z]+", x)))

In [14]:
# Create a dictionary to store image filenames categorized by class
dict_img_list = {
    class_: data.loc[data["class"] == class_]["filename"].values  # Filter rows by class and get corresponding filenames
    for class_ in class_short2full.keys()  # Iterate over each class in the class_short2full dictionary
}

In [None]:
from sklearn.preprocessing import MinMaxScaler  # Import the MinMaxScaler to normalize the input features
from matplotlib.offsetbox import AnnotationBbox, OffsetImage  # Import for adding images as annotations to the plot

def plot_classes(X, y, min_distance=0.05, images=None, figsize=(13, 10), cmap=COLORS["cmap"], annot=False):
    # Scale the input features (X) so that they range from 0 to 1
    X_normalized = MinMaxScaler().fit_transform(X)
    
    # Initialize a list to store coordinates of the points plotted so far
    # We start with a point far away from the data points to avoid unnecessary `if` statements later
    neighbors = np.array([[10., 10.]])
    
    # Create a figure and axis for the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Get the unique class labels from `y`
    classes = np.unique(y)
    n_classes = len(classes)
    
    # Plot each class separately with a different color
    for class_ in classes:
        ax.scatter(
            X_normalized[y == class_, 0],  # Select the x-coordinate for points of the current class
            X_normalized[y == class_, 1],  # Select the y-coordinate for points of the current class
            c=COLORS["class"][class_dict_rev[class_]],  # Set color based on class
            alpha=0.7,  # Make the points semi-transparent
        )
    
    # If `annot` is True, annotate the points with class labels or images
    if annot:
        for index, image_coord in enumerate(X_normalized):
            # Calculate the distance to the nearest previously plotted point
            closest_distance = np.linalg.norm(neighbors - image_coord, axis=1).min()
            
            # If the distance is greater than the specified threshold, annotate the point
            if closest_distance > min_distance:
                neighbors = np.r_[neighbors, [image_coord]]  # Add this point to the list of plotted neighbors
                
                if images is None:
                    # Annotate with class label if no images are provided
                    ax.text(
                        image_coord[0],  # x-coordinate
                        image_coord[1],  # y-coordinate
                        class_dict_rev[y[index]],  # Class label
                        color=COLORS["class"][class_dict_rev[y[index]]],  # Set text color based on class
                        alpha=0.7,  # Make the text semi-transparent
                        **FONT_KW["plot_text_small"]  # Use small font for the text
                    )
                else:
                    # If images are provided, add the image as annotation instead of text
                    image = images[index].reshape(28, 28)  # Reshape the image to 28x28 pixels
                    imagebox = AnnotationBbox(OffsetImage(image, cmap="binary"), image_coord)  # Create the image annotation
                    ax.add_artist(imagebox)  # Add the image to the plot
    
    # Set background color of the figure and axis
    fig.patch.set_facecolor(COLORS["fig_bg"])
    ax.set_facecolor(COLORS["fig_bg"])
    
    # Remove axis ticks and labels for a cleaner look
    ax.axis("off")
    
    # Add a legend to the plot with the class names
    ax.legend(
        [class_short2full[class_dict_rev[label]] for label in np.unique(y)],  # Get full class names for the legend
        prop={"family": "serif", "size": 8},  # Set font properties for the legend
        facecolor=COLORS["fig_bg"]  # Set the background color of the legend
    )
    
    plt.show()


In [16]:
def create_dataset(img_list, class_label, ratio=1.0, max_images=None):
    dataset = []  # List to store the processed images and their labels
    count = 0  # Counter for the number of processed images
    num_images = int(ratio * len(img_list))  # Number of images to process based on the ratio

    # Loop through the image filenames in img_list
    for img in img_list:
        # If max_images is set, stop when the counter reaches it
        if max_images is not None and count >= max_images:
            break

        # Skip images based on the specified ratio (chance of inclusion)
        if random.random() > ratio:
            continue

        # Build the full image path
        image_path = os.path.join(DATA_PATH, img)
        # Read the image using OpenCV
        image = cv2.imread(image_path)

        # Skip the image if it is not read properly (None)
        if image is None:
            continue

        # Convert the image from BGR (OpenCV default) to RGB format
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Resize the image to the specified size (IMG_SIZE x IMG_SIZE)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))

        # Append the processed image and its class label to the dataset
        dataset.append([np.array(image), class_label])
        count += 1  # Increment the count

        # Stop if the desired number of images (num_images) has been processed
        if count == num_images:
            break

    # Return the created dataset
    return dataset


In [None]:
# Define the class labels and calculate the number of classes
CLASSES = ["D", "G", "C", "A", "H", "M", "N"]
NUM_CLASSES = len(CLASSES)  # Total number of classes

# Create a dictionary to map each class to a unique integer
class_dict = {class_ : i for i, class_ in enumerate(CLASSES)}
# Create a reverse dictionary to map each integer back to the class label
class_dict_rev = {v: k for k, v in class_dict.items()}

# Filter the data to keep only the rows where the "class" column matches one of the defined classes
data = data.loc[data["class"].isin(CLASSES)]

# Create a dictionary where each key is a class and the value is a list of filenames belonging to that class
dict_img_list = {
    class_: data.loc[data["class"] == class_]["filename"].values
    for class_ in CLASSES
}


In [None]:
# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(14, 6), dpi=70, gridspec_kw={"wspace": 0.5})

# Set background color for the entire figure
fig.patch.set_facecolor(COLORS["fig_bg"])

# Calculate the class distribution in the original data
value_counts = data["class"].value_counts().rename("num").to_frame()
value_counts["percent"] = value_counts / value_counts.sum()  # Calculate percentage for each class
value_counts.reindex(index=CLASSES)  # Reorder based on predefined class list

# Create horizontal bar plot for the original class distribution
b1 = ax1.barh(value_counts.index, value_counts["percent"])

# Customize y-ticks and labels
ax1.set_yticks(value_counts.index, [class_short2full[i] for i in value_counts.index], **FONT_KW["plot_label_small"], color=COLORS["plot_text"])
ax1.tick_params(axis="y", length=0)
ax1.set_title("Original", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"])  # Title for the left plot

# Annotate the bars with the count and percentage
ax1.bar_label(b1, labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts["num"], value_counts["percent"])],
              padding=5, color=COLORS["plot_text"], **FONT_KW["plot_text"])

# Set background color for the left plot
ax1.set_facecolor(COLORS["fig_bg"])

# Set the color for each class' bar and its label
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts.index[i]]
    ax1.get_yticklabels()[i].set_color(c)  # Set label color
    b1[i].set_color(c)  # Set bar color

# Hide x-axis for the left plot
ax1.axes.get_xaxis().set_visible(False)

# Remove the top, right, and bottom spines for the left plot
for spine in ["bottom", "right", "top"]:
    ax1.spines[spine].set_visible(False)

# Perform data augmentation for the minority class ('G') to balance class distribution
NUM_AUGMENTATIONS = 5
value_counts_aug = value_counts.copy()  # Copy the original class distribution
value_counts_aug.loc["G", "num"] *= NUM_AUGMENTATIONS  # Increase the count for the 'G' class
value_counts_aug["percent"] = value_counts_aug["num"] / value_counts_aug["num"].sum()  # Recalculate the percentage

# Create horizontal bar plot for the augmented class distribution
b2 = ax2.barh(value_counts_aug.index, value_counts_aug["percent"])

# Customize y-ticks and labels for the augmented plot
ax2.set_yticks(value_counts_aug.index, [class_short2full[i] for i in value_counts_aug.index], **FONT_KW["plot_label_small"], color=COLORS["plot_text"])
ax2.tick_params(axis="y", length=0)
ax2.set_title("With Minority Class Augmentations", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"])  # Title for the right plot

# Annotate the bars with the count and percentage for the augmented plot
ax2.bar_label(b2, labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts_aug["num"], value_counts_aug["percent"])],
              padding=5, color=COLORS["plot_text"], **FONT_KW["plot_text"])

# Set background color for the right plot
ax2.set_facecolor(COLORS["fig_bg"])

# Set the color for each class' bar and its label in the augmented plot
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts_aug.index[i]]
    ax2.get_yticklabels()[i].set_color(c)  # Set label color
    b2[i].set_color(c)  # Set bar color

# Hide x-axis for the right plot
ax2.axes.get_xaxis().set_visible(False)

# Remove the top, right, and bottom spines for the right plot
for spine in ["bottom", "right", "top"]:
    ax2.spines[spine].set_visible(False)

# Add a title for the entire figure
plt.figtext(0, 1.05, "Class Distribution", **FONT_KW["plot_title"], color=COLORS["plot_text"])

plt.show()


In [None]:
# Initialize an empty list to store the dataset
dataset = []

# Define the maximum number of images per class (not used directly in this code but defined for potential future use)
max_images_per_class = 1000

print("START building dataset")

# Iterate over the classes to build the dataset
for i, class_ in enumerate(CLASSES):
    # Print the progress message showing the class being processed
    print(f"[{i+1}/{len(CLASSES)}] adding {class_short2full[class_]} to dataset ...")
    
    # Retrieve the list of image filenames for the current class
    img_list = dict_img_list[class_]
    
    # Get the numeric label associated with the current class
    class_label = class_dict[class_]
    
    # Add the images for the current class to the dataset using the `create_dataset` function
    # Here, the `ratio` is set to 0.5, meaning only 50% of the images from the list will be included in the dataset
    dataset += create_dataset(img_list, class_label, ratio=0.5)

# Shuffle the dataset to ensure randomness in the order of images
random.shuffle(dataset)

print("COMPLETE building dataset")


In [None]:
len(dataset)

In [None]:
from sklearn.model_selection import train_test_split  
from tensorflow.keras.utils import to_categorical

image_size = 224
num_classes = 7

# Create the input features (images) and labels (class labels) from the dataset
# `train_x_concate` is a numpy array of images reshaped to the required dimensions (224x224x3)
train_x_concate = np.array([i[0] for i in dataset]).reshape(-1, image_size, image_size, 3)

# `train_y_concate` is a numpy array of class labels
train_y_concate = np.array([i[1] for i in dataset])

# Split the dataset into training and validation sets (80% for training, 20% for validation)
# `train_test_split` randomly splits the data, with a fixed random seed for reproducibility
x_train_concate, x_val_concate, y_train_concate, y_val_concate = train_test_split(
    train_x_concate, train_y_concate, test_size=0.2, random_state=42
)

# Print the number of images in the training and validation sets
print(f"Number of images - Train: {len(x_train_concate)}, Validation: {len(x_val_concate)}")


In [None]:
# Convert the class labels to categorical format for use with categorical cross-entropy loss
# `to_categorical` will convert the labels to a one-hot encoded format

y_train_concate = to_categorical(y_train_concate, num_classes=num_classes)
y_val_concate = to_categorical(y_val_concate, num_classes=num_classes)


In [None]:
print("Shape of train_x:", x_train_concate.shape)
print("Shape of train_x:", y_train_concate.shape)
print("Shape of train_x:", x_val_concate.shape)
print("Shape of train_x:", y_val_concate.shape)

In [None]:
print(len(x_train_concate))
print(len(y_train_concate))
print(len(x_val_concate))
print(len(y_val_concate))

In [None]:
# Convert one-hot encoded labels back to class labels by taking the index of the highest value in each row
y_train_labels = np.argmax(y_train_concate, axis=1)

# Count the occurrences of each class in the training set
test_class_counts = np.bincount(y_train_labels)

# Print the number of images for each class in the training set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")


In [None]:
# Convert one-hot encoded labels to class labels for the validation set
y_val_labels = np.argmax(y_val_concate, axis=1)

# Count the occurrences of each class in the validation set
test_class_counts = np.bincount(y_val_labels)

# Print the number of images for each class in the validation set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")


In [None]:
# Load the train, validation, and test datasets (from real dataset)
x_train = np.load('/kaggle/input/tsne-7-class-train-test-val/x_train.npy')
y_train = np.load('/kaggle/input/tsne-7-class-train-test-val/y_train.npy')
x_val = np.load('/kaggle/input/tsne-7-class-train-test-val/x_val.npy')
y_val = np.load('/kaggle/input/tsne-7-class-train-test-val/y_val.npy')
x_test = np.load('/kaggle/input/tsne-7-class-train-test-val/x_test.npy')
y_test = np.load('/kaggle/input/tsne-7-class-train-test-val/y_test.npy')

In [None]:
# Concatenate the existing training data with the new dataset's training data
combined_train_x = np.concatenate((x_train_concate, x_train), axis=0)
combined_train_y = np.concatenate((y_train_concate, y_train), axis=0)

# Concatenate the existing training data with the new dataset's training data
combined_val_x = np.concatenate((x_val_concate, x_val), axis=0)
combined_val_y = np.concatenate((y_val_concate, y_val), axis=0)

In [None]:
# Convert one-hot encoded labels to class labels
y_val_labels = np.argmax(y_train, axis=1)

# Count the occurrences of each class in the test set
test_class_counts = np.bincount(y_val_labels)

# Print the number of images for each class in the test set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")

In [None]:
# Convert one-hot encoded labels to class labels
y_val_labels = np.argmax(combined_train_y, axis=1)

# Count the occurrences of each class in the test set
test_class_counts = np.bincount(y_val_labels)

# Print the number of images for each class in the test set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")

In [None]:
# Convert one-hot encoded labels to class labels
y_val_labels = np.argmax(y_val, axis=1)

# Count the occurrences of each class in the test set
test_class_counts = np.bincount(y_val_labels)

# Print the number of images for each class in the test set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")

In [None]:
# Convert one-hot encoded labels to class labels
y_val_labels = np.argmax(combined_val_y, axis=1)

# Count the occurrences of each class in the test set
test_class_counts = np.bincount(y_val_labels)

# Print the number of images for each class in the test set
for class_label, count in enumerate(test_class_counts):
    print(f"Class {class_label}: {count} images")

In [None]:
# Save the combined training data for future use
np.save('/kaggle/working/combined_train_x.npy', combined_train_x)
np.save('/kaggle/working/combined_train_y.npy', combined_train_y)

np.save('/kaggle/working/combined_val_x.npy', combined_val_x)
np.save('/kaggle/working/combined_val_y.npy', combined_val_y)

np.save('/kaggle/working/x_test.npy', x_test)
np.save('/kaggle/working/y_test.npy', y_test)