In [1]:
import pandas as pd
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
from pathlib import Path
from shapely.geometry import box, Point, Polygon
import xarray as xr
import rioxarray
from rasterio.features import shapes
import os
import folium
from geocube.api.core import make_geocube
from src.utils import read_neon_trees
from functools import cached_property

root = Path.cwd()

In [None]:
class BaCalculator():
        # tree_points: gpkg with columns [taxonID, dbh_inches]
        # reference_raster: sentinel data from site
        def __init__(self,root,date,site_name,epsg,code_type):
                self.root = root
                self.date = date
                self.site_name = site_name
                self.epsg = epsg
                self.tree_points = self.open_tree_points()
                self.reference_raster = self.open_ref_raster()
                if code_type == 'neon':
                        self.tree_dict = {'TSCA':'hemlock','FAGR': 'beech','ACRU': 'maple','BEAL2':'birch','PIRU':'spruce','ACSAS':'maple','BEPAP': 'birch','FRAM2': 'ash','ACPE': 'maple','POGR4': 'aspen', 'POTR5' : 'aspen', 'BEPAC2': 'birch', 'ABBA' : 'balsam_fir','BECAC':'birch','BEPO':'birch', 'PIST': 'pine', 'PIRE': 'pine', 'FRPE': 'ash', 'ULAM': 'elm', 'QURU': 'oak', 'TIAM': 'basswood', 'OSVI': 'ironwood', 'PRPE':'cherry','ACNI5':'maple','ACSA2':'maple','BELE':'birch','BEPA':'birch','CADE12':'chestnut','CAGL8':'hickory','CAOV2':'hickory','FRNI':'ash','HAVI4':'witchhazel','NYSY':'tupelo','PIAB':'spruce','PIMA':'spruce','PIRI':'pine','PRSES':'cherry','QUAL':'oak','QUVE':'oak'}
                elif code_type == 'fia':
                        self.tree_dict = {12:'balsam_fir',71: 'tamarack',91:'spruce',94:'spruce', 95: 'spruce',97: 'spruce',105: 'pine', 125: 'pine', 129:'pine', 130: 'pine', 261:'hemlock',316:'maple',318:'maple',371:'birch', 372: 'birch', 375: 'birch', 531:'beech',541:'ash', 543: 'ash', 743:'aspen',746: 'aspen', 762:'cherry',833:'oak',950: 'basswood',970: 'elm',999:'unknown'}

                self.vectorized_raster_grid = None
                self.basal_area = None

        def open_tree_points(self):
                tp = gpd.read_file(self.root / 'output' / self.site_name.upper() / f'all_trees_{self.site_name.upper()}.gpkg')
                if tp.crs.to_epsg() != self.epsg:
                        tp.to_crs(self.epsg)
                return tp

        def open_ref_raster(self):
                r = xr.open_dataarray(self.root /'sentinel_data' / self.site_name / f'{self.date}_{self.site_name}.nc')
                r = r.rio.write_crs(self.epsg).rio.set_spatial_dims(x_dim="x",y_dim="y",).rio.write_coordinate_system()
                return r.isel(band=0,time=0)

        @cached_property
        def vectorize_raster(self):
                # vectorize raster
                r1_unique = np.arange(self.reference_raster.size).reshape(self.reference_raster.shape) 
                r1_unique = r1_unique.astype('uint16') 
                r1 = xr.DataArray(r1_unique, coords={'y': self.reference_raster.y.values, 'x': self.reference_raster.x.values},dims=['y','x'])
                r1 = r1.rio.write_crs(self.epsg).rio.set_spatial_dims(x_dim="x",y_dim="y",).rio.write_coordinate_system()
                polygons = shapes(r1_unique, transform=r1.rio.transform()) # returns (geojson, value) for each raster grid cell

                #Create a list of polygon geojsons
                geometry = [Polygon(polygon['coordinates'][0])for polygon, _ in polygons]

                # Create a GeoDataFrame from the features
                gdf_dict = {'geometry':geometry,'cell_id':list(range(1,len(geometry)+1))}

                gdf = gpd.GeoDataFrame(gdf_dict, crs=r1.rio.crs)
                return gdf

        def join_trees_to_grid(self):
                # join tree points to polygons
                j = gpd.sjoin(self.vectorize_raster, self.tree_points, predicate='contains')

                # add genus column based on tree codes
                j['genus'] = j['taxonID'].map(self.tree_dict)
        
                # keep only needed columns
                j = j[['geometry','cell_id','taxonID','genus','dbh_inches']]
                
                # dbh in inches gives basal area in sqare feet
                j['basal_area'] = .005454 * j['dbh_inches']**2


                return j
        
        def calculate_basal_area(self,num_species):

                # join tree points to vector grid derived from sentinel raster
                j = self.join_trees_to_grid()
                
                # get basal area for each genus in separate columns
                for value in self.tree_dict.values():
                        if value in j['genus'].values:
                                j[f'{value}_basal_area'] = j.apply(lambda row: row['basal_area'] if row['genus']==value else 0, axis=1)

                # get total basal area per cell_id
                j2 = j.groupby('cell_id').agg({'basal_area': 'sum', 'taxonID': 'count'}).rename(columns={'taxonID': 'num_trees'}).reset_index()
                # get genus basal area per cell_id
                for genus in j.genus.unique():
                        if genus is not None:
                                j1 = j.groupby('cell_id').agg({f'{genus}_basal_area':'sum'}).reset_index()
                                j2 = pd.merge(j2,j1,on='cell_id',how='left')

                # merge totals with geometry
                merged = self.vectorize_raster.merge(j2, on='cell_id', how='inner')

                # convert basal area totals to percent basal area
                for genus in j.genus.unique():
                        if genus is not None:
                                merged[f'{genus}_basal_area'] = (merged[f'{genus}_basal_area']/merged['basal_area'])*100

                all_ba = xr.Dataset()
                all_species = j['genus'].value_counts().head(num_species).index.to_list() # get top species
                if 'beech' not in all_species:
                        all_species.append('beech')
                for i in range(len(all_species)):
                        target_species = all_species[i]
                        nontarget_species = [x for x in all_species if x != target_species] 

                        category_labels = [
                        f'0% {target_species}',
                        f'100% {target_species}',
                        f'mixed {target_species}'
                        ]
                        # Add one label for each non-target species
                        category_labels += [f'mixed {species}' for species in nontarget_species]
                        # Add category for other
                        category_labels.append('mixed other')

                        # Category numbers 
                        categories = list(range(len(category_labels)))

                        label_df = pd.DataFrame({'cat_numbers': categories, 'cat_labels': category_labels})
                        label_df.to_csv(self.root / 'output' / self.site_name.upper() / f'{target_species}_category_labels.csv') # save labels for reference

                        conditions = [merged[f'{target_species}_basal_area']==0.0,
                                        merged[f'{target_species}_basal_area']==100.0,
                                        ((merged[f'{target_species}_basal_area']>=50)&(merged[f'{target_species}_basal_area']<100))]
                        for species in nontarget_species:
                                cond = ((merged[f'{species}_basal_area'] >= 50) & (merged[f'{species}_basal_area'] < 100))
                                conditions.append(cond)
                                       

                        merged[f'{target_species}_category'] = np.select(conditions,categories[:-1],default=categories[-1])

                        #remove cells with only one target tree
                        merged = merged.loc[~((merged[f'{target_species}_basal_area']==100.0)&(merged['num_trees']==1))] 

                        # convert to raster
                        ba = make_geocube(
                        vector_data=merged,
                        measurements=[f"{target_species}_basal_area",f"{target_species}_category"],
                        like=self.reference_raster, # ensure the data are on the same grid
                        )

                        all_ba[f"{target_species}_basal_area"] = ba[f"{target_species}_basal_area"]
                        all_ba[f"{target_species}_category"] = ba[f"{target_species}_category"]
                
                self.basal_area = all_ba

                all_ba.to_netcdf(self.root / 'output' / self.site_name.upper() / f'basal_area_{self.site_name.upper()}.nc')
                print(f'basal area saved to {self.root} / output / {self.site_name.upper()} / basal_area_{self.site_name.upper()}.nc')
        
        def plot_points_over_grid(self):
                cc = tuple(self.tree_points.to_crs(4326).geometry.get_coordinates().iloc[0,:])[::-1]
                m = folium.Map(location= cc, zoom_start=11)

                j = self.join_trees_to_grid()

                folium.GeoJson(
                j.to_crs(4326)
                ).add_to(m)

                folium.GeoJson(
                self.tree_points.to_crs(4326),
                marker=folium.Circle(radius=.5, fill_color="orange", fill_opacity=0.4, color="red", weight=1)
                ).add_to(m)

                return m
                



In [20]:
suny = BaCalculator(root=root,date=2018,site_name='suny',epsg=26918,code_type='fia')

In [None]:
harv = BaCalculator(root=root,date=2023,site_name='harv',epsg=26918,code_type='neon')

In [11]:
m = suny.plot_points_over_grid()
m

In [21]:
suny.calculate_basal_area(num_species=5)

basal area saved to c:\Users\roseh\OneDrive - Hunter - CUNY\Documents\beech_tree / output / SUNY / basal_area_SUNY.nc
