In [2]:
import pandas as pd
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

In [3]:
def get_onehots(entry, unique_classes):
    """Retrieve one_hot encodings"""
    x = [unique_classes.index(c) for c in entry]
    return F.one_hot(torch.tensor(x), num_classes=len(unique_classes)).sum(dim=0).tolist()

def get_species_classes(records, unique_classes, level, filter_uncommon=False):
    """Get one_hot encoded classes for each species"""
    if filter_uncommon:
        species_classes_counts = pd.DataFrame(records[["species_key",level]].value_counts()).reset_index().rename(columns={0:"count"})
        species_classes_counts = species_classes_counts.join(pd.DataFrame(records["species_key"].value_counts()).rename(columns={"species_key":"total"})["total"], on="species_key", how="inner")
        species_classes_counts["fraction"] = species_classes_counts["count"]/species_classes_counts["total"]
        #for sk in species_classes_counts["species_key"].unique():
        #    species_classes_counts.loc[species_classes_counts["species_key"]==sk,"cutoff"] = species_classes_counts.loc[species_classes_counts["species_key"] == sk,"fraction"].quantile(0.003)
        #records = species_classes_counts[species_classes_counts["fraction"]>species_classes_counts["cutoff"]]
        #records = species_classes_counts[species_classes_counts["fraction"]>species_classes_counts["cutoff"]]
        records = species_classes_counts[species_classes_counts["fraction"]>0.01]
    species_classes = records.groupby("species_key")[level].unique().reset_index().rename(columns={level:"classes"})
    species_classes["classes_onehot"] = species_classes["classes"].apply(lambda x : get_onehots(x, unique_classes))
    return species_classes

In [4]:
def check(species_list, reference_list):
    out = 1
    for i in species_list:
        if i not in reference_list:
            out = 0
    return out

In [5]:
def retrieve_in_ref_list(species_list, ref_list):
    out = [spe for spe in species_list if spe in ref_list]
    return out

In [6]:
a = pd.read_json("/data/nicola/WSH/final_data/species_keys.json", orient="records").set_index("ID")
#a = temp.loc[a["species_key"].to_list()]
half1 = a.sample(frac=0.5, random_state=42)
half2 = a.drop(half1.index)

In [7]:
half2

Unnamed: 0_level_0,species_key
ID,Unnamed: 1_level_1
1,3170807
2,3105433
3,2883073
4,4299368
5,2891147
...,...
3612,8961574
3616,2682592
3618,2669208
3625,7331276


In [8]:
all_data = pd.read_json("/data/nicola/WSH/final_data/all_data.json", orient="records")
all_data["half1_species"] = all_data["species_key"].apply(lambda x : retrieve_in_ref_list(x,half1["species_key"].to_list()))
all_data["half2_species"] = all_data["species_key"].apply(lambda x : retrieve_in_ref_list(x, half2["species_key"].to_list()))

In [9]:
all_data["len_half1"] = all_data["half1_species"].apply(lambda x : len(x))
all_data["len_half2"] = all_data["half2_species"].apply(lambda x : len(x))

In [10]:
all_data[["zone_id","species_key","half1_species","half2_species","species_based_class"]]

Unnamed: 0,zone_id,species_key,half1_species,half2_species,species_based_class
0,9,"[3032837, 3170807, 3105433, 2883073]",[3032837],"[3170807, 3105433, 2883073]","[0, 0, 1, 1, 0, 1, 0, 0, 0]"
1,10,[4299368],[],[4299368],"[0, 0, 1, 1, 0, 1, 1, 1, 0]"
2,13,[2891147],[],[2891147],"[0, 0, 1, 1, 0, 1, 0, 0, 1]"
3,26,"[5137582, 7799370]","[5137582, 7799370]",[],"[0, 0, 1, 1, 1, 1, 0, 0, 0]"
4,29,"[8207244, 5352367, 3170807]",[],"[8207244, 5352367, 3170807]","[0, 0, 1, 1, 0, 1, 0, 0, 0]"
...,...,...,...,...,...
25958,98095,"[3170040, 2812375, 3029627, 1537719, 3189747, ...","[2812375, 3029627, 1537719, 3189747, 5357013]","[3170040, 8890062]","[0, 0, 0, 1, 0, 1, 0, 0, 0]"
25959,98096,[3928139],[3928139],[],"[0, 0, 1, 1, 0, 1, 0, 0, 1]"
25960,98097,[2882835],[],[2882835],"[0, 0, 1, 1, 1, 1, 0, 0, 0]"
25961,98101,"[5275365, 3112620]",[5275365],[3112620],"[0, 0, 1, 1, 0, 1, 0, 0, 1]"


In [11]:
split_1 = all_data.loc[all_data["len_half1"]>0].drop(["len_half2","half2_species"], axis=1)
split_2 = all_data.loc[all_data["len_half2"]>0].drop(["len_half1","half1_species"], axis=1)

In [12]:
split_2

Unnamed: 0,zone_id,maps_based_class,split,shape_area,species_key,species_count,species_based_class,num_classes,half2_species,len_half2
0,9,4,train,665918.930058,"[3032837, 3170807, 3105433, 2883073]",4,"[0, 0, 1, 1, 0, 1, 0, 0, 0]",3,"[3170807, 3105433, 2883073]",3
1,10,4,train,2991.243434,[4299368],1,"[0, 0, 1, 1, 0, 1, 1, 1, 0]",5,[4299368],1
2,13,6,train,230749.174683,[2891147],1,"[0, 0, 1, 1, 0, 1, 0, 0, 1]",4,[2891147],1
4,29,5,train,101384.855230,"[8207244, 5352367, 3170807]",3,"[0, 0, 1, 1, 0, 1, 0, 0, 0]",3,"[8207244, 5352367, 3170807]",3
5,37,4,train,10094.769110,[5405976],1,"[0, 0, 1, 1, 1, 1, 0, 0, 1]",5,[5405976],1
...,...,...,...,...,...,...,...,...,...,...
25956,98092,5,test,4501.734903,"[2891147, 3033289]",2,"[0, 0, 1, 1, 0, 1, 0, 0, 1]",4,[2891147],1
25957,98093,3,test,42639.142599,[5220170],1,"[0, 0, 1, 1, 1, 1, 0, 0, 0]",4,[5220170],1
25958,98095,4,test,63666.903126,"[3170040, 2812375, 3029627, 1537719, 3189747, ...",7,"[0, 0, 0, 1, 0, 1, 0, 0, 0]",2,"[3170040, 8890062]",2
25960,98097,5,test,3613.157923,[2882835],1,"[0, 0, 1, 1, 1, 1, 0, 0, 0]",4,[2882835],1


In [13]:
split_1.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes,len_half1
count,19588.0,19588.0,19588.0,19588.0,19588.0,19588.0
mean,48512.095773,5.16454,155453.0,4.543343,3.693435,2.466459
std,27756.626142,1.889594,571396.6,3.408896,1.469071,1.807968
min,9.0,1.0,17.94476,1.0,0.0,1.0
25%,23904.5,4.0,16071.7,2.0,3.0,1.0
50%,49374.0,5.0,55916.27,3.0,4.0,2.0
75%,71023.75,6.0,150007.6,8.0,4.0,4.0
max,98118.0,9.0,31007130.0,10.0,8.0,10.0


In [14]:
split_1.len_half1.value_counts()

1     8846
2     3577
3     2248
4     1732
5     1453
6     1005
7      518
8      166
9       38
10       5
Name: len_half1, dtype: int64

In [15]:
split_2.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes,len_half2
count,19796.0,19796.0,19796.0,19796.0,19796.0,19796.0
mean,48872.377501,5.178167,156183.1,4.522176,3.629218,2.494443
std,27693.936859,1.867371,567208.7,3.400884,1.355342,1.808045
min,9.0,1.0,8.114492,1.0,0.0,1.0
25%,24444.0,4.0,16817.8,2.0,3.0,1.0
50%,49426.0,5.0,56347.54,3.0,3.0,2.0
75%,71289.25,6.0,150412.9,8.0,4.0,4.0
max,98101.0,9.0,31007130.0,10.0,8.0,10.0


In [16]:
split_2.len_half2.value_counts()

1     8704
2     3698
3     2307
4     1827
5     1517
6      993
7      536
8      177
9       30
10       7
Name: len_half2, dtype: int64

In [17]:
train = pd.read_json("/data/nicola/WSH/final_data/train_data.json", orient="records")
train_species = train["species_key"]
train_species = train_species.to_list()
train_species = sum(train_species, [])
train_species = list(set(train_species))
len(train_species)

3137

In [18]:
val = pd.read_json("/data/nicola/WSH/final_data/val_data.json", orient="records")
val_species = val["species_key"]
val_species = val_species.to_list()
val_species = sum(val_species, [])
val_species = list(set(val_species))
len(val_species)

1625

In [19]:
test = pd.read_json("/data/nicola/WSH/final_data/test_data.json", orient="records")
test_species = test["species_key"]
test_species = test_species.to_list()
test_species = sum(test_species, [])
test_species = list(set(test_species))
len(test_species)

2470

In [20]:
common_train_test_species = [spe for spe in test_species if spe in train_species]

In [21]:
len(common_train_test_species)

2077

In [22]:
all = pd.read_json("/data/nicola/WSH/final_data/all_data.json", orient="records")
all["species_key"] = all["species_key"].apply(lambda x : retrieve_in_ref_list(x,common_train_test_species))
all["species_count"] = all["species_key"].apply(lambda x : len(x))
all = all[all["species_count"]>0]
all.describe()

Unnamed: 0,zone_id,maps_based_class,shape_area,species_count,num_classes
count,25752.0,25752.0,25752.0,25752.0,25752.0
mean,48677.570519,5.147872,138730.6,3.670395,3.926724
std,27952.401223,1.900734,512677.6,3.171782,1.477475
min,9.0,1.0,8.114492,1.0,0.0
25%,24026.75,4.0,12963.18,1.0,3.0
50%,49356.0,5.0,46714.08,2.0,4.0
75%,71434.0,6.0,132516.4,6.0,5.0
max,98118.0,9.0,31007130.0,10.0,8.0
