# Animals with Attributes 2

In [None]:
import os
import sys

import quanproto.datasets.config_parser as quan_dataloader
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from quanproto.eda import eda
from quanproto.utils.workspace import *

In [None]:
config = {
    "dataset_dir": DATASET_DIR,
    "dataset" : "awa2",
}

figure_height = 5
figure_width = 8

# Create the Dataset

In [None]:
dataset = quan_dataloader.get_dataset(config["dataset_dir"], config["dataset"])
if not dataset.has_splits():
    dataset.split_dataset(k=4, seed=42, shuffle=True, stratified=True, train_size=0.7)
test_info = dataset.test_info()


In [None]:
log_dir = os.path.join(WORKSPACE_PATH, "experiments/awa2/all")
os.makedirs(log_dir, exist_ok=True)

sample_labels = np.array(list(dataset.sample_labels().values()))
sample_dir = dataset.sample_dir()

# Make all image statistics
sample_class_labels = np.array(list(dataset.sample_labels().values()))
sample_dir = dataset.sample_dir()

In [None]:
# Class Statistics
overview = eda.class_statistics(sample_class_labels)
overview

In [None]:
# Class Distribution
counts, labels = eda.class_histogram(sample_class_labels)
fig, ax = plt.subplots(figsize=(figure_width, figure_height))
ax.plot(labels, counts)
plt.show()

In [None]:
# Image Histograms
counts, vals = eda.color_histogram(sample_dir)
norm_counts = counts / np.sum(counts, axis=1, keepdims=True)
_, ax = plt.subplots(figsize=(figure_width, figure_height))
colors = ["red", "green", "blue", "black"]
labels = ["Red", "Green", "Blue", "Exposure"]

for i in range(4):
    ax.plot(vals, norm_counts[i], color=colors[i], label=labels[i])

ax.set_title("Color Histograms")
ax.set_xlabel("Pixel Value")
ax.set_ylabel("Frequency")
ax.legend()
plt.show()

In [None]:
# Image Statistics
statistics = eda.color_statistics(counts, vals)
statistics

In [None]:
k = 6010
# show an image
image = os.path.join(dataset.test_dirs()["test"], test_info["paths"][k])
class_name = dataset.class_names()[dataset.sample_labels()[test_info["ids"][k]]]
labels = test_info["labels"][k]
predicate_names = dataset.predicate_names()
rows = [val for i, val in enumerate(predicate_names.values()) if labels[i] == 1]
stat = pd.DataFrame(rows, columns=[class_name])
print(stat)
# save the dataframe as a latex table
stat.to_latex("zebra.tex", index=False)

fig, ax = plt.subplots(1, 1, figsize=(figure_width, figure_height))
ax.imshow(plt.imread(image))
# ax.set_title(class_name)
ax.axis("off")

# show the image
plt.tight_layout()