In [None]:
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
from aml4cv import FIGURES_DIR
from aml4cv.train import get_data_transforms, get_model_and_processor
from aml4cv.dataset import FlowersDataset
import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F
import numpy as np

DISTRIBUTIONS_DIR = FIGURES_DIR / 'distributions'
DISTRIBUTIONS_DIR.mkdir(parents=True, exist_ok=True)

def plot_flower_distribution(dataset: FlowersDataset):
    """
    Plots the distribution of flower classes in the given FlowersDataset.

    Args:
        dataset (FlowersDataset): The dataset containing flower images and labels.
    """
    # Extract labels from the dataset
    labels = [label["class_name"] for _, label in dataset]
    # Count occurrences of each class
    label_counts = sorted(Counter(labels).items(), key=lambda x: x[1])
    labels, counts = zip(*label_counts)
    print("Flower class distribution:")
    print(f"Min class count: {min(counts)}")
    print(f"Max class count: {max(counts)}")
    print(f"Median class count: {sorted(counts)[len(counts) // 2]}")

    # without labels
    fig, ax = plt.subplots(figsize=(10, 5))
    sns.barplot(x=range(len(labels)), y=counts, hue=counts, palette="crest", legend=False if dataset.split != "test" else True, ax=ax)
    plt.xticks(ticks=range(0,len(labels), 10))
    ax.tick_params(axis='both', labelsize=15, length=6, width=2)
    ax.set_title('Flower Class Distribution', fontsize=25)
    ax.set_xlabel('Flower Class', fontsize=20)
    ax.set_ylabel('Number of Images', fontsize=20)
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / f'{dataset.split}_no_labels.pdf', bbox_inches='tight')
    plt.show()
    # with labels
    plt.figure(figsize=(10, 5))
    sns.barplot(x=labels, y=counts, hue=counts, palette="crest", legend=False if dataset.split != "test" else True)
    plt.xticks(rotation=45, ha='right')
    plt.title('Flower Class Distribution')
    plt.xlabel('Flower Class')
    plt.ylabel('Number of Images')
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / f'{dataset.split}_with_labels.pdf', bbox_inches='tight')
    plt.show()


In [None]:
train = FlowersDataset(split='train')
val = FlowersDataset(split='val')
test = FlowersDataset(split='test')
print(f'Train set size: {len(train)}')
print(f'Validation set size: {len(val)}')
print(f'Test set size: {len(test)}')
print(f"Total dataset size: {len(train) + len(val) + len(test)}")

In [None]:
# min, max image sizes
image_sizes = [img.size() for img, _ in train] + [img.size() for img, _ in val] + [img.size() for img, _ in test]
c, widths, heights = zip(*image_sizes)
print(f"Min image width: {min(widths)}, height: {min(heights)}")
print(f"Max image width: {max(widths)}, height: {max(heights)}")
print(f"Average image width: {sum(widths)/len(widths):.2f}, height: {sum(heights)/len(heights):.2f}")

In [None]:
plot_flower_distribution(train)

In [None]:
plot_flower_distribution(test)

In [None]:
plot_flower_distribution(val)

In [None]:
model_id = "google/vit-large-patch16-224-in21k"
model, processor = get_model_and_processor(model_id, "cpu")
image_mean = (
    processor.image_mean
    if isinstance(processor.image_mean, list)
    else [processor.image_mean] * 3
)
image_std = (
    processor.image_std
    if isinstance(processor.image_std, list)
    else [processor.image_std] * 3
)



In [None]:

transforms_to_show = [
    None,
    v2.Pad(padding=10, padding_mode="constant"),
    v2.RandomRotation(degrees=[-180, 180]),
    v2.RandomVerticalFlip(p=1.0),
    v2.GaussianBlur(kernel_size=(19, 19), sigma=(5.0, 10.0)),
    v2.CenterCrop(380),
]
def show_transforms(image, label):
    image_width = processor.size["width"]
    image_height = processor.size["height"]
    fig, ax = plt.subplots(2, 3, figsize=(15, 10))
    for transforms in transforms_to_show:
        normalize = v2.Compose(
            [
                v2.Resize((image_height, image_width)),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=image_mean, std=image_std),
            ]
        )
        if transforms is None:
            transformed_image = normalize(image)
        else:
            transformed_image = transforms(image)
            transformed_image = normalize(transformed_image)
        # unnormalize the image for visualization
        unnormalized_image = transformed_image * torch.tensor(image_std).view(3, 1, 1) + torch.tensor(image_mean).view(3, 1, 1)
        image_pil = v2.functional.to_pil_image(unnormalized_image)
        if transforms is None:
            title = "No Transform"
        else:
            title = transforms.__class__.__name__
        ax_idx = transforms_to_show.index(transforms)
        ax[ax_idx // 3, ax_idx % 3].imshow(image_pil)
        ax[ax_idx // 3, ax_idx % 3].set_title(f"{title}\nClass: {label['class_name']}", fontsize=20)
        ax[ax_idx // 3, ax_idx % 3].grid(False)
        ax[ax_idx // 3, ax_idx % 3].axis('off')
    plt.suptitle('Sample Image with Various Transforms', fontsize=25)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / "data_augmentations.pdf", bbox_inches='tight')
    plt.show()

image, label = test[4264]
show_transforms(image, label)

In [None]:
# gather sizes of all images in the training set
image_sizes = [(img.shape[1], img.shape[2]) for img, _ in train]
print(f"Min image size (height, width): {min(image_sizes)}")
print(f"Max image size (height, width): {max(image_sizes)}")

# min hieght
image_heights = [size[0] for size in image_sizes]
min_height = min(image_heights)
# max height
max_height = max(image_heights)
# min width
image_widths = [size[1] for size in image_sizes]
min_width = min(image_widths)
# max width
max_width = max(image_widths)
print(f"Min image height: {min_height}, Max image height: {max_height}")
print(f"Min image width: {min_width}, Max image width: {max_width}")

In [None]:
def plot_hsv_boxplots(datasets: list[FlowersDataset]):
    """
    Plots boxplots of HSV values comparing different splits.
    
    Args:
        datasets: List of FlowersDataset objects to compare.
    """
    data = []
    
    for dataset in datasets:
        for img, label in dataset:
            pil_img = F.to_pil_image(img)
            cv_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2HSV)
            data.append({
                'split': dataset.split.capitalize(),
                'hue': cv_img[:, :, 0].mean(),
                'saturation': cv_img[:, :, 1].mean(),
                'value': cv_img[:, :, 2].mean()
            })
    
    df = pd.DataFrame(data)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    # shared y-axis limits for better comparison 0 to 255

    axs[0].set_ylim(0, 180)
    axs[1].set_ylim(0, 255)
    axs[2].set_ylim(0, 255)
    
    sns.boxplot(data=df, x='split', y='hue', ax=axs[0], hue="split", palette='Set2')
    axs[0].set_title('Hue by Split', fontsize=14)
    axs[0].set_xlabel('Split', fontsize=14)
    axs[0].set_ylabel('Mean Hue', fontsize=14)
    axs[0].tick_params(axis='both',  labelsize=12)
    
    sns.boxplot(data=df, x='split', y='saturation', ax=axs[1], hue="split", palette='Set2')
    axs[1].set_title('Saturation by Split', fontsize=14)
    axs[1].set_xlabel('Split', fontsize=14)
    axs[1].set_ylabel('Mean Saturation', fontsize=14)
    axs[1].tick_params(axis='both',  labelsize=12)
    
    sns.boxplot(data=df, x='split', y='value', ax=axs[2], hue="split", palette='Set2')
    axs[2].set_title('Value by Split', fontsize=14)
    axs[2].set_xlabel('Split', fontsize=14)
    axs[2].set_ylabel('Mean Value', fontsize=14)
    axs[2].tick_params(axis='both',  labelsize=12)

    
    plt.suptitle('HSV Distribution Across Splits', fontsize=16)
    plt.tight_layout()
    plt.savefig(DISTRIBUTIONS_DIR / 'hsv_boxplots.pdf', bbox_inches='tight')
    plt.show()

plot_hsv_boxplots([train, val, test])