In [None]:
from paraview.simple import *
import numpy as np
import pandas as pd
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import cuml
from ipywidgets import fixed, interact, interactive, VBox, HBox
from ipyparaview.widgets import PVDisplay
from ipyparaview.widgets import PVRenderer
import ipywidgets as widgets
import PVGeo

In [None]:
X, Y = load_digits(return_X_y=True)

In [None]:
def plotDigit(linear_img, N=8, label=None):
    img2d = linear_img.reshape((N,N))
    plt.imshow(img2d, cmap='binary')
    if label is not None:
        plt.title("Digit: '{}'".format(str(label)))
    plt.show()
    
k=10
plotDigit(X[k, :], label=Y[k])

In [None]:
%%time
umap = cuml.manifold.UMAP(n_components=3)
X3 = umap.fit_transform(X, Y)

In [None]:
Xdf = pd.DataFrame(X3, columns=["X", "Y", "Z"])
Xdf["color"] = Y

In [None]:
renderView1 = CreateView('RenderView')
renderView1.AxesGrid.Visibility = 1
renderView1.AxesGrid.ShowGrid = 1
renderView1.ViewSize = [640, 480]
M3p = TrivialProducer()

M3p.GetClientSideObject().SetOutput(PVGeo.interface.data_frame_to_table(Xdf))
M3p.UpdatePipeline()
tableToPoints1 = TableToPoints(Input=M3p)
tableToPoints1.XColumn = 'X'
tableToPoints1.YColumn = 'Y'
tableToPoints1.ZColumn = 'Z'
colorLUT = GetColorTransferFunction('color')
colorLUT.ApplyPreset('jet', True)
colorLUT.RescaleTransferFunction(0.0, 9.0)
tableToPoints1Display = Show(tableToPoints1, renderView1)
tableToPoints1Display.RenderPointsAsSpheres = 1
tableToPoints1Display.Representation = 'Surface'
tableToPoints1Display.ColorArrayName = ['POINTS', 'color']
tableToPoints1Display.LookupTable = colorLUT
tableToPoints1Display.RenderPointsAsSpheres = 1
tableToPoints1Display.PointSize = 6.0

def update_data(new_value):
    data = new_value
    M3p.GetClientSideObject().SetOutput(PVGeo.interface.data_frame_to_table(new_value))
    M3p.UpdatePipeline()
    tableToPoints1.Input = M3p


In [None]:
renderer = PVRenderer(renderView1)
pv_widget = PVDisplay(renderer)

In [None]:
def rayTracing(use_rt):
    if use_rt:
        renderView1.EnableRayTracing = 1
        renderView1.BackEnd = 'OptiX pathtracer'
        renderView1.SamplesPerPixel = 7
    else:
        renderView1.EnableRayTracing = 0
        
def minDist(min_dist):
    umap = cuml.manifold.UMAP(n_components=3, min_dist=min_dist)
    X3 = umap.fit_transform(X, Y)
    Xdf = pd.DataFrame(X3, columns=["X", "Y", "Z"])
    Xdf["color"] = Y
    update_data(Xdf)

In [None]:
VBox([pv_widget,
      HBox([interactive(rayTracing, use_rt=widgets.Checkbox(value=False)),
            interactive(minDist, min_dist=widgets.FloatSlider(value=0.1, min=0.01, max=2, 
                                                              continuous_update=False))])])
      