In [21]:
import warnings

import ipywidgets as widgets
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import display

from allocation_forecast import (create_station_location_mapping,
                                 spatial_demand_extraction)
from utils import load_data, preprocess

warnings.filterwarnings('ignore')

In [12]:
# Load the data and preprocess
df = load_data(path='../data/raw/')
df = preprocess(df)
# Extract the spatial demand
demand = spatial_demand_extraction(df)
station_location_dict = create_station_location_mapping(df)
# Map both lat and lng using the dictionary of tuples
demand['latitudes'] = demand['station_name'].map(
    lambda x: station_location_dict.get(x, (None, None))[0])
demand['longitudes'] = demand['station_name'].map(
    lambda x: station_location_dict.get(x, (None, None))[1])


In [22]:
# Create an interactive plot to visualize the net bike count by station
# location

# Define the feature to use for the slider
feature = 'hour'
max_slider = demand[feature].nunique() - 1

# Define fixed longitude, latitude, and net bike ranges
timelapsed_averaged = demand.groupby(['station_name', 'latitudes', 'longitudes', feature])[
    'net_bikes'].mean().reset_index()
lon_min, lon_max = timelapsed_averaged['longitudes'].min(
), timelapsed_averaged['longitudes'].max()
lat_min, lat_max = timelapsed_averaged['latitudes'].min(
), timelapsed_averaged['latitudes'].max()
net_bike_min, net_bike_max = timelapsed_averaged['net_bikes'].min(
), timelapsed_averaged['net_bikes'].max()

# Create a Normalize instance to fix color scaling
norm = mcolors.Normalize(vmin=-20, vmax=20)

# Define the function to update the plot based on the selected hour


def update_plot(time):
    # Clear the current figure
    plt.figure(figsize=(10, 8))

    # Filter the data for the selected hour
    data = timelapsed_averaged[timelapsed_averaged[feature] == time]

    # Create the scatter plot with consistent size and color scaling
    sns.scatterplot(
        x='longitudes',
        y='latitudes',
        size='net_bikes',             # Use net_bikes directly for size
        # Fixed min and max point sizes across all hours
        sizes=(10, 100),
        hue='net_bikes',              # Color based on net_bikes for consistent color scaling
        data=data,
        palette='coolwarm',           # Use a consistent palette
        alpha=0.6,
        legend='brief',
        hue_norm=norm                 # Apply the Normalize instance for consistent color scaling
    )

    # Set fixed axis limits
    plt.xlim(lon_min, lon_max)
    plt.ylim(lat_min, lat_max)

    # Set plot labels and title for each frame
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title(f"Net Bike Count by Station Location - {feature}: {time}")
    # plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap='coolwarm'), label='Net Bike Count')
    plt.show()


# Create an interactive slider widget
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=max_slider,
    step=1,
    description=f'{
        feature.capitalize()}')

# Use the `interact` function to update the plot when the slider is changed
widgets.interactive(update_plot, time=slider)

interactive(children=(IntSlider(value=0, description='Hour', max=23), Output()), _dom_classes=('widget-interac…