What do I want to know?

- How many are placed in 'other'?
- per level:
    - number of sublevels
    - commulative number of samples

In [None]:
from pathlib import Path
from collections import defaultdict

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from moths.label_hierarchy import label_hierarchy_from_file

In [None]:
sns.set_style("darkgrid")

In [None]:
label_hierarchy_path = Path("/home/vlinderstichting/Data/moths/data/family.csv")
data_source_path = Path("/home/vlinderstichting/Data/moths/artifacts/image_folder")

In [None]:
label_hierarchy = label_hierarchy_from_file(label_hierarchy_path, data_source_path, 50)

In [None]:
# klass: (group, family, genus)

In [None]:
class_counts = {p.name: len(list(p.iterdir())) for p in data_source_path.iterdir()}

In [None]:
label_tree = defaultdict(lambda: defaultdict(lambda: defaultdict( lambda: dict())))

In [None]:
label_tree['Other']['Other']['Other']['Other'] = 0
for klass, count in class_counts.items():
    
    if klass in label_hierarchy.name_map:
        group, family, genus = label_hierarchy.name_map[klass]
        label_tree[group][family][genus][klass] = count
    else:
        label_tree['Other']['Other']['Other']['Other'] +=  count

In [None]:
def node_leaf_sum(node):
    if isinstance(node, int):
        return node
    else:
        return sum([node_leaf_sum(child) for child in node.values()])

In [None]:
node_leaf_sum(label_tree)

In [None]:
def node_species_dict(node):
    def _node_species_dict(_node, _out):
        first_value = list(_node.values())[0]
        if isinstance(first_value, int):
            _out.update(_node)
        else:
            for child in _node.values():
                _node_species_dict(child, _out)
        
    out = {}
    _node_species_dict(node, out)
    return out

In [None]:
node_species_dict(label_tree)

In [None]:
def sort_dict_other_last(counts):
    has_other = "Other" in counts
    
    if has_other:
        other_count = counts["Other"]
        del counts["Other"]
    
    count_list = sorted([(k, c) for k, c in counts.items()], key=lambda t: t[1], reverse=True)

    if has_other:
        count_list.append(("Other", other_count))

    return count_list

In [None]:
sort_dict_other_last(node_species_dict(label_tree))

In [None]:
def plot_count_list(count_list):
    plt.figure(figsize = (15,8))
    ax = sns.barplot(x=[t[0] for t in count_list], y=[t[1] for t in count_list])
    plt.xticks(rotation=90)

In [None]:
plot_count_list(sort_dict_other_last(node_species_dict(label_tree)))

In [None]:
noctuidae_classes = {c: class_counts[c] for c, (_, f, _) in label_hierarchy.name_map.items() if f == "Noctuidae"}

In [None]:
def plot_count_dict(count_dict):
    count_list = sorted([(k, c) for k, c in count_dict.items()], key=lambda t: t[1], reverse=True)
    other_list = [(i, k, c) for i, (k, c) in enumerate(count_list) if k == "Other"]
    if len(other_list) > 0:
        ix, kx, cx = other_list[0]
        count_list.pop(ix)
        count_list.append((kx, cx))

    plt.figure(figsize = (15,8))
    ax = sns.barplot(x=[t[0] for t in count_list], y=[t[1] for t in count_list])
    plt.xticks(rotation=90)

In [None]:
len(noctuidae_classes)

In [None]:
plot_count_dict(noctuidae_classes)

In [None]:
plot_count_dict(class_counts)

In [None]:
group_dicts = {group: (node_leaf_sum(group_dict), len(node_species_dict(group_dict))) for group, group_dict in label_tree.items()}

sort_dict_other_last(group_dicts)

In [None]:
family_dicts = {family: family_dict for group_dict in label_tree.values() for family, family_dict in group_dict.items()}
family_dicts = {family: (len(node_species_dict(family_dict)), node_leaf_sum(family_dict)) for family, family_dict in family_dicts.items()}

sort_dict_other_last(family_dicts)

In [None]:
genus_dicts = {genus: genus_dict for group_dict in label_tree.values() for family_dict in group_dict.values() for genus, genus_dict in family_dict.items()}
genus_dicts = {genus: (node_leaf_sum(genus_dict), len(node_species_dict(genus_dict))) for genus, genus_dict in genus_dicts.items()}

sort_dict_other_last(genus_dicts)