Skip to content

Commit

Permalink
Refactor usage-vtk-cfd to not mutate object
Browse files Browse the repository at this point in the history
This commit changes how the Viz object is initialized by moving the
expensive computations (loading bike_mesh and tunnelReader) inside the
initial loading, such that we can safely and quickly create a new Viz
object every time the callback is fired.

The previous approach simply mutated the `Viz` object every time a calllback
is fired, which could cause inteference when more than one user is using
the application
  • Loading branch information
xhlulu committed Feb 2, 2021
1 parent 76fdf96 commit a69ecaf
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions demos/usage-vtk-cfd/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import random

import time

import dash
import dash_bootstrap_components as dbc
Expand All @@ -14,33 +14,42 @@

import vtk

random.seed(42)

# -----------------------------------------------------------------------------
# VTK Pipeline
# -----------------------------------------------------------------------------


def load_tunnel_reader(data_directory):
tunnel_filename = os.path.join(data_directory, "tunnel.vtu")
tunnelReader = vtk.vtkXMLUnstructuredGridReader()
tunnelReader.SetFileName(tunnel_filename)
tunnelReader.Update()

return tunnelReader


def load_bike_mesh(data_directory):
bike_filename = os.path.join(data_directory, "bike.vtp")
bikeReader = vtk.vtkXMLPolyDataReader()
bikeReader.SetFileName(bike_filename)
bikeReader.Update()
bike_mesh = to_mesh_state(bikeReader.GetOutput())

return bike_mesh


class Viz:
def __init__(self, data_directory):
def __init__(self, tunnelReader, bike_mesh):
t1 = time.time()
self.color_range = [0, 1]
bike_filename = os.path.join(data_directory, "bike.vtp")
tunnel_filename = os.path.join(data_directory, "tunnel.vtu")

# Seeds settings
self.resolution = 10
self.point1 = [-0.4, 0, 0.05]
self.point2 = [-0.4, 0, 1.5]

# VTK Pipeline setup
bikeReader = vtk.vtkXMLPolyDataReader()
bikeReader.SetFileName(bike_filename)
bikeReader.Update()
self.bike_mesh = to_mesh_state(bikeReader.GetOutput())

tunnelReader = vtk.vtkXMLUnstructuredGridReader()
tunnelReader.SetFileName(tunnel_filename)
tunnelReader.Update()
self.bike_mesh = bike_mesh

self.lineSeed = vtk.vtkLineSource()
self.lineSeed.SetPoint1(*self.point1)
Expand Down Expand Up @@ -68,6 +77,9 @@ def __init__(self, data_directory):
self.tubeFilter.CappingOn()
self.tubeFilter.Update()

t2 = time.time()
print(f"Created Viz object in {t2-t1:.4f}s.")

def updateSeedPoints(self, p1_y, p2_y, resolution):
self.point1[1] = p1_y
self.point2[1] = p2_y
Expand Down Expand Up @@ -98,13 +110,18 @@ def getSeedState(self):
}



# -----------------------------------------------------------------------------
# GUI setup
# -----------------------------------------------------------------------------

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server
viz = Viz(os.path.join(os.path.dirname(__file__), "data"))

data_dir = os.path.join(os.path.dirname(__file__), "data")
tunnelReader = load_tunnel_reader(data_dir)
bike_mesh = load_bike_mesh(data_dir)
viz = Viz(tunnelReader, bike_mesh)

# -----------------------------------------------------------------------------
# 3D Viz
Expand Down Expand Up @@ -145,17 +162,19 @@ def getSeedState(self):
dbc.CardHeader("Seeds"),
dbc.CardBody(
[
html.P("Seed line:"),
html.P("Top starting position:"),
dcc.Slider(
id="point-1",
id="point-2",
min=-1,
max=1,
step=0.01,
value=0,
marks={-1: "-1", 1: "+1"},
),
html.Br(),
html.P("Bottom starting position:"),
dcc.Slider(
id="point-2",
id="point-1",
min=-1,
max=1,
step=0.01,
Expand Down Expand Up @@ -193,13 +212,15 @@ def getSeedState(self):
{"label": "k", "value": "k"},
],
value="p",
clearable=False,
),
html.Br(),
html.P("Color Preset"),
dcc.Dropdown(
id="preset",
options=preset_as_options,
value="erdc_rainbow_bright",
clearable=False,
),
]
),
Expand Down Expand Up @@ -252,7 +273,9 @@ def getSeedState(self):
],
)
def update_seeds(y1, y2, resolution, colorByField, presetName):
viz = Viz(tunnelReader, bike_mesh)
viz.updateSeedPoints(y1, y2, resolution)

return [
viz.getSeedState(),
viz.getTubesMesh(colorByField),
Expand Down

0 comments on commit a69ecaf

Please sign in to comment.