In [None]:
%load_ext autoreload
%autoreload 2

In [None]:

from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from aml4cv.dataset import FlowersDataset
from aml4cv.constants import FIGURES_DIR

def plot_flower_distribution(dataset: FlowersDataset):
    """
    Plots the distribution of flower classes in the given FlowersDataset.

    Args:
        dataset (FlowersDataset): The dataset containing flower images and labels.
    """
    # Extract labels from the dataset
    labels = [label["class_name"] for _, label in dataset]
    # Count occurrences of each class
    label_counts = sorted(Counter(labels).items(), key=lambda x: x[1])
    labels, counts = zip(*label_counts)
    print("Flower class distribution:")
    print(f"Min class count: {min(counts)}")
    print(f"Max class count: {max(counts)}")
    print(f"Median class count: {sorted(counts)[len(counts) // 2]}")

    # Create a count plot using seaborn
    plt.figure(figsize=(20, 5))
    sns.barplot(x=labels, y=counts, hue=counts, palette="crest")
    plt.xticks(rotation=45, ha='right')
    plt.title('Flower Class Distribution')
    plt.xlabel('Flower Class')
    plt.ylabel('Number of Images')
    if dataset.split == "test":
        plt.legend()
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / f'flower_class_distribution_{dataset.split}.pdf')
    plt.show()

In [None]:
train = FlowersDataset(split='train')
val = FlowersDataset(split='val')
test = FlowersDataset(split='test')
print(f'Train set size: {len(train)}')
print(f'Validation set size: {len(val)}')
print(f'Test set size: {len(test)}')
print(f"Total dataset size: {len(train) + len(val) + len(test)}")

In [None]:
plot_flower_distribution(train)

In [None]:
plot_flower_distribution(test)

In [None]:
plot_flower_distribution(val)

In [None]:
from aml4cv.train import get_data_transforms, get_model_and_processor
model_id = "google/vit-large-patch16-224-in21k"
model, processor = get_model_and_processor(model_id, "cpu")
image_mean = (
    processor.image_mean
    if isinstance(processor.image_mean, list)
    else [processor.image_mean] * 3
)
image_std = (
    processor.image_std
    if isinstance(processor.image_std, list)
    else [processor.image_std] * 3
)



In [None]:
import torch
from torchvision.transforms import v2
from aml4cv import FIGURES_DIR
transforms_to_show = [
    None,
    v2.Pad(padding=10, padding_mode="constant"),
    v2.RandomRotation(degrees=[-180, 180]),
    v2.RandomVerticalFlip(p=1.0),
    v2.GaussianBlur(kernel_size=(19, 19), sigma=(5.0, 10.0)),
    v2.CenterCrop(380),
]
def show_transforms(image, label):
    image_width = processor.size["width"]
    image_height = processor.size["height"]
    fig, ax = plt.subplots(2, 3, figsize=(15, 10))
    for transforms in transforms_to_show:
        normalize = v2.Compose(
            [
                v2.Resize((image_height, image_width)),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=image_mean, std=image_std),
            ]
        )
        if transforms is None:
            transformed_image = normalize(image)
        else:
            transformed_image = transforms(image)
            transformed_image = normalize(transformed_image)
        # unnormalize the image for visualization
        unnormalized_image = transformed_image * torch.tensor(image_std).view(3, 1, 1) + torch.tensor(image_mean).view(3, 1, 1)
        image_pil = v2.functional.to_pil_image(unnormalized_image)
        if transforms is None:
            title = "No Transform"
        else:
            title = transforms.__class__.__name__
        ax_idx = transforms_to_show.index(transforms)
        ax[ax_idx // 3, ax_idx % 3].imshow(image_pil)
        ax[ax_idx // 3, ax_idx % 3].set_title(f"{title}\nClass: {label['class_name']}")
        ax[ax_idx // 3, ax_idx % 3].grid(False)
        ax[ax_idx // 3, ax_idx % 3].axis('off')
    plt.suptitle('Sample Image with Various Transforms', fontsize=16)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / "transforms_visualization.png", dpi=300, bbox_inches='tight')
    plt.show()


image, label = train[4264]
show_transforms(image, label)

In [None]:
# gather sizes of all images in the training set
image_sizes = [(img.shape[1], img.shape[2]) for img, _ in train]
print(f"Min image size (height, width): {min(image_sizes)}")
print(f"Max image size (height, width): {max(image_sizes)}")

# min hieght
image_heights = [size[0] for size in image_sizes]
min_height = min(image_heights)
# max height
max_height = max(image_heights)
# min width
image_widths = [size[1] for size in image_sizes]
min_width = min(image_widths)
# max width
max_width = max(image_widths)
print(f"Min image height: {min_height}, Max image height: {max_height}")
print(f"Min image width: {min_width}, Max image width: {max_width}")

In [None]:
# plot distribution of image colors
def plot_image_color_distribution(dataset: FlowersDataset):
    """
    Plots the distribution of image colors in the given FlowersDataset.

    Args:
        dataset (FlowersDataset): The dataset containing flower images and labels.
    """
    # Initialize lists to hold mean color values
    mean_red = []
    mean_green = []
    mean_blue = []

    # Iterate through the dataset to compute mean color values
    for img, _ in dataset:
        mean_red.append(img[0, :, :].mean(dtype=torch.float32).item())
        mean_green.append(img[1, :, :].mean(dtype=torch.float32).item())
        mean_blue.append(img[2, :, :].mean(dtype=torch.float32).item())

    # Create a DataFrame for plotting
    import pandas as pd
    color_data = pd.DataFrame({
        'Red': mean_red,
        'Green': mean_green,
        'Blue': mean_blue
    })

    # Plot the distribution of colors
    plt.figure(figsize=(12, 6))
    sns.kdeplot(data=color_data, x="Red", fill=True, legend=True, label="Red", color="r")
    sns.kdeplot(data=color_data, x="Green", fill=True, legend=True, label="Green", color="g")
    sns.kdeplot(data=color_data, x="Blue", fill=True, legend=True, label="Blue", color="b")

    plt.title('Distribution of Image Colors')
    plt.xlabel('Color Intensity')
    plt.ylabel('Density')
    plt.legend(title='Color Channel')
    plt.tight_layout()
    plt.show()
plot_image_color_distribution(train)

In [None]:
plot_image_color_distribution(test)

In [None]:
plot_image_color_distribution(val)