In [None]:
import geopandas as gpd
import rasterio
from rasterio.mask import mask
from rasterio.features import geometry_mask
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import os

# Paths
shapefile_path = '../study-area/barishal.shp'
lulc_input_path = '../data/interim/barishal_lulc_map.tif'
lulc_clipped_tif = '../results/barishal_lulc_clipped.tif'
lulc_clipped_png = '../results/barishal_lulc_clipped.png'

# Load shapefile and convert to geometry
aoi = gpd.read_file(shapefile_path).to_crs("EPSG:4326")
geometry = [aoi.geometry[0].__geo_interface__]

# Open and clip raster
with rasterio.open(lulc_input_path) as src:
    lulc_array = src.read(1)
    transform = src.transform
    profile = src.profile.copy()
    height, width = src.shape

    # Generate a mask: True = outside AOI
    mask_arr = geometry_mask(
        [geom.__geo_interface__ for geom in aoi.geometry],
        transform=transform,
        invert=False,
        out_shape=(height, width)
    )

    # Set pixels outside AOI to 255 (NoData)
    lulc_array[~mask_arr] = lulc_array[~mask_arr]
    lulc_array[mask_arr] = 255  # NoData
    
# Update metadata
profile.update({
    'dtype': 'uint8',
    'count': 1,
    'compress': 'lzw',
    'nodata': 255
})

# Save clipped raster
with rasterio.open(lulc_clipped_tif, "w", **profile) as dest:
    dest.write(lulc_array,1)

print(f"✅ Clipped GeoTIFF saved to: {lulc_clipped_tif}")

# Class label and color map
label_map = {
    0: 'Built-up',
    1: 'Cropland',
    2: 'Vegetation',
    3: 'Water',
    4: 'Wetland'
}

color_map = {
    0: '#e31a1c',  # Red
    1: '#b2df8a',  # Light green
    2: '#33a02c',  # Forest green
    3: '#1f78b4',  # Blue
    4: '#a6cee3'   # Cyan
}

# Create lookup table
lut = np.zeros((256, 3), dtype=np.uint8)
for cls, hex_color in color_map.items():
    lut[cls] = [int(hex_color[i:i+2], 16) for i in (1, 3, 5)]
lut[255] = [255, 255, 255]  # NoData = white

# Convert to RGB image
rgb_image = lut[lulc_array]

# Plot and save
plt.figure(figsize=(10, 8))
plt.imshow(rgb_image)
plt.title('Barishal Land Cover Classification (Clipped)', fontsize=14)
plt.axis('off')

# Create legend
legend_patches = [mpatches.Patch(color=color_map[k], label=label_map[k]) for k in label_map]
plt.legend(handles=legend_patches, loc='lower center', bbox_to_anchor=(0.5, -0.1),
           ncol=3, frameon=False)

plt.tight_layout()
plt.savefig(lulc_clipped_png, dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ PNG map saved to: {lulc_clipped_png}")