This tutorial is the dataset analyze part of to the following blog post:
[Improve the conditional model]().

Follow it to have more explanations and context.

In [1]:
import plotly.express as px
import numpy as np
from diffusers_tutorials.datasets import SpritesDataset
from torch.utils.data import WeightedRandomSampler

# Dataset

In [2]:
dataset = SpritesDataset(
    "./dlai_lib/sprites_1788_16x16.npy",
    "./dlai_lib/sprite_labels_nc_1788_16x16.npy",
    null_context=False,
)

# Analyze the dataset

In [3]:
labels = dataset.slabels.argmax(axis=1)

In [4]:
fig = px.histogram(labels, nbins=5, title="Histogram of the sprite dataset")
fig.update_xaxes(
    title_text="Labels",
    tickmode="array",
    tickvals=[0, 1, 2, 3, 4],
    ticktext=["human", "non-human", "food", "spell", "side-facing"],
)

fig

## Balanced dataset

How to compute class weights to balance the dataset.

In [5]:
u_labels, class_counts = np.unique(labels, return_counts=True)
class_weights = 1 - class_counts / class_counts.sum()

In [6]:
class_weights # weights for each class to obtain balanced dataset

array([0.91051454, 0.63758389, 0.93288591, 0.60850112, 0.91051454])

In [7]:
sample_weights = tuple(class_weights[label] for label in labels)
dataset_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(labels), replacement=True)

# Look at some images from class having issues

In [8]:
from diffusers_tutorials.tools.plotly import plot_generated_images

In [9]:
non_human_indexes = [i for i, value in enumerate(labels) if value == 1]
plot_generated_images(dataset.sprites[non_human_indexes][:30], 6, 5, title_text="Glimpse of non human images")

Here, as you can see, the problem is not from our model, but from the data! Lets try to detect automatically most of the wrong ones.

In [10]:
def detect_wrong_images(image: np.ndarray, number_of_white: int = 26) -> bool:
    """Function to detect if an image is a wrong one and should be excluded from training.

    Parameters
    ----------
    image : np.ndarray
        the image in RGB with shape [width, heigh, channel]
    number_of_white : int, optional
        the number of white pixel a function must have to be considered valid, by default 26

    Returns
    -------
    bool
        True if it is a wrong image for generation and False if it is a good one.
    """

    wrong_image = False
    # Note white is (255, 255, 255) hence the sum is 765
    # White background so a minimum of pixel should be white
    wrong_image |= np.sum(image.sum(axis=-1) == 765) < number_of_white

    # 4 corners must be white
    wrong_image |= image[0, 0].sum() != 765
    wrong_image |= image[-1, 0].sum() != 765
    wrong_image |= image[0, -1].sum() != 765
    wrong_image |= image[-1, -1].sum() != 765

    return wrong_image

non_human_data = dataset.sprites[non_human_indexes]
wrong_non_human_index = [detect_wrong_images(image) for image in non_human_data]
plot_generated_images(non_human_data[wrong_non_human_index][:30], 6, 5, title_text="Wrong images")

In [11]:
sum(wrong_non_human_index)

30000

So 30k of our 32.4k samples of non-human are wrong images, no wonder our model gave wrong results.

**Note**: this is a simple case, not a real-world dataset as our function to detect wrong samples. In real world applications, we must check blurriness of our data.
Nevertheless, let us check if good samples look great.

In [12]:
good_non_human_index = [not elem for elem in wrong_non_human_index]
plot_generated_images(non_human_data[good_non_human_index][:30], 6, 5, title_text="Right images")