# Working code

In [20]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from ipywidgets import interact
import ipywidgets as widgets

# Load the dataset
file_path = 'Updated_Merged_Data.csv'
data = pd.read_csv(file_path)

# List of remote sensing variables
remote_sensing_vars = [col for col in data.columns if '_B' in col]

# List of average variables
avg_vars = ['B2_AVG', 'B3_AVG', 'B4_AVG', 'B8_AVG', 'B8A_AVG', 'B11_AVG', 'B12_AVG']

def plot_data(plot_type, x_var, y_var, log_x, log_y, window, filter_cloud, cloud_band, cloud_threshold, custom_var_expr, drop_zero):
    plt.figure(figsize=(15, 6))
    
    # Evaluate the custom variable expression if provided
    if custom_var_expr:
        try:
            data['custom_var'] = eval(custom_var_expr, {"data": data, "np": np})
            x_var = 'custom_var'
        except Exception as e:
            print(f"Error in custom variable expression: {e}")
            return
    
    # Determine if the selected variable is a remote sensing variable or average variable
    is_remote_sensing = any(var in remote_sensing_vars for var in [x_var, y_var])
    is_avg = any(var in avg_vars for var in [x_var, y_var])
    
    # Always include the timestamp column for time series plotting
    cols_to_select = ['timestamp_sentinel2', x_var]
    if y_var != 'None':
        cols_to_select.append(y_var)
    if is_remote_sensing or is_avg:
        if is_avg:
            cs_col = 'cs_AVG'
            cdf_col = 'cs_cdf_AVG'
        else:
            point_num = x_var.split('_')[0].replace('point', '')
            cs_col = f'point{point_num}_cs'
            cdf_col = f'point{point_num}_cs_cdf'
        
        # Determine the appropriate column based on cloud_band selection
        if cloud_band == 'cs':
            filter_col = cs_col
        else:
            filter_col = cdf_col
        
        cols_to_select.append(filter_col)
        
        # Filter out rows with NaN values in the selected columns
        plot_data = data[cols_to_select].dropna()
        
        # Apply cloud score filtering if enabled
        if filter_cloud:
            plot_data = plot_data[plot_data[filter_col] >= cloud_threshold]
    else:
        # Filter out rows with NaN values in the selected columns without cloud score filtering
        plot_data = data[cols_to_select].dropna()
    
    # Drop zero-valued data points if enabled
    if drop_zero:
        plot_data = plot_data[plot_data[x_var] != 0]
        if y_var != 'None':
            plot_data = plot_data[plot_data[y_var] != 0]
    
    # Add a small positive constant to avoid log(0) or negative values
    epsilon = 1e-10
    
    # Take the logarithm of the values if enabled
    if log_x:
        plot_data[x_var] = np.log(plot_data[x_var] + epsilon)
    if y_var != 'None' and log_y:
        plot_data[y_var] = np.log(plot_data[y_var] + epsilon)
    
    if plot_type == 'Scatter Plot':
        if y_var != 'None':
            plt.scatter(plot_data[x_var], plot_data[y_var], alpha=0.5)
            plt.xlabel(x_var)
            plt.ylabel(y_var)
            
            # Calculate and display the correlation
            correlation = plot_data[[x_var, y_var]].corr().iloc[0, 1]
            plt.title(f'{plot_type} of {x_var} and {y_var}\nCorrelation: {correlation:.2f}')
        else:
            plt.scatter(plot_data.index, plot_data[x_var], alpha=0.5)
            plt.xlabel('Index')
            plt.ylabel(x_var)
            plt.title(f'{plot_type} of {x_var}')
        
    elif plot_type == 'Time Series Plot':
        if y_var != 'None':
            fig, ax1 = plt.subplots(figsize=(15, 6))
            color = 'tab:blue'
            ax1.set_xlabel('Time')
            ax1.set_ylabel(x_var, color=color)
            ax1.plot(plot_data['timestamp_sentinel2'], plot_data[x_var].rolling(window=window).mean(), color=color, label=f'{x_var} (Moving Average)')
            ax1.tick_params(axis='y', labelcolor=color)

            ax2 = ax1.twinx()
            color = 'tab:red'
            ax2.set_ylabel(y_var, color=color)
            ax2.plot(plot_data['timestamp_sentinel2'], plot_data[y_var].rolling(window=window).mean(), color=color, label=f'{y_var} (Moving Average)')
            ax2.tick_params(axis='y', labelcolor=color)
            
            fig.legend(loc='upper right')
            plt.title(f'{plot_type} of {x_var} and {y_var}')
        else:
            plt.figure(figsize=(15, 6))
            plt.plot(plot_data['timestamp_sentinel2'], plot_data[x_var].rolling(window=window).mean(), label=f'{x_var} (Moving Average)')
            plt.xlabel('Time')
            plt.ylabel(x_var)
            plt.legend()
            plt.title(f'{plot_type} of {x_var}')
            
        # Adjust x-ticks to plot fewer of them, showing year and month only
        ax1.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
        plt.gcf().autofmt_xdate()
            
    plt.grid(True)
    plt.show()

variables = data.columns.tolist()
if 'timestamp_sentinel2' in variables:
    variables.remove('timestamp_sentinel2')

interact(plot_data, 
         plot_type=widgets.RadioButtons(options=['Scatter Plot', 'Time Series Plot'], description='Plot Type'),
         x_var=widgets.Dropdown(options=variables, description='X Variable'),
         y_var=widgets.Dropdown(options=['None'] + variables, description='Y Variable'),
         log_x=widgets.Checkbox(value=False, description='Log X Variable'),
         log_y=widgets.Checkbox(value=False, description='Log Y Variable'),
         window=widgets.IntSlider(value=5, min=1, max=50, step=1, description='Moving Avg Window'),
         filter_cloud=widgets.Checkbox(value=False, description='Filter by Cloud Score'),
         cloud_band=widgets.RadioButtons(options=['cs', 'cs_cdf'], description='Cloud Band'),
         cloud_threshold=widgets.FloatSlider(value=0.1, min=0.0, max=1.0, step=0.01, description='Cloud Threshold'),
         custom_var_expr=widgets.Text(value='', description='Custom Var Expr'),
         drop_zero=widgets.Checkbox(value=False, description='Drop Zero Values')
);


interactive(children=(RadioButtons(description='Plot Type', options=('Scatter Plot', 'Time Series Plot'), valu…