In [1]:
import pandas as pd
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

In [2]:
def get_onehots(entry, unique_classes):
    """Retrieve one_hot encodings"""
    x = [unique_classes.index(c) for c in entry]
    return F.one_hot(torch.tensor(x), num_classes=len(unique_classes)).sum(dim=0).tolist()

def get_species_classes(records, unique_classes, level, filter_uncommon=False):
    """Get one_hot encoded classes for each species"""
    if filter_uncommon:
        species_classes_counts = pd.DataFrame(records[["species_key",level]].value_counts()).reset_index().rename(columns={0:"count"})
        species_classes_counts = species_classes_counts.join(pd.DataFrame(records["species_key"].value_counts()).rename(columns={"species_key":"total"})["total"], on="species_key", how="inner")
        species_classes_counts["fraction"] = species_classes_counts["count"]/species_classes_counts["total"]
        #for sk in species_classes_counts["species_key"].unique():
        #    species_classes_counts.loc[species_classes_counts["species_key"]==sk,"cutoff"] = species_classes_counts.loc[species_classes_counts["species_key"] == sk,"fraction"].quantile(0.003)
        #records = species_classes_counts[species_classes_counts["fraction"]>species_classes_counts["cutoff"]]
        #records = species_classes_counts[species_classes_counts["fraction"]>species_classes_counts["cutoff"]]
        records = species_classes_counts[species_classes_counts["fraction"]>0.01]
    species_classes = records.groupby("species_key")[level].unique().reset_index().rename(columns={level:"classes"})
    species_classes["classes_onehot"] = species_classes["classes"].apply(lambda x : get_onehots(x, unique_classes))
    return species_classes

In [3]:
def check(species_list, reference_list):
    out = 1
    for i in species_list:
        if i not in reference_list:
            out = 0
    return out

In [4]:
def retrieve_in_ref_list(species_list, ref_list):
    out = [spe for spe in species_list if spe in ref_list]
    return out

In [27]:
a = pd.read_json("/data/nicola/WSH/final_data/L2_species_keys.json", orient="records").set_index("ID")
#a = temp.loc[a["species_key"].to_list()]
half1 = a.sample(frac=0.5, random_state=42)
half2 = a.drop(half1.index)

In [28]:
half2

Unnamed: 0_level_0,species_key
ID,Unnamed: 1_level_1
1,3170807
2,3105433
3,2883073
4,4299368
5,2891147
...,...
3612,8961574
3616,2680229
3618,2682592
3625,3033252


In [7]:
all_data = pd.read_json("/data/nicola/WSH/final_data/L2_all_data.json", orient="records")
all_data["half1_species"] = all_data["species_key"].apply(lambda x : retrieve_in_ref_list(x,half1["species_key"].to_list()))
all_data["half2_species"] = all_data["species_key"].apply(lambda x : retrieve_in_ref_list(x, half2["species_key"].to_list()))

In [8]:
all_data["len_half1"] = all_data["half1_species"].apply(lambda x : len(x))
all_data["len_half2"] = all_data["half2_species"].apply(lambda x : len(x))

In [9]:
all_data[["zone_id","species_key","half1_species","half2_species","species_based_class"]]

Unnamed: 0,zone_id,species_key,half1_species,half2_species,species_based_class
0,9,"[3032837, 3170807, 3105433, 2883073]",[3032837],"[3170807, 3105433, 2883073]","[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, ..."
1,10,[4299368],[],[4299368],"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, ..."
2,13,[2891147],[],[2891147],"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, ..."
3,26,"[5137582, 7799370]","[5137582, 7799370]",[],"[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, ..."
4,29,"[8207244, 5352367, 3170807]",[],"[8207244, 5352367, 3170807]","[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, ..."
...,...,...,...,...,...
25958,98095,"[3170040, 2812375, 3029627, 1537719, 3189747, ...","[3170040, 2812375, 3029627, 1537719, 3189747, ...",[8890062],"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, ..."
25959,98096,[3928139],[3928139],[],"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, ..."
25960,98097,[2882835],[],[2882835],"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, ..."
25961,98101,"[5275365, 3112620]",[5275365],[3112620],"[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, ..."


In [23]:
split_1 = all_data.loc[all_data["len_half1"]>0].drop(["len_half2","half2_species","species_key","species_count"], axis=1).rename(columns={"half1_species":"species_key","len_half1":"species_count"})
split_2 = all_data.loc[all_data["len_half2"]>0].drop(["len_half1","half1_species","species_key","species_count"], axis=1).rename(columns={"half2_species":"species_key","len_half2":"species_count"})

In [24]:
split_2

Unnamed: 0,zone_id,maps_based_class,split,shape_area,species_based_class,num_classes,species_key,species_count
0,9,45,train,665918.930058,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, ...",9,"[3170807, 3105433, 2883073]",3
1,10,43,train,2991.243434,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, ...",16,[4299368],1
2,13,66,train,230749.174683,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, ...",16,[2891147],1
4,29,54,train,101384.855230,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, ...",10,"[8207244, 5352367, 3170807]",3
5,37,43,train,10094.769110,"[0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, ...",14,[5405976],1
...,...,...,...,...,...,...,...,...
25956,98092,53,test,4501.734903,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, ...",14,[2891147],1
25957,98093,34,test,42639.142599,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, ...",17,[5220170],1
25958,98095,43,test,63666.903126,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, ...",6,[8890062],1
25960,98097,53,test,3613.157923,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, ...",15,[2882835],1


In [12]:
split_1.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes,len_half1
count,19955.0,19955.0,19955.0,19955.0,19955.0,19955.0
mean,48538.779604,50.682335,153995.0,4.483538,8.907793,2.460937
std,27893.566574,24.056461,569874.7,3.406487,4.90534,1.789828
min,9.0,2.0,17.94476,1.0,0.0,1.0
25%,23763.0,42.0,15610.38,1.0,5.0,1.0
50%,49330.0,45.0,54573.41,3.0,9.0,2.0
75%,71173.0,66.0,147609.7,7.0,12.0,4.0
max,98118.0,94.0,31007130.0,10.0,21.0,10.0


In [13]:
split_1.len_half1.value_counts()

1     8912
2     3718
3     2332
4     1872
5     1436
6      973
7      491
8      177
9       35
10       9
Name: len_half1, dtype: int64

In [14]:
split_2.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes,len_half2
count,19310.0,19310.0,19310.0,19310.0,19310.0,19310.0
mean,48965.372294,50.755567,155637.2,4.592595,8.572553,2.516054
std,27570.263443,23.850685,552202.8,3.407622,4.615418,1.832243
min,9.0,2.0,8.114492,1.0,0.0,1.0
25%,24607.0,42.0,17298.96,2.0,5.0,1.0
50%,49507.5,45.0,57623.47,3.0,8.0,2.0
75%,71234.0,66.0,152116.7,8.0,12.0,4.0
max,98101.0,94.0,31007130.0,10.0,21.0,10.0


In [15]:
split_2.len_half2.value_counts()

1     8485
2     3553
3     2238
4     1745
5     1463
6     1062
7      539
8      189
9       32
10       4
Name: len_half2, dtype: int64

In [16]:
train = pd.read_json("/data/nicola/WSH/final_data/L2_train_data.json", orient="records")
train_species = train["species_key"]
train_species = train_species.to_list()
train_species = sum(train_species, [])
train_species = list(set(train_species))
len(train_species)

3137

In [17]:
val = pd.read_json("/data/nicola/WSH/final_data/L2_val_data.json", orient="records")
val_species = val["species_key"]
val_species = val_species.to_list()
val_species = sum(val_species, [])
val_species = list(set(val_species))
len(val_species)

1625

In [18]:
test = pd.read_json("/data/nicola/WSH/final_data/L2_test_data.json", orient="records")
test_species = test["species_key"]
test_species = test_species.to_list()
test_species = sum(test_species, [])
test_species = list(set(test_species))
len(test_species)

2470

In [19]:
common_train_test_species = [spe for spe in test_species if spe in train_species]

In [20]:
len(common_train_test_species)

2077

In [21]:
all = pd.read_json("/data/nicola/WSH/final_data/L2_all_data.json", orient="records")
all["species_key"] = all["species_key"].apply(lambda x : retrieve_in_ref_list(x,common_train_test_species))
all["species_count"] = all["species_key"].apply(lambda x : len(x))
all = all[all["species_count"]>0]
all.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes
count,25753.0,25753.0,25753.0,25753.0,25753.0
mean,48666.938764,50.615695,138738.8,3.670252,9.844989
std,27949.58184,23.933494,512668.4,3.173786,4.921894
min,9.0,2.0,8.114492,1.0,0.0
25%,24009.0,42.0,12945.84,1.0,6.0
50%,49354.0,45.0,46712.61,2.0,10.0
75%,71432.0,66.0,132550.5,6.0,14.0
max,98118.0,94.0,31007130.0,10.0,21.0
