In [None]:
!pip install uv
!uv pip install -r  requirements.txt
!uv pip install streamlit
!uv pip install -U ipywidgets
!uv pip install shap snowflake-ml-python==1.19.0

In [None]:
#Update this VERSION_NUM to version your features, models etc!
VERSION_NUM = '0'
DB = "EY_DATA_CHALLENGE" 
SCHEMA = "DATA_SCHEMA" 
ROLE ="ACCOUNTADMIN"

In [None]:
import pandas as pd
import numpy as np
import sklearn
import math
import pickle
import shap
from datetime import datetime
import streamlit as st
from xgboost import XGBClassifier

# Snowpark ML
from snowflake.ml.registry import Registry
from snowflake.ml.modeling.tune import get_tuner_context
from snowflake.ml.modeling import tune
from entities import search_algorithm

#Snowflake feature store
from snowflake.ml.feature_store import FeatureStore, FeatureView, Entity, CreationMode

# Snowpark session
from snowflake.snowpark import DataFrame
from snowflake.snowpark.functions import col, to_timestamp, min, max, month, dayofweek, dayofyear, avg, date_add, sql_expr,year,quarter,date_trunc
from snowflake.snowpark.types import IntegerType
from snowflake.snowpark import Window

#setup snowpark session
from snowflake.snowpark.context import get_active_session
session = get_active_session()

session.use_database(DB)
session.use_schema(SCHEMA)
session

In [None]:
# Supress Warnings 
import warnings
warnings.filterwarnings('ignore')

# Import common GIS tools
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

# Import Planetary Computer tools
import pystac_client
import planetary_computer as pc
from odc.stac import stac_load

In [None]:
# Sample region in South Africa
# Contains Water Quality Sample Site #184 and #186 on Wilge River

#lat_long = (-26.510278, 28.351389) # Lat-Lon centroid location
#lat_long = (-28.760833,17.730278) # Sample data set validation 
lat_long = (-26.9847222,26.63227778)
box_size_deg = 0.15 # Surrounding box in degrees

In [None]:
# Calculate the Lat-Lon bounding box region
min_lon = lat_long[1]-box_size_deg/2
min_lat = lat_long[0]-box_size_deg/2
max_lon = lat_long[1]+box_size_deg/2
max_lat = lat_long[0]+box_size_deg/2
print (f"min_lon:{min_lon}")
print (f"min_lat:{min_lat}")
print (f"max_lon:{max_lon}")
print (f"max_lat:{max_lat}")
bounds = (min_lon, min_lat, max_lon, max_lat)
print(f"bounds:{bounds}")

In [None]:
# Define the time window
time_window="2011-01-01/2015-12-31"


In [None]:
stac = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")
search = stac.search(
    collections=["landsat-c2-l2"], 
    bbox=bounds, 
    datetime=time_window,
    query={"platform": {"in": ["landsat-7", "landsat-8"]}, "eo:cloud_cover": {"lt": 10}},
)
items = list(search.get_all_items())
print('This is the number of scenes that touch our region:',len(items))


Next, we'll load the data into an xarray DataArray using the Open Data Cube (ODC) STAC odc-stac library that is included with the Planetary Computer. The ODC odc is an open source geospatial data management and analysis software project that is used globally for many projects (e.g., Digital Earth Africa). The ODC-STAC code will load the selected items from the catalog search, select the desired spectral bands, including the "qa_pixel" cloud filtering band, reproject into Lat-Lon coordinates (EPSG:4326) at 30-meters resolution (typical of Landsat pixel resolution), and clip the region to the spatial bounding box.

In [None]:
# Define the pixel resolution for the final product
# Define the scale according to our selected crs, so we will use degrees
resolution = 30  # meters per pixel 
scale = resolution / 111320.0 # degrees per pixel for CRS:4326 

In [None]:
xx = stac_load(
    items,
    bands=["red", "green", "blue", "nir08", "swir16", "swir22", "qa_pixel"],
    crs="EPSG:4326", # Latitude-Longitude
    resolution=scale, # Degrees
    chunks={"x": 2048, "y": 2048},
    patch_url=pc.sign,
    bbox=bounds
)

In [None]:
# Apply scaling and offsets for Landsat Collection-2 (reference below) to the spectral bands ONLY
# https://planetarycomputer.microsoft.com/dataset/landsat-c2-l2
xx['red'] = (xx['red']*0.0000275)-0.2
xx['green'] = (xx['green']*0.0000275)-0.2
xx['blue'] = (xx['blue']*0.0000275)-0.2
xx['nir08'] = (xx['nir08']*0.0000275)-0.2
xx['swir16'] = (xx['swir16']*0.0000275)-0.2
xx['swir22'] = (xx['swir22']*0.0000275)-0.2

In [None]:
# View the dimensions of our XARRAY and the variables
display(xx)

In [None]:
# Test that packages work together
import xarray as xr
import dask
import dask.array as da
import zarr
import numpy as np

print("Package Versions:")
print(f"  xarray: {xr.__version__}")
print(f"  dask: {dask.__version__}")
print(f"  zarr: {zarr.__version__}")

# Test dask chunk manager
print("\nTesting dask chunk manager...")
try:
    test_array = da.ones((10, 10), chunks=(5, 5))
    test_xr = xr.DataArray(test_array)
    print(f"✓ Dask chunks working: {test_xr.chunks}")
except Exception as e:
    print(f"✗ Error: {e}")

print("\n✓ Ready to use!")


In [None]:
plot_xx = xx[["red","green","blue"]].to_array()
plot_xx.plot.imshow(col='time', col_wrap=4, robust=True, vmin=0, vmax=0.3)
plt.show()

In [None]:
data = xx[["red","green","blue"]].to_array(name='value')
data = data.compute()  # Force computation if using dask
display(data)
df = data.to_dataframe().unstack()
print(df)
#df.to_csv('output.csv')



In [None]:
import pandas as pd

# Get the array and reshape as needed
#data = xx[["red","green","blue"]].to_array()
time_slice = 1
print(xx[["red","green","blue"]])


In [None]:
 
result = xx.sel(
    time=xx.time[0],           # First time step
    latitude=-26.5,             # Your desired latitude
    longitude=28.3,             # Your desired longitude
    method="nearest"            # Find closest coordinate
)

print("Red:", result["red"].values)
print("Green:", result["green"].values)
print("Blue:", result["blue"].values)

In [None]:
print(f"Total time dimensions: {len(xx.time)}")
print(f"Total latitude values: {len(xx.latitude)}")
print(f"Total longitude values: {len(xx.longitude)}")

In [None]:
# Drop the table first
session.sql("DROP TABLE IF EXISTS LANDSAT_FETCH").collect()
session.sql("DROP TABLE IF EXISTS LANDSAT_FETCH_TEMP").collect()

# Convert to dataframe
df = xx[["red","green","blue"]].to_dataframe()
df_reset = df.reset_index()

# Convert time to STRING format that Snowflake can parse
df_reset['time'] = pd.to_datetime(df_reset['time']).dt.strftime('%Y-%m-%d %H:%M:%S.%f')

# Rename and clean up
df_reset = df_reset.rename(columns={'time': 'OBS_TIME'})
df_reset = df_reset.drop(columns=['spatial_ref'], errors='ignore')
df_reset.columns = df_reset.columns.str.upper()

print("Dtypes after string conversion:")
print(df_reset.dtypes)
print("\nSample:")
print(df_reset.head())

# Write to temporary table as strings
session.write_pandas(
    df_reset, 
    table_name="LANDSAT_FETCH_TEMP",
    auto_create_table=True,
    overwrite=True
)

# Create final table with proper TIMESTAMP and cast
session.sql("""
    CREATE OR REPLACE TABLE LANDSAT_FETCH AS
    SELECT 
        TO_TIMESTAMP(OBS_TIME, 'YYYY-MM-DD HH24:MI:SS.FF') AS OBS_TIME,
        LATITUDE,
        LONGITUDE,
        RED,
        GREEN,
        BLUE
    FROM LANDSAT_FETCH_TEMP
""").collect()

# Drop temp table
session.sql("DROP TABLE LANDSAT_FETCH_TEMP").collect()



In [None]:
# Verify
session.sql("DESCRIBE TABLE LANDSAT_FETCH").show()
session.sql("SELECT OBS_TIME, LATITUDE, LONGITUDE,RED,GREEN,BLUE FROM LANDSAT_FETCH LIMIT 5").show()

In [None]:
# Loop through first 10 pixels only
for t in range(2):  # First 2 time steps
    for lat in range(5):  # First 5 latitudes
        for lon in range(5):  # First 5 longitudes
            time_val = xx.time[t].values
            lat_val = xx.latitude[lat].values
            lon_val = xx.longitude[lon].values
            red_val = xx['red'].isel(time=t, latitude=lat, longitude=lon).values
            green_val = xx['green'].isel(time=t, latitude=lat, longitude=lon).values
            blue_val = xx['blue'].isel(time=t, latitude=lat, longitude=lon).values
            
            print(f"T: {time_val}, Lat: {lat_val:.4f}, Lon: {lon_val:.4f}, "
                  f"R: {red_val:.2f}, G: {green_val:.2f}, B: {blue_val:.2f}")
    


In [None]:
#Create a dict with keys for feature names and values containing transform code

feature_eng_dict = dict()

#Timstamp features
feature_eng_dict["MONTH"] = month("SAMPLE_DATE")
feature_eng_dict["QUARTER"] = quarter("SAMPLE_DATE") 
feature_eng_dict["YEAR"] = year("SAMPLE_DATE") 
feature_eng_dict["QUARTER_DATE"] = date_trunc("quarter", col("SAMPLE_DATE"))

##Spectral Indexes

#NDMI (Normalized Difference Moisture Index) - Useful for detecting wetland conditions affecting water quality
#	Formula: (NIR - SWIR) / (NIR + SWIR)
#	Measures water content in vegetation and soil moisture

#MNDWI (Modified Normalized Difference Water Index) -Better for turbid water identification
#	Formula: (Green - SWIR) / (Green + SWIR)
#	Enhances water body detection, suppresses soil/vegetation noise

#1. NDWI: (Green - NIR) / (Green + NIR) - Water body delineation
#2. NDTI: (Red - Green) / (Red + Green) - Turbidity measurement

## EC Indexes 
#3. Salinity Index: (Red - NIR) / (Red + NIR)
#4. Band ratios: Blue/Red, SWIR/NIR combinations

## Useful Indices for DRP Prediction:

#5.Chlorophyll Index (CI): (NIR/Red) - 1, indicates algae from phosphorus
feature_eng_dict["CI"] = col("NIR") / col()
