# UMP interactive scatter plot
check 2columns relations.

## Import modules

In [None]:
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interactive_output
from ipywidgets import Text, HBox, VBox, Select, fixed


## Prepare Dataset.

In [None]:
train_data = pd.read_parquet('../input/ubiquant-parquet/train_low_mem.parquet')
train_data = train_data.drop('row_id', axis=1)
train_data.loc[:, ['time_id', 'investment_id']] = train_data.loc[:, ['time_id', 'investment_id']].astype(np.int16)
train_data = train_data.sort_values(['time_id', 'investment_id'], ascending=True)


## Defined Interactive Scatter Plot

In [None]:
def scatter_hist(plot_data, col_x, col_y, hue=None, title=None):
    fig, axes = plt.subplots(2, 2, figsize=(10, 10), facecolor='white',gridspec_kw={'width_ratios': [9, 1],'height_ratios': [1, 9]})
    sns.scatterplot(data=plot_data, x=col_x, y=col_y, ax=axes[1, 0], hue=hue)
    sns.histplot(data=plot_data, x=col_x, ax=axes[0, 0], hue=hue, kde=True)
    sns.histplot(data=plot_data, y=col_y,ax=axes[1, 1], hue=hue, kde=True)
                             
    axes = axes.ravel()
    for i in range(len(axes)):
        axes[i].grid(True)
        if i != 2:
            axes[i].tick_params(length=0)
            axes[i].xaxis.set_visible(False)
            axes[i].yaxis.set_visible(False)
            for loc in ['top', 'bottom', 'left', 'right']:
                axes[i].spines[loc].set_visible(False)
    if type(title) is str:
        fig.suptitle(title)
    fig.tight_layout(rect=[0, 0, 0.96, 1])
    plt.close()
    return fig


def scatter_groupby(dataframe, col_x, agg_x, col_y, agg_y, groupby):
    target_meam_by_time_id = dataframe.loc[:, ['time_id', 'investment_id', col_y]].groupby(groupby).agg(agg_y).copy()
    feats_nuni_by_time_id = dataframe.loc[:, ['time_id', 'investment_id', col_x]].groupby(groupby).agg(agg_x).copy()
    plot_data = pd.concat([target_meam_by_time_id, feats_nuni_by_time_id], axis=1)
    title = f'X: {col_x}({agg_x}) Y: {col_y}({agg_y}) / groupby: {groupby}'
    fig = scatter_hist(plot_data, col_x, col_y, title=title)
    display(fig)
                  

In [None]:
data = fixed(train_data)
s_x = Select(description='X', options=train_data.columns[2:], value='f_62', rows=4,)
s_y = Select(description='Y', options=train_data.columns[2:], value='target', rows=4,)
s_aggx = Select(description='X_aggregate', options=['mean', 'std', 'nunique', 'min', 'max'], value='nunique', rows=4,)   
s_aggy = Select(description='Y_aggregate', options=['mean', 'std', 'nunique', 'min', 'max'], value='mean', rows=4,)   
s_gby = Select(description='groupby', options=['time_id', 'investment_id'], value='time_id', rows=2,) 

selector = HBox([s_x, s_aggx, s_y, s_aggy, s_gby])
plot_output = interactive_output(scatter_groupby, dict(dataframe=data, col_x=s_x, agg_x=s_aggx, col_y=s_y, agg_y=s_aggy, groupby=s_gby))
display(selector, plot_output)