In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import plotly.express as px
import plotly.offline as py_offline
import json

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.cluster import AgglomerativeClustering

In [None]:
import ee
import geemap.plotlymap as geemap

In [None]:
ee.Authenticate()
ee.Initialize(project='sentinel-treeclassification')

In [None]:
target_column = 'tree_name'
usecols = ['latitude', 'longitude'] + [target_column, 'load_date']
trees_df = pd.read_csv("data/Borough_tree_list_2021July.csv", usecols=usecols, parse_dates=['load_date'])
trees_df.info()

In [None]:
trees_df[target_column].unique()

In [None]:
trees_df[target_column].isna().sum()

In [None]:
trees_df = trees_df.dropna(subset=target_column)
trees_df.isna().sum()

In [None]:
london_trees_mapbox = px.scatter_mapbox(trees_df.sample(n=10000), lat="latitude", lon="longitude", color=target_column,
                        zoom=10, mapbox_style="carto-darkmatter", height=800)
london_trees_mapbox

In [None]:
trees_df['load_date'].value_counts()

In [None]:
date_indices = trees_df['load_date'].value_counts().index

In [None]:
trees_df['load_date'][trees_df['load_date'] == date_indices[2]] = date_indices[1]
trees_df['load_date'].value_counts()

In [None]:
px.histogram(trees_df, x=target_column, text_auto=True).update_xaxes(categoryorder="total descending")

In [None]:
trees_gdf = gpd.GeoDataFrame(
    trees_df, geometry=gpd.points_from_xy(x=trees_df['longitude'], y=trees_df['latitude'], crs=4326)
)
# trees_gdf = trees_gdf.drop(['longitude', 'latitude'], axis=1)
trees_gdf.head(1)

In [None]:
london_trees_bbox = ee.Geometry.BBox(*trees_gdf.geometry.total_bounds)
london_trees_centroid = trees_gdf.dissolve().to_crs(epsg=6933).centroid.to_crs(epsg=4326)[0]

In [None]:
class SentinelGetter:
    def mask_s2_clouds(self, image):
      # Quality assessment with resolution in meters
      qa = image.select('QA60')
      # Bits 10 and 11 are clouds and cirrus, respectively.
      cloud_bit_mask = 1 << 10
      cirrus_bit_mask = 1 << 11
      # Both flags should be set to zero, indicating clear conditions.
      mask = (
          qa.bitwiseAnd(cloud_bit_mask)
          .eq(0)
          .And(qa.bitwiseAnd(cirrus_bit_mask).eq(0))
      )
      return image.updateMask(mask)

    def get_image(self, center_date, bbox):
        month = pd.DateOffset(months=1)
        image = (
            ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
            .filterDate(center_date - month, center_date + month)
            # Pre-filter to get less cloudy granules.
            .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20))
            .map(self.mask_s2_clouds)
            .mean()
            .clip(bbox)
        )
        return image

In [None]:
sentinel_image_2020 = SentinelGetter().get_image(date_indices[1], london_trees_bbox)
sentinel_image_2018 = SentinelGetter().get_image(date_indices[0], london_trees_bbox)
rgb_max = 3000
rgb_bands = ['B4', 'B3', 'B2']
visualization = {
    'min': 0.0,
    'max': rgb_max,
    'bands': rgb_bands,
    'layer': "below",
}

london_2018map = geemap.Map(center=(london_trees_centroid.y, london_trees_centroid.x), zoom=10)
london_2018map.addLayer(sentinel_image_2018, visualization, 'RGB')
py_offline.iplot(london_2018map)

In [None]:
min_freq = 0.01
value_counts = trees_gdf[target_column].value_counts()
mask = (value_counts/value_counts.sum()).lt(min_freq)
trees_gdf[target_column] = pd.Series(np.where(trees_gdf[target_column].isin(value_counts[mask].index), 'Other_minor', trees_gdf[target_column]))

trees_gdf[target_column] = trees_gdf[target_column].astype('category')
px.histogram(trees_gdf, x=target_column, text_auto=True).update_xaxes(categoryorder="total descending")

In [None]:
trees_gdf = trees_gdf.sort_values(by='load_date')
trees_regions = trees_gdf.to_crs(epsg=6933).buffer(10, cap_style=3).to_crs(epsg=4326)

In [None]:
trees_regions.plot()

In [None]:
%%time
trees_regions_gdf = gpd.GeoDataFrame(geometry=trees_regions)
date_mask_2018 = (trees_gdf['load_date'] == date_indices[0])
date_mask_2020 = (trees_gdf['load_date'] == date_indices[1])

In [None]:
trees_regions_2018_ee = geemap.geopandas_to_ee(trees_regions_gdf[date_mask_2018].head(1000))
sentinel_image_2018.sampleRegions(trees_regions_2018_ee)

In [None]:
# train_y, test_y = train_test_split(trees_gdf, train_size=0.7)
# # Train a 10-tree random forest classifier from the training sample.
# trained_classifier = ee.Classifier.smileRandomForest(10).train(
#     features=training_sample,
#     classProperty=label,
#     inputProperties=img.bandNames(),
# )

In [None]:
# points_in_m = trees_gdf.to_crs(epsg=6933)
# xy = list(map(list, zip(points_in_m.geometry.x, points_in_m.geometry.y)))
# cluster = AgglomerativeClustering(
#     n_clusters=None, 
#     linkage='single',
#     metric='euclidean',
#     distance_threshold=50)
    
# cluster.fit(xy)
# trees_squares['group'] = cluster.labels_
# counts = trees_squares.value_counts('group')
# counts.shape, trees_squares.shape