In [None]:
from pathlib import Path
import json

import pandas as pd
import numpy as np

import plotly.express as px

In [None]:
labels_json = Path('TreeSatBA_v9_60m_multi_labels.json')

# load the data
with labels_json.open() as f:
    multilabels_dict = json.load(f)

# Convert the pair lists to dictionaries
# so that pandas can read it properly
new_values = []
for vs in multilabels_dict.values():
    this_dict = {}
    for key, value in vs:
        this_dict[key] = value
    new_values.append(this_dict)

labels_df = pd.DataFrame.from_records(new_values, index=multilabels_dict.keys())
labels_df = labels_df.fillna(0)

In [None]:
labels_df

In [None]:
counts = labels_df.sum()
counts.sort_values(ascending=False)

In [None]:
total_counts = counts.sum()
fig = px.histogram(x=counts.index, y=(counts/total_counts).round(2), text_auto=True)
fig.update_xaxes(categoryorder="total descending")
fig.update_layout(xaxis_title="Species", yaxis_title="Frequency", 
                  title=f'Total filled images: {int(total_counts)}')

In [None]:
tol = 0.1
not_one = labels_df[(labels_df.sum(axis=1) - 1).abs() > tol]
not_one_count = int(not_one.sum(axis=1).sum())
px.histogram(not_one.sum(axis=1), text_auto=True, nbins=20,
             title=f'Images with sums outside the range of {1-tol} to {1+tol}, total: {not_one_count}')

In [None]:
def plotly_hist(df, filename):

    t = pd.read_csv(filename, header=None)

    df_t = df[df.index.isin(t[0])]
    
    fig = px.histogram(x=df_t.sum().index, y=df_t.sum().round(), text_auto=True)
    fig.update_xaxes(categoryorder="total descending")
    fig.update_layout(xaxis_title="Species", yaxis_title="Occurrences")
    return fig

In [None]:
train_fig = plotly_hist(labels_df, 'train_filenames.lst')
test_fig = plotly_hist(labels_df, 'test_filenames.lst')

display(train_fig)
display(test_fig)

In [None]:
save_path = Path('s2').joinpath('s2_60m.npy')
tif_dict = np.load(save_path, allow_pickle=True).item()