In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle

from tqdm import tqdm

import numpy as np
import scipy.interpolate

from sklearn.preprocessing import KBinsDiscretizer

import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from modules.utils.general_utils.utilities import group_wise_binning

from plotly.subplots import make_subplots
import plotly.graph_objects as go

KeyboardInterrupt: 

In [None]:
def interpolate_paths(z, x, y, c, rep_id):
    """Interpolate lines
    """
    INTERP_KIND = {2:"linear", 3:"quadratic", 4:"cubic"}
    
    consecutive_year_blocks = np.where(np.diff(z) != 1)[0] + 1
    z_blocks = np.split(z, consecutive_year_blocks)
    x_blocks = np.split(x, consecutive_year_blocks)
    y_blocks = np.split(y, consecutive_year_blocks)
    c_blocks = np.split(c, consecutive_year_blocks)

    paths = []

    for block_idx, zs in enumerate(z_blocks):

        if len(zs) > 1:
            kind = INTERP_KIND.get(len(zs), "cubic")
        else:
            paths.append(
                (zs, x_blocks[block_idx], y_blocks[block_idx], c_blocks[block_idx])
            )
            continue

        z = np.linspace(np.min(zs), np.max(zs), 100)
        x = scipy.interpolate.interp1d(zs, x_blocks[block_idx], kind=kind)(z)
        y = scipy.interpolate.interp1d(zs, y_blocks[block_idx], kind=kind)(z)
        c = scipy.interpolate.interp1d(zs, c_blocks[block_idx], kind=kind)(z)

        paths.append((z, x, y, c))

    return paths

In [None]:
with (open('results\\saved_data_containers\\melchior.pkl', 'rb')) as container:
    DATA_CONTAINER = pickle.load(container)
    
predictions = DATA_CONTAINER['prediction_ds']['tar_activity']
contexts = DATA_CONTAINER['context']

predictions = [predictions[i] for i in range(5)]
contexts = [contexts[i] for i in range(5)]
predictions = np.hstack(group_wise_binning(predictions, n_bins=100, grouper=contexts))

In [None]:
df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')
df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')
df['Future Session Activity'] = predictions
df = df[df['context'] == 6]

users = df.groupby(['user_id'])['session'].max() +1
users = users.reset_index()
users = users[users['session'] == 4]['user_id'].values

df = df[df['user_id'].isin(users)]

In [None]:
zoom=3.5
fig = make_subplots(
    rows=1, 
    cols=3, 
    specs=[
        [{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]
    ],
    subplot_titles=(
        'Low Variability',
        'Medium Variability',  
        'High Variability'
    ),
    horizontal_spacing = 0.01,
    vertical_spacing = 0.05,
    shared_xaxes=True,
    shared_yaxes=True
)

locations = [
    (1, 1),
    (1, 2),
    (1, 3)
]
for index, rank in enumerate([0, 4, 8]):
    
    location = locations[index]
    
    unique_ids = variability_rank[variability_rank['rank'] == rank]['user_id'].values
    unique_ids = np.random.choice(unique_ids, 700, replace=False)
    traces = []
    
    for unique_id in unique_ids:
        
        z = df.session[df.user_id == unique_id].values
        x = df.UMAP_1[df.user_id == unique_id].values
        y = df.UMAP_2[df.user_id == unique_id].values
        c = df['Future Session Activity'][df.user_id == unique_id]

        for z, x, y, c in interpolate_paths(z, x, y, c, unique_id):

            trace = go.Scatter3d(
                x=x, y=z, z=y,
                mode='lines',
                line=dict(
                    color=c,
                    cmin=0,
                    cmid=50,
                    cmax=100,
                    cauto=False,
                    colorscale='RdBu',
                    colorbar=dict(),
                    width=0.6,
                ),
                opacity=0.6,
            )
            fig.add_trace(trace, row=location[0], col=location[1])

fig.update_layout(
    width=1050,
    height=500,
    margin=dict(r=1, l=1),
    showlegend=False,
    autosize=False,
    template="plotly_white",
)
fig.update_layout(
    scene_aspectmode='manual',
    scene_aspectratio=dict(x=1, y=3, z=1),
    scene_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)
fig.update_layout(
    scene2_aspectmode='manual',
    scene2_aspectratio=dict(x=1, y=3, z=1),
    scene2_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)
fig.update_layout(
    scene3_aspectmode='manual',
    scene3_aspectratio=dict(x=1, y=3, z=1),
    scene3_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)

fig.update_scenes(
    xaxis_title_text='UMAP 1',  
    zaxis_title_text='UMAP 2',  
    yaxis_title_text=r"$t$",
    yaxis = dict(
        tickmode = 'array',
        tickvals = [0, 1, 2, 3],
        ticktext = [1, 2, 3, 4]
    )
)

fig.show()