In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random

raw_df = pd.read_pickle("raw_all_data.pkl")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("whitegrid")

entries_per_image = raw_df.groupby(['robot', 'domain', 'image_ref']).size().reset_index(name='num_entries')
domain_counts = entries_per_image.groupby(['domain', 'num_entries']).size().reset_index(name='count')

plt.figure(figsize=(12, 5))
ax = sns.barplot(
    data=domain_counts,
    x='num_entries',
    y='count',
    hue='domain',
    hue_order = ['Home', 'BigOffice-2', 'BigOffice-3', 'Hallway', 'MeetingRoom',
       'SmallOffice'],
    palette='viridis'
)
for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            height + 0.5,  # Slightly above the bar
            int(height),
            ha='center',
            va='bottom',
            fontsize=11
        )
ax.set_ylim(0, domain_counts['count'].sum())

plt.xlabel('Number of Entries per Image')
plt.ylabel('Number of Images')
plt.title('Distribution of Number of Entries per Image by Domain')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

duplicated_rows = raw_df.duplicated(keep=False)
duplicates_per_index = duplicated_rows.groupby([raw_df['robot'], raw_df['domain'], raw_df['image_ref']]).sum()
duplicates_per_index = duplicates_per_index.apply(lambda x: max(x - 1, 0))

duplicates_per_index = duplicates_per_index.reset_index(name='num_duplicates')
domain_counts = duplicates_per_index.groupby(['domain', 'num_duplicates']).size().reset_index(name='count')


plt.figure(figsize=(12, 6))
ax = sns.barplot(
    data=domain_counts,
    x='num_duplicates',
    y='count',
    hue='domain',
    hue_order = ['Home', 'BigOffice-2', 'BigOffice-3', 'Hallway', 'MeetingRoom',
       'SmallOffice'],
    palette='viridis'
)

for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            height + 0.5,
            int(height),
            ha='center',
            va='bottom',
            fontsize=11
        )

# Set y-axis max to total number of duplicates
ax.set_ylim(0, domain_counts['count'].sum())

plt.xlabel('Number of Duplicates per Image')
plt.ylabel('Number of Images')
plt.title('Distribution of Number of Duplicates per Image by Domain')
plt.tight_layout()
plt.show()