In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
from aml4cv import FIGURES_DIR
from aml4cv.train import get_data_transforms, get_model_and_processor
from aml4cv.dataset import FlowersDataset
import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F
import numpy as np

DISTRIBUTIONS_DIR = FIGURES_DIR / 'distributions'
DISTRIBUTIONS_DIR.mkdir(parents=True, exist_ok=True)

def plot_flower_distribution(dataset: FlowersDataset):
    """
    Plots the distribution of flower classes in the given FlowersDataset.

    Args:
        dataset (FlowersDataset): The dataset containing flower images and labels.
    """
    # Extract labels from the dataset
    labels = [label["class_name"] for _, label in dataset]
    # Count occurrences of each class
    label_counts = sorted(Counter(labels).items(), key=lambda x: x[1])
    labels, counts = zip(*label_counts)
    print("Flower class distribution:")
    print(f"Min class count: {min(counts)}")
    print(f"Max class count: {max(counts)}")
    print(f"Median class count: {sorted(counts)[len(counts) // 2]}")

    # without labels
    fig, ax = plt.subplots(figsize=(10, 5))
    sns.barplot(x=range(len(labels)), y=counts, hue=counts, palette="crest", legend=False if dataset.split != "test" else True, ax=ax)
    plt.xticks(ticks=range(0,len(labels), 10))
    ax.tick_params(axis='both', labelsize=15, length=6, width=2)
    ax.set_title('Flower Class Distribution', fontsize=25)
    ax.set_xlabel('Flower Class', fontsize=20)
    ax.set_ylabel('Number of Images', fontsize=20)
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / f'{dataset.split}_no_labels.pdf', bbox_inches='tight')
    plt.show()
    # with labels
    plt.figure(figsize=(10, 5))
    sns.barplot(x=labels, y=counts, hue=counts, palette="crest", legend=False if dataset.split != "test" else True)
    plt.xticks(rotation=45, ha='right')
    plt.title('Flower Class Distribution')
    plt.xlabel('Flower Class')
    plt.ylabel('Number of Images')
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / f'{dataset.split}_with_labels.pdf', bbox_inches='tight')
    plt.show()


In [None]:
train = FlowersDataset(split='train')
val = FlowersDataset(split='val')
test = FlowersDataset(split='test')
print(f'Train set size: {len(train)}')
print(f'Validation set size: {len(val)}')
print(f'Test set size: {len(test)}')
print(f"Total dataset size: {len(train) + len(val) + len(test)}")

In [None]:
plot_flower_distribution(train)

In [None]:
plot_flower_distribution(test)

In [None]:
plot_flower_distribution(val)

In [None]:
model_id = "google/vit-large-patch16-224-in21k"
model, processor = get_model_and_processor(model_id, "cpu")
image_mean = (
    processor.image_mean
    if isinstance(processor.image_mean, list)
    else [processor.image_mean] * 3
)
image_std = (
    processor.image_std
    if isinstance(processor.image_std, list)
    else [processor.image_std] * 3
)



In [None]:

transforms_to_show = [
    None,
    v2.Pad(padding=10, padding_mode="constant"),
    v2.RandomRotation(degrees=[-180, 180]),
    v2.RandomVerticalFlip(p=1.0),
    v2.GaussianBlur(kernel_size=(19, 19), sigma=(5.0, 10.0)),
    v2.CenterCrop(380),
]
def show_transforms(image, label):
    image_width = processor.size["width"]
    image_height = processor.size["height"]
    fig, ax = plt.subplots(2, 3, figsize=(15, 10))
    for transforms in transforms_to_show:
        normalize = v2.Compose(
            [
                v2.Resize((image_height, image_width)),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=image_mean, std=image_std),
            ]
        )
        if transforms is None:
            transformed_image = normalize(image)
        else:
            transformed_image = transforms(image)
            transformed_image = normalize(transformed_image)
        # unnormalize the image for visualization
        unnormalized_image = transformed_image * torch.tensor(image_std).view(3, 1, 1) + torch.tensor(image_mean).view(3, 1, 1)
        image_pil = v2.functional.to_pil_image(unnormalized_image)
        if transforms is None:
            title = "No Transform"
        else:
            title = transforms.__class__.__name__
        ax_idx = transforms_to_show.index(transforms)
        ax[ax_idx // 3, ax_idx % 3].imshow(image_pil)
        ax[ax_idx // 3, ax_idx % 3].set_title(f"{title}\nClass: {label['class_name']}")
        ax[ax_idx // 3, ax_idx % 3].grid(False)
        ax[ax_idx // 3, ax_idx % 3].axis('off')
    plt.suptitle('Sample Image with Various Transforms', fontsize=16)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / "data_augmentations.pdf", bbox_inches='tight')
    plt.show()

image, label = test[4264]
show_transforms(image, label)

In [None]:
# gather sizes of all images in the training set
image_sizes = [(img.shape[1], img.shape[2]) for img, _ in train]
print(f"Min image size (height, width): {min(image_sizes)}")
print(f"Max image size (height, width): {max(image_sizes)}")

# min hieght
image_heights = [size[0] for size in image_sizes]
min_height = min(image_heights)
# max height
max_height = max(image_heights)
# min width
image_widths = [size[1] for size in image_sizes]
min_width = min(image_widths)
# max width
max_width = max(image_widths)
print(f"Min image height: {min_height}, Max image height: {max_height}")
print(f"Min image width: {min_width}, Max image width: {max_width}")

In [None]:
from collections import defaultdict
def plot_image_color_distribution(datasets: list[FlowersDataset]):
    """
    Plots the distribution of image colors between all classes in the given
    FlowersDatasets.
    
    Args:
        datasets: The datasets containing flower images and labels.
    """
    # for each class in CLASSES, store lists of mean hue, saturation, value
    class_color_means = defaultdict(lambda: {'hue': list(), 'saturation': list(), 'value': list()})
    
    # Iterate through the dataset to compute mean color values
    for dataset in datasets:
        for img, label in dataset:
            # convert to hsv
            pil_img = F.to_pil_image(img)
            cv_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2HSV_FULL)
            # get mean of each channel
            class_color_means[label['class_name']]['hue'].append(cv_img[:, :, 0].mean())
            class_color_means[label['class_name']]['saturation'].append(cv_img[:, :, 1].mean())
            class_color_means[label['class_name']]['value'].append(cv_img[:, :, 2].mean())

    # Create DataFrame with classes as columns and hue/saturation/value as rows (means)
    color_df = pd.DataFrame({
        class_name: {
            'hue': np.mean(values['hue']),
            'saturation': np.mean(values['saturation']),
            'value': np.mean(values['value'])
        }
        for class_name, values in class_color_means.items()
    })

    # Create DataFrame for standard deviations (intra-class variance)
    color_std_df = pd.DataFrame({
        class_name: {
            'hue': np.std(values['hue']),
            'saturation': np.std(values['saturation']),
            'value': np.std(values['value'])
        }
        for class_name, values in class_color_means.items()
    })

    # Sort columns by hue for better visualization
    sorted_classes = color_df.loc['hue'].sort_values().index
    color_df = color_df[sorted_classes]
    color_std_df = color_std_df[sorted_classes]

    # Create heatmaps
    fig, axes = plt.subplots(2, 1, figsize=(20, 10))
    
    # Heatmap 1: Inter-class comparison (normalized means)
    #color_df_normalized = color_df.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
    sns.heatmap(color_df, ax=axes[0], cmap='viridis')
    axes[0].set_title('Inter-Class: Mean HSV Values (sorted by Hue)', fontsize=14)
    axes[0].set_xlabel('Flower Class')
    axes[0].set_ylabel('HSV Channel')
    axes[0].tick_params(axis='x', rotation=90, labelsize=8)

    # Heatmap 2: Intra-class variance (normalized std devs)
    #color_std_normalized = color_std_df.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
    sns.heatmap(color_std_df, ax=axes[1], cmap='magma', cbar_kws={'label': 'Normalized Std Dev'})
    axes[1].set_title('Intra-Class: HSV Variance (Std Dev) - Brighter = More Variation Within Class', fontsize=14)
    axes[1].set_xlabel('Flower Class')
    axes[1].set_ylabel('HSV Channel')
    axes[1].tick_params(axis='x', rotation=90, labelsize=8)

    plt.tight_layout()
    plt.show()
    
    return color_df, color_std_df

color_df, color_std_df = plot_image_color_distribution([train, val, test])
color_df

In [None]:
def plot_hsv_boxplots(datasets: list[FlowersDataset]):
    """
    Plots boxplots of HSV values comparing different splits.
    
    Args:
        datasets: List of FlowersDataset objects to compare.
    """
    data = []
    
    for dataset in datasets:
        for img, label in dataset:
            pil_img = F.to_pil_image(img)
            cv_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2HSV_FULL)
            data.append({
                'split': dataset.split.capitalize(),
                'hue': cv_img[:, :, 0].mean(),
                'saturation': cv_img[:, :, 1].mean(),
                'value': cv_img[:, :, 2].mean()
            })
    
    df = pd.DataFrame(data)
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    
    sns.boxplot(data=df, x='split', y='hue', ax=ax[0], hue="split", palette='Set2')
    ax[0].set_title('Hue by Split', fontsize=14)
    ax[0].set_xlabel('Split')
    ax[0].set_ylabel('Mean Hue')
    
    sns.boxplot(data=df, x='split', y='saturation', ax=ax[1], hue="split", palette='Set2')
    ax[1].set_title('Saturation by Split', fontsize=14)
    ax[1].set_xlabel('Split')
    ax[1].set_ylabel('Mean Saturation')
    
    sns.boxplot(data=df, x='split', y='value', ax=ax[2], hue="split", palette='Set2')
    ax[2].set_title('Value by Split', fontsize=14)
    ax[2].set_xlabel('Split')
    ax[2].set_ylabel('Mean Value')
    
    plt.suptitle('HSV Distribution Across Splits', fontsize=16)
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / 'hsv_boxplots.pdf', bbox_inches='tight')
    plt.show()

plot_hsv_boxplots([train, val, test])