In [None]:
import os
import sys
from glob import glob
from loguru import logger
from tqdm import tqdm
from yaml import load, FullLoader

import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio as rio
from matplotlib import colors
import matplotlib.gridspec as gridspec
from matplotlib import pyplot as plt
from rasterstats import zonal_stats
from rasterio.features import shapes
from shapely.geometry import shape
from skimage.color import rgb2hsv
from skimage.exposure import adjust_gamma, adjust_log, adjust_sigmoid, equalize_adapthist, equalize_hist, rescale_intensity

import cv2

In [None]:
sys.path.insert(1,'../..')
import functions.fct_misc as misc
from functions.fct_rasters import remove_black_border

logger = misc.format_logger(logger)

## Functions

In [None]:
def print_images(image_dict, vmax):
    f, axarr = plt.subplots(3,3)
    axarr[0,0].imshow(image_dict['0_2570184_1148461.tif'], vmin=0, vmax=vmax)
    axarr[0,1].imshow(image_dict['4_2569842_1149296.tif'], vmin=0, vmax=vmax)
    axarr[0,2].imshow(image_dict['1_2571614_1152259.tif'], vmin=0, vmax=vmax)
    axarr[1,0].imshow(image_dict['5_2569300_1148156.tif'], vmin=0, vmax=vmax)
    axarr[1,1].imshow(image_dict['0_2570190_1148491.tif'], vmin=0, vmax=vmax)
    axarr[1,2].imshow(image_dict['10_2580845_1165703.tif'], vmin=0, vmax=vmax)
    axarr[2,0].imshow(image_dict['4_2569483_1149035.tif'], vmin=0, vmax=vmax)
    axarr[2,1].imshow(image_dict['5_2569281_1148151.tif'], vmin=0, vmax=vmax)
    axarr[2,2].imshow(image_dict['6_2567727_1147671.tif'], vmin=0, vmax=vmax)

Config

In [None]:
with open('../../../config/config_symbol_classif.yaml') as fp:
    cfg = load(fp, Loader=FullLoader)['test_notebooks.py']

WORKING_DIR = cfg['working_dir']
OUTPUT_DIR = cfg['output_dir']
TILE_DIR = cfg['tile_dir']

IMAGE_FILE = cfg['image_gpkg']

In [None]:
os.chdir(WORKING_DIR)
os.makedirs(OUTPUT_DIR, exist_ok=True)

Data

In [None]:
logger.info('Read data...')
tile_list = glob(os.path.join(TILE_DIR, '*.tif'))
images_gdf = gpd.read_file(IMAGE_FILE)
images_gdf.loc[images_gdf.CATEGORY == 'undetermined', 'CATEGORY'] = 'undet'

In [None]:
image_data = {}
meta_data = {}
for tile_path in tile_list:
    with rio.open(tile_path) as src:
        tile_name = os.path.basename(tile_path)
        image_data[tile_name] = src.read().transpose(1, 2, 0)
        meta_data[tile_name] = src.meta

## Color filters

In [None]:
image_dict = image_data
vmax = 255
print_images(image_dict, vmax)

In [None]:
cropped_images = {k: remove_black_border(v) for k, v in image_dict.items()}
image_dict = image_data
vmax = 255
print_images(cropped_images, vmax)

In [None]:
def get_pixel_color(image):
    pixel_colors = image.reshape((np.shape(image)[0]*np.shape(image)[1], 3))
    norm = colors.Normalize(vmin=-1.,vmax=1.)
    norm.autoscale(pixel_colors)
    pixel_colors = norm(pixel_colors).tolist()

    return pixel_colors


def plot_hsv(image, fig, spec, pos_x=1, pos_y=1):
    pixel_colors = get_pixel_color(image)
    hsv_nemo = rgb2hsv(image)

    h, s, v = cv2.split(hsv_nemo)
    
    axis = fig.add_subplot(spec[pos_x, pos_y], projection="3d")    

    axis.scatter(h.flatten(), s.flatten(), v.flatten(), facecolors=pixel_colors, marker=".")
    axis.set_xlabel("Hue")
    axis.set_ylabel("Saturation")
    axis.set_zlabel("Value")

In [None]:
fig = plt.figure(figsize=(18, 16))

spec = gridspec.GridSpec(ncols=3, nrows=3, figure=fig)
plot_hsv(image_dict['0_2570184_1148461.tif'], fig, spec, 0, 0)
plot_hsv(image_dict['4_2569842_1149296.tif'], fig, spec, 0, 1)
plot_hsv(image_dict['1_2571614_1152259.tif'], fig, spec, 0, 2)
plot_hsv(image_dict['5_2569300_1148156.tif'], fig, spec, 1, 0)
plot_hsv(image_dict['0_2570190_1148491.tif'], fig, spec, 1, 1)
plot_hsv(image_dict['10_2580845_1165703.tif'], fig, spec, 1, 2)
plot_hsv(image_dict['4_2569483_1149035.tif'], fig, spec, 2, 0)
plot_hsv(image_dict['5_2569281_1148151.tif'], fig, spec, 2, 1)
plot_hsv(image_dict['6_2567727_1147671.tif'], fig, spec, 2, 2)

plt.show()

In [None]:
def plot_rgb(image, fig, spec, pos_x=1, pos_y=1):
    pixel_colors = get_pixel_color(image)

    r, g, b = cv2.split(image)
    
    axis = fig.add_subplot(spec[pos_x, pos_y], projection="3d")    

    axis.scatter(r.flatten(), g.flatten(), b.flatten(), facecolors=pixel_colors, marker=".")
    axis.set_xlabel("Red")
    axis.set_ylabel("Green")
    axis.set_zlabel("Blue")

In [None]:
fig = plt.figure(figsize=(9, 8))

spec = gridspec.GridSpec(ncols=3, nrows=3, figure=fig)
plot_rgb(image_dict['0_2570184_1148461.tif'], fig, spec, 0, 0)
plot_rgb(image_dict['4_2569842_1149296.tif'], fig, spec, 0, 1)
plot_rgb(image_dict['1_2571614_1152259.tif'], fig, spec, 0, 2)
plot_rgb(image_dict['5_2569300_1148156.tif'], fig, spec, 1, 0)
plot_rgb(image_dict['0_2570190_1148491.tif'], fig, spec, 1, 1)
plot_rgb(image_dict['10_2580845_1165703.tif'], fig, spec, 1, 2)
plot_rgb(image_dict['4_2569483_1149035.tif'], fig, spec, 2, 0)
plot_rgb(image_dict['5_2569281_1148151.tif'], fig, spec, 2, 1)
plot_rgb(image_dict['6_2567727_1147671.tif'], fig, spec, 2, 2)

In [None]:
binary_color_list = {key: np.where((i[:, :, 0] < 250) & (i[:, :, 1] < 225) & (i[:, :, 2] < 225) , True, False) for key, i in cropped_images.items()}

In [None]:
image_dict = binary_color_list
vmax = 1
f, axarr = plt.subplots(3,3)
axarr[0,0].imshow(image_dict['0_2570184_1148461.tif'], vmin=0, vmax=vmax)
axarr[0,1].imshow(image_dict['4_2569842_1149296.tif'], vmin=0, vmax=vmax)
axarr[0,2].imshow(image_dict['1_2571614_1152259.tif'], vmin=0, vmax=vmax)
axarr[1,0].imshow(image_dict['5_2569300_1148156.tif'], vmin=0, vmax=vmax)
axarr[1,1].imshow(image_dict['0_2570190_1148491.tif'], vmin=0, vmax=vmax)
axarr[1,2].imshow(image_dict['10_2580845_1165703.tif'], vmin=0, vmax=vmax)
axarr[2,0].imshow(image_dict['4_2569483_1149035.tif'], vmin=0, vmax=vmax)
axarr[2,1].imshow(image_dict['5_2569281_1148151.tif'], vmin=0, vmax=vmax)
axarr[2,2].imshow(image_dict['6_2567727_1147671.tif'], vmin=0, vmax=vmax)

In [None]:
image_data['5_2569281_1148151.tif'][:,:,1]

## HSV Filters

In [None]:
data_hsv = {key: rgb2hsv(i) for key, i in image_data.items()}

In [None]:
image_dict = data_hsv
vmax = 255
print_images(image_dict, vmax)

In [None]:
binary_list = {key: np.where(i[:, :, 2] < 0.90, 1, 0) for key, i in data_hsv.items()}


In [None]:
image_dict = binary_list
vmax = 1
print_images(image_dict, vmax)

In [None]:
binary_list_3 = {key: np.where(i[:, :, 2] < 0.95, 1, np.where(i[:, :, 0] < 0.1, 1, 0)) for key, i in data_hsv.items()}

In [None]:
image_dict = binary_list_3
vmax = 1
print_images(image_dict, vmax)

In [None]:
binary_list_2 = {key: np.where(i[:, :, 2] < 0.95, 1, np.where(i[:, :, 1] > 0.3, 1, 0)) for key, i in data_hsv.items()}


In [None]:
image_dict = binary_list_2
vmax = 1
print_images(image_dict, vmax)

In [None]:
binary_list_final = {key: np.where(i[:, :, 2] < 0.95, True, np.where((i[:, :, 0] < 0.1) & (i[:, :, 1] > 0.3), True, False)) for key, i in data_hsv.items()}

In [None]:
image_dict = binary_list_final
vmax = 1
print_images(image_dict, vmax)

## HSV filter - second round

In [None]:
binary_list = {key: np.where((i[:, :, 0] < 0.1) | (i[:, :, 0] >0.45), 1, 0) for key, i in data_hsv.items()}

In [None]:
image_dict = binary_list
vmax = 1
print_images(image_dict, vmax)

In [None]:
# Black and blue condition
binary_list = {key: np.where((i[:, :, 2] < 0.90) & ((i[:, :, 0] < 0.2) | (i[:, :, 0] > 0.45)), 1, 0) for key, i in data_hsv.items()}

In [None]:
image_dict = binary_list
vmax = 1
print_images(image_dict, vmax)

In [None]:
# Red condition
binary_list = {key: np.where((i[:, :, 1] > 0.15) & (i[:, :, 2] > 0.8) & (i[:, :, 0] < 0.05), 1, 0) for key, i in data_hsv.items()}

In [None]:
image_dict = binary_list
vmax = 1
print_images(image_dict, vmax)

In [None]:
binary_list_final = {}
for name, i in data_hsv.items():
    h, s, v = [i[:, :, band] for band in range(3)]
    condition_red = (s > 0.15) & (v > 0.8) & (h < 0.05)
    condition_black_blue = (v < 0.9) & ((h < 0.2) | (h > 0.45))

    binary_list_final[name] = np.where(condition_black_blue | condition_red, 1, 0, )

In [None]:
image_dict = binary_list_final
vmax = 1
print_images(image_dict, vmax)

## Test pixels under mask

In [None]:
filtered_tile_dir = os.path.join(os.path.dirname(TILE_DIR), 'filtered_symbols_2')
filtered_images = {}
os.makedirs(filtered_tile_dir, exist_ok=True)
for name, image in tqdm(image_data.items()):
    mask = np.repeat(binary_list_final[name][..., np.newaxis], repeats=3, axis=2)
    filtered_images[name] = np.where(mask, image, 0)
    with rio.open(os.path.join(filtered_tile_dir, name), 'w', **meta_data[name]) as src:
        src.write(filtered_images[name].transpose(2, 0, 1))


In [None]:
image_dict = filtered_images
vmax = 255
print_images(image_dict, vmax)

In [None]:
BAND_CORRESPONDENCE = {0: 'R', 1: 'G', 2: 'B'}
STAT_LIST = ['min', 'max', 'std', 'mean', 'median']
cat_list = ['1b', '1n', '1r', '2b', '3b', '3r', '5n', 'undet']
pxl_values_dict = {
    0: {cat: [] for cat in cat_list}, 
    1: {cat: [] for cat in cat_list}, 
    2: {cat: [] for cat in cat_list}
}
stats_df_dict = {band: pd.DataFrame() for band in BAND_CORRESPONDENCE.keys()}
ratio_stats_df = pd.DataFrame()

for name, image in tqdm(image_data.items(), desc="Extract pixel values from tiles"):
    category = images_gdf.loc[images_gdf.image_name == name.rstrip('.tif'), 'CATEGORY'].iloc[0]

    mask = binary_list_final[name]
    if (mask==0).all():
        continue

    # Polygonize mask
    geoms = ((shape(s), v) for s, v in shapes(mask.astype('uint8'), transform = meta_data[name]['transform']) if v == 1)
    mask_gdf = gpd.GeoDataFrame(geoms, columns=['geometry', 'class'], crs = meta_data[name]['crs'])
    mask_gdf = gpd.GeoDataFrame([name], geometry = [mask_gdf.unary_union], columns=['geometry'], crs = meta_data[name]['crs'])  

    for band in BAND_CORRESPONDENCE.keys():
        # Get individual pixel value
        pxl_values_dict[band][category].extend(image[:, :, band][mask].flatten())

        # Get category stats on each image
        tmp_stats = zonal_stats(mask_gdf, os.path.join(filtered_tile_dir, name), stats=STAT_LIST, band_num=band+1)
        tmp_stats_df = pd.DataFrame.from_records(tmp_stats)
        tmp_stats_df['CATEGORY'] = category
        tmp_stats_df['image_name'] = name.rstrip('.tif')
        if not tmp_stats_df[tmp_stats_df['median'].notna()].empty:
            stats_df_dict[band] = pd.concat([stats_df_dict[band], tmp_stats_df[tmp_stats_df['median'].notna()]], ignore_index=True)

In [None]:
stats_df = pd.DataFrame()
for band_nbr, band_letter in BAND_CORRESPONDENCE.items():
    tmp_df = stats_df_dict[band_nbr].copy()
    tmp_df['band'] = band_letter
    stats_df = pd.concat([stats_df, tmp_df], ignore_index=True)
stats_df.to_csv(os.path.join(OUTPUT_DIR, 'stats_on_filtered_bands.csv'), index=False)

In [None]:
for band in tqdm(BAND_CORRESPONDENCE.keys(), desc='Produce boxplots for each band'):
    labels, data = [*zip(*pxl_values_dict[band].items())]

    plt.boxplot(data)
    plt.xticks(range(1, len(labels) + 1), labels)
    plt.title(f'Pixel values on the {BAND_CORRESPONDENCE[band]} band')
    plt.savefig(os.path.join(OUTPUT_DIR, f'boxplot_filtered_pixels_{BAND_CORRESPONDENCE[band]}.png'), bbox_inches='tight')
    plt.close()

    for stat in STAT_LIST:
        stats_df = stats_df_dict[band].loc[: , ['CATEGORY', stat]].copy()
        stats_df.plot.box(by='CATEGORY')
        plt.title(f'{stat.title()} on the {BAND_CORRESPONDENCE[band]} band')
        plt.savefig(os.path.join(OUTPUT_DIR, f'boxplot_filetered_stats_{BAND_CORRESPONDENCE[band]}_{stat}.png'), bbox_inches='tight')
        plt.close()

## Improve the color

In [None]:
from skimage.morphology import binary_closing

In [None]:
closed_filters = {k: binary_closing(v) for k, v in binary_list_final.items()}

In [None]:
image_dict = closed_filters
vmax = 1
print_images(image_dict, vmax)

## Change brightness

In [None]:
gamma_1pt5 = {key: adjust_gamma(v, gamma=1.5) for key, v in image_data.items()}
gamma_0pt85 = {key: adjust_gamma(v, gamma=0.85) for key, v in image_data.items()}
gamma_0pt75 = {key: adjust_gamma(v, gamma=0.75) for key, v in image_data.items()}
gamma_one_half = {key: adjust_gamma(v, gamma=1/2) for key, v in image_data.items()}

In [None]:
image_dict = gamma_1pt5
vmax=255
print_images(image_dict, vmax)

In [None]:
image_dict = image_data
vmax=255
print_images(image_dict, vmax)

In [None]:
image_dict = gamma_0pt85
vmax=255
print_images(image_dict, vmax)

In [None]:
image_dict = gamma_0pt75
vmax=255
print_images(image_dict, vmax)

In [None]:
image_dict = gamma_one_half
vmax=255
print_images(image_dict, vmax)

In [None]:
log_default = {key: adjust_log(v) for key, v in image_data.items()}
log_two = {key: adjust_log(v, gain=2) for key, v in image_data.items()}
log_half = {key: adjust_log(v, gain=0.5) for key, v in image_data.items()}

In [None]:
image_dict = log_default
vmax=255
print_images(image_dict, vmax)

In [None]:
eq_default = {key: equalize_adapthist(v) for key, v in image_data.items()}

In [None]:
image_dict = eq_default
vmax=255
print_images(image_dict, vmax)

In [None]:
eq_gamma_one_half = {key: equalize_adapthist(v, clip_limit=0.003) for key, v in gamma_one_half.items()}
image_dict = eq_gamma_one_half
vmax=255
print_images(image_dict, vmax)

In [None]:
eq_gamma_one_half = {key: equalize_hist(v) for key, v in gamma_one_half.items()}
image_dict = eq_gamma_one_half
vmax=255
print_images(image_dict, vmax)

In [None]:
sigm_default = {key: adjust_sigmoid(v) for key, v in image_data.items()}
image_dict = sigm_default
vmax=255
print_images(image_dict, vmax)

In [None]:
image_data = {}
meta_data = {}
for tile_path in tile_list:
    with rio.open(tile_path) as src:
        tile_name = os.path.basename(tile_path)
        image_data[tile_name] = src.read().transpose(1, 2, 0)
        meta_data[tile_name] = src.meta

In [None]:
image_dict = image_data
vmax=255
print_images(image_dict, vmax)

In [None]:
add_red = {}
for key, v in image_data.items():
    new_image = np.array(v, copy=False)
    new_image[:,:, 1:3] = np.where(new_image[:, :, :1] > 245, new_image[:, :, 1:3]-20, new_image[:, :, 1:3])
    add_red[key] = new_image
image_dict = add_red
vmax=255
print_images(image_dict, vmax)

In [None]:
image_data = {}
meta_data = {}
for tile_path in tile_list:
    with rio.open(tile_path) as src:
        tile_name = os.path.basename(tile_path)
        image_data[tile_name] = src.read().transpose(1, 2, 0)
        meta_data[tile_name] = src.meta

In [None]:
add_blue = {}
for key, v in image_data.items():
    new_image = np.array(v, copy=False)
    new_image[:,:, 0:2] = np.where((new_image[:, :, :1] > 225) & (new_image[:, :, 1:2] > 225) & (new_image[:, :, 2:3] < 255-20), new_image[:, :, 0:2]-50, new_image[:, :, 0:2])
    new_image[:, :, 2:3] = np.where((new_image[:, :, :1] > 225) & (new_image[:, :, 1:2] > 225) & (new_image[:, :, 2:3] < 255-20), new_image[:, :, 2:3]+20, new_image[:, :, 2:3])
    add_blue[key] = new_image
image_dict = add_blue
vmax=255
print_images(image_dict, vmax)

In [None]:
((new_image[:, :, :1] > 245) & (new_image[:,:, 2:3] < 255-20)).shape