# Notebook 4: Ray tracing and interactive data visualization on the GPU

Here we illustrate how one can build custom widgets for interactive data visualization.
We use the Paicos CUDA GPU implementation of ray tracing to achieve the necessary speed.

In [1]:
%matplotlib widget

## Load and check required packages
This example notebook requires that you have the following python GPU packages installed:

In [2]:
import cupy as cp
from numba import cuda

## Do a manual pre-selection
We start by loading a snapshot and selecting just a part of it.
We limit ourselves because the GPU has limited memory and the GPU ray
tracer class builds a binary tree spanning the entire snapshot.

In [3]:
import paicos as pa
import numpy as np

pa.use_units(True)

snapnum = 247
# snap = pa.Snapshot(pa.data_dir, snapnum)
snap = pa.Snapshot('/lustre/astro/berlok/zoom-simulations-new-ics/halo_0003/tng/zoom12_ics_v1/output', snapnum)
center = snap.Cat.Group['GroupPos'][0]
R200c = snap.Cat.Group['Group_R_Crit200'][0]
widths = np.array([10000, 10000, 10000]) * R200c.uq

# Create subset of snapshot
index = pa.util.get_index_of_radial_range(snap['0_Coordinates'], center, 0, np.max(widths)*np.cbrt(3))
snap = snap.select(index, parttype=0)

# Pixel dimensions of image
nx = ny = 1024

## Initialize the GPU projector

Here we use a Paicos orientation class to initialize the view such that the 
width of the image is along the $x$-coordinate of the simulation and the height
of the image is along the $y$-coordinate. The depth of the image is in the $z$-direction.

The orientation class has methods for rotating the view around $x$, $y$, and $z$
or around the axes of its local coordinate system. When an orientation
instance has been passed to an ImageCreator (such as the projector below),
then calling these methods will result in a rotation around the center of
the image.

In [4]:
orientation = pa.Orientation(normal_vector=[0, 0, 1], perp_vector1=[1, 0, 0])
projector = pa.GpuRayProjector(snap, center, widths, orientation, npix=nx, threadsperblock=8, do_pre_selection=False)

Attempting to get derived variable: 0_Volume...	[DONE]



In [5]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import ipywidgets as widgets

## Some code for sorting FoF and subfind catalogues

This is just for plotting the 20 most massive FoF or subhalos that are inside the projection cube.
Mainly for the widget so no need to read this.

In [6]:
def get_group_and_sub_indices():
    info = {}
    if hasattr(projector.snap.Cat, 'Sub'):
        sub_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Sub['SubhaloPos'],
                                                        projector.center, projector.widths,
                                                        projector.snap.Cat.Sub['SubhaloHalfmassRad'],
                                                        projector.snap.box, projector.orientation)

        info['Subhalo_ids'] = np.arange(sub_in_region_bool.shape[0])[sub_in_region_bool]
        info['SubhaloPos'] = projector.snap.Cat.Sub['SubhaloPos'][sub_in_region_bool]
        info['SubhaloHalfmassRad'] = projector.snap.Cat.Sub['SubhaloHalfmassRad'][sub_in_region_bool]
        info['SubhaloMass'] = projector.snap.Cat.Sub['SubhaloMass'][sub_in_region_bool]
        # Sort according to mass
        sort_index = np.argsort(info['SubhaloMass'])[::-1]
        info['Subhalo_ids'] =  info['Subhalo_ids'][sort_index]
        info['SubhaloPos'] =  info['SubhaloPos'][sort_index]
        info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'][sort_index]
        info['SubhaloMass'] =  info['SubhaloMass'][sort_index]
        
    if hasattr(projector.snap.Cat, 'Group'):
        group_in_region_bool = pa.util.get_index_of_rotated_cubic_region_plus_thin_layer(projector.snap.Cat.Group['GroupPos'],
                                                projector.center, projector.widths,
                                                projector.snap.Cat.Group['Group_R_Crit200'],
                                                projector.snap.box, projector.orientation)
        info['Group_ids'] = np.arange(group_in_region_bool.shape[0])[group_in_region_bool]
        info['GroupPos'] = projector.snap.Cat.Group['GroupPos'][group_in_region_bool]
        info['Group_R_Crit200'] = projector.snap.Cat.Group['Group_R_Crit200'][group_in_region_bool]
        info['Group_M_Crit200'] = projector.snap.Cat.Group['Group_M_Crit200'][group_in_region_bool]
        # Sort according to mass
        sort_index = np.argsort(info['Group_M_Crit200'])[::-1]
        info['Group_ids'] = info['Group_ids'][sort_index]
        info['GroupPos']  = info['GroupPos'] [sort_index]
        info['Group_R_Crit200'] = info['Group_R_Crit200'][sort_index]
        info['Group_M_Crit200'] = info['Group_M_Crit200'][sort_index]
    return info

## The interactive widget
The rather long code below defines an interactive ipython widget with a
number of hopefully mostly self-explanatory buttons.

These can zoom in/out, rotate the image, change width, height and depth etc.

Pressing the 'Recording' tick mark will output a png/hdf5 every time the
wiev changes (if those boxes are ticked). These will be saved in the
directory entered in the box just to the right of png tick box.
The recording also saves a .log file with a series of commands
that can be used to reproduce an interactive session.
This allows for using an interactive session as a starting point for
creating an animation of a simulation snapshot.

We have left all the code for the widget visible instead of saving it somewhere
else and importing it. We hope that it will in this way be easier for someone
to extend/modify the code to their own needs.

The data included with Paicos is very low resolution and this notebook does
not really showcase how well the GPU code works at higher resolution.
We have tried with a $12^3$ times better mass resolution simulation (equivalent
to the resolution in the TNG300 simulation, used in the version of this notebook
shown on readthedocs) and find that an A100 GPU is fast enough
to give a smooth user experience.

In [7]:

def update():
    proj = projector.project_variable(var_str.value)
    extent = projector.centered_extent

    if to_physical.value:
        proj = proj.to_physical
        extent = extent.to_physical

    if to_cgs.value:
        proj = proj.cgs
        extent = extent.cgs
    
    if to_astro_units.value:
        proj = proj.astro
        extent = extent.astro

    fig = plt.figure(1)
    plt.clf()

    # Deal with color limits and do plot
    if fix_climits.value:
        vmin.disabled = False
        vmax.disabled = False
        if vmin.value > 0 and vmax.value > 0:
            pass
        else:
            vmin.value = proj.value.min()
            vmax.value = proj.value.max()
        im = plt.imshow(proj.value, extent=extent.value, origin='lower',
                        norm=LogNorm(vmin.value, vmax.value), cmap=cmap_str.value)
    else:
        vmin.value = proj.value.min()
        vmax.value = proj.value.max()
        vmin.disabled = True
        vmax.disabled = True
        im = plt.imshow(proj.value, extent=extent.value, origin='lower',
                norm=LogNorm(), cmap=cmap_str.value)
    # Labels
    plt.xlabel(extent.label())
    plt.ylabel(extent.label())
    
    # Colorbar
    cb = plt.colorbar()
    cb.set_label(proj.label('\\mathrm{' + var_str.value.replace('_', '\_') + '}\,'))

    # Title
    title_str = f'Snapnum: {snapnum}, Age: {snap.age:1.2f}, Redshift: {snap.z:1.2f}'
    plt.title(title_str)

    # Add subs/groups
    # TODO: Get rid of mostly duplicate code for groups/subhalos
    select_center.disabled = True
    if hasattr(projector.snap, 'Cat'):
        info = get_group_and_sub_indices()

        if 'Group_ids' in info and show_groups.value:
            select_center.disabled = False
            orientation = projector.orientation
            points = info['GroupPos'] - projector.center
            points = np.matmul(orientation.inverse_rotation_matrix, points.T).T
            if to_physical.value:
                points = points.to_physical
                info['Group_R_Crit200'] = info['Group_R_Crit200'].to_physical
            if to_cgs.value:
                points = points.cgs
                info['Group_R_Crit200'] = info['Group_R_Crit200'].cgs
            if to_astro_units.value:
                points = points.astro
                info['Group_R_Crit200'] = info['Group_R_Crit200'].astro
                
            ax = plt.gca()
            options = []
            for ii in range(points.shape[0]):
                if ii >= 20 or info['Group_M_Crit200'][ii].value == 0:
                    break
                circ = plt.Circle((points[ii, 0].value, points[ii, 1].value),
                           info['Group_R_Crit200'][ii].value, color='k', fill=False)
                ax.add_patch(circ)
                plt.text(points[ii, 0].value, points[ii, 1].value, f"G{info['Group_ids'][ii]}", fontsize=6)
                options.append(f"G{info['Group_ids'][ii]}")

            select_center.options = list(options)

        if 'Subhalo_ids' in info and show_subs.value:
        
            select_center.disabled = False
            orientation = projector.orientation
            points = info['SubhaloPos'] - projector.center
            points = np.matmul(orientation.inverse_rotation_matrix, points.T).T
            if to_physical.value:
                points = points.to_physical
                info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].to_physical
            if to_cgs.value:
                points = points.cgs
                info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].cgs
            if to_astro_units.value:
                points = points.astro
                info['SubhaloHalfmassRad'] = info['SubhaloHalfmassRad'].astro
                
            ax = plt.gca()
            options = []
            for ii in range(points.shape[0]):
                if ii >= 20 or info['SubhaloMass'][ii].value == 0:
                    break
                circ = plt.Circle((points[ii, 0].value, points[ii, 1].value),
                           info['SubhaloHalfmassRad'][ii].value, color='k', fill=False)
                ax.add_patch(circ)
                plt.text(points[ii, 0].value, points[ii, 1].value, f"S{info['Subhalo_ids'][ii]}", fontsize=6)
                options.append(f"S{info['Subhalo_ids'][ii]}")

            select_center.options = list(options)

    if recording.value:
        if hdf5.value:
            image_file = pa.ArepoImage(projector, basedir=outfolder.value,
                               basename=f'{basename.value}_{var_str.value}_frame_{frame_counter.value}')

            image_file.save_image(var_str.value, proj)

            # Move from temporary filename to final filename
            image_file.finalize()
            if frame_counter.value == 0:
                mylogger(image_file.filename)
                mylogger(outfolder.value)
                mylogger(basename.value)
                mylogger(var_str.value)
                org_center = projector.center - projector._diff_center
                mylogger(f"org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}")

        if png.value:
            plt.savefig(f'{outfolder.value}/{basename.value}_{var_str.value}_frame_{frame_counter.value}_{projector.snap.snapnum}.png', dpi=700)
        frame_counter.value += 1
    
    plt.show()


def mylogger(string, line_ending='\n', mode='a'):
    with open(f'{outfolder.value}/{basename.value}.log', mode) as f:
        f.write(string + line_ending)

width = widgets.FloatSlider(value=projector.width.value, min=1, max=2*widths[0].value, step=1, description="width", continuous_update=False)
height = widgets.FloatSlider(value=projector.height.value, min=1, max=2*widths[1].value, step=1, description="height", continuous_update=False)
depth = widgets.FloatSlider(value=projector.depth.value, min=1, max=2*widths[2].value, step=1, description="depth", continuous_update=False)
zoom_slider = widgets.FloatSlider(value=1, min=0.1, max=10, step=0.1, description="Zoom factor")

center_horizontal = widgets.BoundedFloatText(
    value=0,
    min=-widths[0].value,
    max=widths[0].value,
    step=1,
    description='Horizontal',
    disabled=False
)

center_vertical = widgets.BoundedFloatText(
    value=0,
    min=-widths[0].value,
    max=widths[0].value,
    step=1,
    description='Vertical',
    disabled=False
)

center_depth = widgets.BoundedFloatText(
    value=0,
    min=-widths[0].value,
    max=widths[0].value,
    step=1,
    description='Depth',
    disabled=False
)

def click_step_button(b):
    if center_horizontal.value != 0:
        projector.move_center_along_perp_vector1(center_horizontal.value * projector.width.uq)
            
    if center_vertical.value != 0:
        projector.move_center_along_perp_vector2(center_vertical.value * projector.width.uq)
            
    if center_depth.value != 0:
        projector.move_center_along_normal_vector(center_depth.value * projector.width.uq)

    if recording.value:
        mylogger(f'move_center,{center_horizontal.value},{center_vertical.value},{center_depth.value}')
    update()

def press_reset_center_button(b):
    diff = projector._diff_center
    org_center = projector.center - diff
    if recording.value:
        mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
        mylogger(f"reset_center_to_org_center,{org_center[0].value},{org_center[1].value},{org_center[2].value}")

    projector.center = org_center
    projector._diff_center[:] = 0
    update()

step_button = widgets.Button(description="Move center")
recenter_button = widgets.Button(description="Reset center")
step_button.on_click(click_step_button)
recenter_button.on_click(press_reset_center_button)


## Create dropdown for parttype 0 only (TODO: Remove vectors and tensors from list, or add box for selecting component)
avail_list = []
for key in snap._auto_list:
    if key[0] == '0':
        avail_list.append(key)
var_str =  widgets.Dropdown(options=avail_list,
    value='0_Density',
    description='Field:',
)

select_center =  widgets.Dropdown(options=[],
    value=None,
    description='Center on:',
    disabled=True
)

# Cmap dropdown
cmap_str =  widgets.Dropdown(options=plt.colormaps(),
    value='viridis',
    description='Cmap:',
)

# vmin and vmax
vmin = widgets.FloatText(
    description='vmin:',
    disabled=True
)

vmax = widgets.FloatText(
    description='vmax:',
    disabled=True
)

fix_climits = widgets.Checkbox(
    value=False,
    description='Fix climits',
    disabled=False,
    indent=False
)

# 
button_left = widgets.Button(description="Pan left")
button_right = widgets.Button(description="Pan right")
button_up = widgets.Button(description="Pan up")
button_down = widgets.Button(description="Pan down")
button_clock_wise = widgets.Button(description="Clockwise")
button_anti_clock_wise = widgets.Button(description="Anti-clockwise")
step_size_in_degrees = widgets.FloatSlider(value=15, min=0, max=90, step=0.5, description="Step (Degrees)")

button_update = widgets.Button(description="Update")

def call_update(b):
    update()

button_update.on_click(call_update)

def call_double_resolution(b):
    projector.double_resolution
    if recording.value:
        mylogger('double_resolution')
    update()

def call_half_resolution(b):
    projector.half_resolution
    if recording.value:
        mylogger('half_resolution')
    update()

button_double = widgets.Button(description="Double res")
button_double.on_click(call_double_resolution)

button_half = widgets.Button(description="Half res")
button_half.on_click(call_half_resolution)

def call_zoom(b):
    zoom_button_was_pressed.value = True
    projector.zoom(zoom_slider.value)
    if recording.value:
        mylogger(f'zoom,{zoom_slider.value}')
        
    # Check that new widths are not completely unreasonable!
    # Do check

    width.value = projector.width.value
    height.value = projector.height.value
    update()
    zoom_button_was_pressed.value = False

button_zoom = widgets.Button(description="Zoom in/out")
button_zoom.on_click(call_zoom)
zoom_button_was_pressed = widgets.Checkbox(
    value=False, description='internal boolean for avoiding calling call_update twice when zooming',
)

to_physical = widgets.Checkbox(
    value=False,
    description='physical units',
    disabled=False,
    indent=False
)

to_cgs = widgets.Checkbox(
    value=False,
    description='cgs units',
    disabled=False,
    indent=False
)

to_astro_units = widgets.Checkbox(
    value=False,
    description="'astro' units",
    disabled=False,
    indent=False
)

to_physical.observe(call_update, names=['value'])
to_cgs.observe(call_update, names=['value'])
to_astro_units.observe(call_update, names=['value'])

def change_width(change):
    projector.width = width.value * projector.width.uq
    if not zoom_button_was_pressed.value:
        if recording.value:
            mylogger(f'width,{width.value}')
        update()

def change_height(change):
    projector.height = height.value * projector.height.uq
    if not zoom_button_was_pressed.value:
        if recording.value:
            mylogger(f'height,{height.value}')
        update()

def change_depth(change):
    projector.depth = depth.value * projector.depth.uq
    if recording.value:
        mylogger(f'depth,{depth.value}')
    update()

def change_var_str(change):
    if recording.value:
        mylogger(var_str.value)
    update()

def pan_left(b):
    projector.orientation.rotate_around_perp_vector2(degrees=step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_perp_vector2,{step_size_in_degrees.value}')
    update()

def pan_right(b):
    projector.orientation.rotate_around_perp_vector2(degrees=-step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_perp_vector2,{-step_size_in_degrees.value}')
    update()

def pan_up(b):
    projector.orientation.rotate_around_perp_vector1(degrees=step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_perp_vector1,{step_size_in_degrees.value}')
    update()

def pan_down(b):
    projector.orientation.rotate_around_perp_vector1(degrees=-step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_perp_vector1,{-step_size_in_degrees.value}')
    update()

def clock_wise(b):
    projector.orientation.rotate_around_normal_vector(degrees=step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_normal_vector,{step_size_in_degrees.value}')
    update()

def anti_clock_wise(b):
    projector.orientation.rotate_around_normal_vector(degrees=-step_size_in_degrees.value)
    if recording.value:
        mylogger(f'rotate_around_normal_vector,{-step_size_in_degrees.value}')
    update()

button_left.on_click(pan_left)
button_right.on_click(pan_right)

button_up.on_click(pan_up)
button_down.on_click(pan_down)

button_clock_wise.on_click(clock_wise)
button_anti_clock_wise.on_click(anti_clock_wise)

var_str.observe(call_update)


# Subhalo, groups
show_groups = widgets.Checkbox(
    value=False,
    description='Show FoF groups',
    disabled=False,
    indent=False
)

show_subs = widgets.Checkbox(
    value=False,
    description='Show subhalos',
    disabled=False,
    indent=False
)

show_groups.observe(call_update, names=['value'])
show_subs.observe(call_update, names=['value'])

# Save hdf5/save

frame_counter = widgets.IntSlider(value=0)
recording = widgets.Checkbox(
    value=False,
    description='Recording',
    disabled=False,
    indent=False
)

def reset_counter(b):
    """
    Reset counter if recording is stopped,
    calculate first image if recording is started
    """
    if not recording.value:
        frame_counter.value = 0
        basename.disabled = False
        outfolder.disabled = False
    else:
        basename.disabled = True
        outfolder.disabled = True
        mylogger('', mode='w')
        update()

recording.observe(reset_counter, names=['value'])  

hdf5 = widgets.Checkbox(
    value=False,
    description='hdf5',
    disabled=False,
    indent=False
)
png = widgets.Checkbox(
    value=False,
    description='png',
    disabled=False,
    indent=False
)

out = widgets.interactive_output(update, {},)

def change_center_using_cat(change):
    info = get_group_and_sub_indices()
    if select_center.value[0] == 'G':
        gr_id = int(select_center.value[1:])

        new_center = projector.snap.Cat.Group['GroupPos'][gr_id].T
        if recording.value:
            diff = new_center - projector._center
            mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
            mylogger(f"center_on_group,{gr_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}")
        projector._diff_center += new_center - projector._center
        projector.center = new_center.copy
        show_groups.value = False
        select_center.options = []
        select_center.value = None
    if select_center.value[0] == 'S':
        sub_id = int(select_center.value[1:])
        new_center = projector.snap.Cat.Sub['SubhaloPos'][sub_id].T
        if recording.value:
            diff = new_center - projector._center
            mylogger(f"move_center_sim_coordinates,{diff[0].value},{diff[1].value},{diff[2].value}", line_ending=',#,')
            mylogger(f"center_on_sub,{sub_id},{new_center[0].value},{new_center[1].value},{new_center[2].value}")

        projector._diff_center += new_center - projector._center
        projector.center = new_center.copy
        show_subs.value = False
        select_center.options = []
        select_center.value = None
    # update()

width.observe(change_width, names=['value'])
height.observe(change_height, names=['value'])
depth.observe(change_depth, names=['value'])
var_str.observe(change_var_str, names=['value'])
select_center.observe(change_center_using_cat, names=['value'])



basename = widgets.Text(value='image')
outfolder = widgets.Text(value='./')

display(out,
       widgets.HBox([button_update, var_str, button_zoom, zoom_slider]),
       widgets.HBox([button_left, button_right, button_up, button_down, button_clock_wise, button_anti_clock_wise]),
       widgets.HBox([width, height, depth]),
       widgets.HBox([step_button, center_horizontal, center_vertical, center_depth, recenter_button]),
       widgets.HBox([fix_climits, vmin, vmax]),
       widgets.HBox([step_size_in_degrees, to_physical, to_cgs, to_astro_units]),
       widgets.HBox([cmap_str, button_double, button_half]),
       widgets.HBox([recording, hdf5, png, outfolder, basename]),
       widgets.HBox([show_groups, show_subs, select_center]))

Output()

HBox(children=(Button(description='Update', style=ButtonStyle()), Dropdown(description='Field:', index=4, opti…

HBox(children=(Button(description='Pan left', style=ButtonStyle()), Button(description='Pan right', style=Butt…

HBox(children=(FloatSlider(value=10000.0, continuous_update=False, description='width', max=20000.0, min=1.0, …

HBox(children=(Button(description='Move center', style=ButtonStyle()), BoundedFloatText(value=0.0, description…

HBox(children=(Checkbox(value=False, description='Fix climits', indent=False), FloatText(value=5.8146158647168…

HBox(children=(FloatSlider(value=15.0, description='Step (Degrees)', max=90.0, step=0.5), Checkbox(value=False…

HBox(children=(Dropdown(description='Cmap:', index=3, options=('magma', 'inferno', 'plasma', 'viridis', 'civid…

HBox(children=(Checkbox(value=False, description='Recording', indent=False), Checkbox(value=False, description…

HBox(children=(Checkbox(value=False, description='Show FoF groups', indent=False), Checkbox(value=False, descr…