In [9]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import vtk
from vtkmodules.util import numpy_support
from data.data import fitTransPipeline

In [14]:
data_path = "data/bc80-45000-down.vts"
vtkSGR = vtk.vtkXMLStructuredGridReader()
vtkSGR.SetFileName(data_path)
vtkSGR.Update()
vtkSG = vtkSGR.GetOutput()
print(vtkSG.GetDimensions())
print(vtkSG.GetExtent())
print(vtkSG.GetBounds())

(88, 51, 21)
(0, 87, 0, 50, 0, 20)
(0.0, 696.0, -250.0, 150.0, 0.7504535071217414, 429.4370158281313)


In [29]:
for i in range(1):
  for j in range(10):
    print(tmp[0,j,i], vtkSG.GetPoint(i+j*88))

[   0.         -250.          191.09518731] (0.0, -250.0, 191.09518730899723)
[   0.         -242.          191.09518731] (0.0, -242.0, 191.09518730899723)
[   0.         -234.          191.09518731] (0.0, -234.0, 191.09518730899723)
[   0.         -226.          191.09518731] (0.0, -226.0, 191.09518730899723)
[   0.         -218.          191.09518731] (0.0, -218.0, 191.09518730899723)
[   0.         -210.          191.09518731] (0.0, -210.0, 191.09518730899723)
[   0.         -202.          191.09518731] (0.0, -202.0, 191.09518730899723)
[   0.         -194.          191.09518731] (0.0, -194.0, 191.09518730899723)
[   0.         -186.          191.09518731] (0.0, -186.0, 191.09518730899723)
[   0.         -178.          191.09518731] (0.0, -178.0, 191.09518730899723)


In [7]:
vtkPoints = vtkSG.GetPoints()

In [10]:
coord = numpy_support.vtk_to_numpy(vtkPoints.GetData())

In [20]:
tmp = coord.reshape(21, 51, 88, 3)

In [34]:
np.linspace(0, 1, 88)

array([0.        , 0.01149425, 0.02298851, 0.03448276, 0.04597701,
       0.05747126, 0.06896552, 0.08045977, 0.09195402, 0.10344828,
       0.11494253, 0.12643678, 0.13793103, 0.14942529, 0.16091954,
       0.17241379, 0.18390805, 0.1954023 , 0.20689655, 0.2183908 ,
       0.22988506, 0.24137931, 0.25287356, 0.26436782, 0.27586207,
       0.28735632, 0.29885057, 0.31034483, 0.32183908, 0.33333333,
       0.34482759, 0.35632184, 0.36781609, 0.37931034, 0.3908046 ,
       0.40229885, 0.4137931 , 0.42528736, 0.43678161, 0.44827586,
       0.45977011, 0.47126437, 0.48275862, 0.49425287, 0.50574713,
       0.51724138, 0.52873563, 0.54022989, 0.55172414, 0.56321839,
       0.57471264, 0.5862069 , 0.59770115, 0.6091954 , 0.62068966,
       0.63218391, 0.64367816, 0.65517241, 0.66666667, 0.67816092,
       0.68965517, 0.70114943, 0.71264368, 0.72413793, 0.73563218,
       0.74712644, 0.75862069, 0.77011494, 0.7816092 , 0.79310345,
       0.8045977 , 0.81609195, 0.82758621, 0.83908046, 0.85057

In [63]:
dims = [21, 51, 88]
z,y,x = np.meshgrid(
  np.linspace(0, 1, dims[0]),
  np.linspace(0, 1, dims[1]),
  np.linspace(0, 1, dims[2]),
  indexing="ij"
)
xyz = np.concatenate([z[...,None], y[...,None], x[...,None]], axis=-1)

In [95]:
class Phys2CompDataset(Dataset):
  def __init__(self, data_path, intrans=None, outtrans=None):
    self.data_path = data_path
    vtkSGR = vtk.vtkXMLStructuredGridReader()
    vtkSGR.SetFileName(data_path)
    vtkSGR.Update()
    vtkSG = vtkSGR.GetOutput()
    # get dimension and swap xyz order to zyx
    self.dims = np.array(vtkSG.GetDimensions())
    tmp = self.dims[0]
    self.dims[0] = self.dims[2]
    self.dims[2] = tmp
    print(dims)
    # get phys mesh
    self.phys = numpy_support.vtk_to_numpy(vtkSG.GetPoints().GetData())
    # get comp mesh - regular grid
    self.phys = self.phys.reshape([*self.dims, 3])
    z,y,x = np.meshgrid(
      np.linspace(0, 1, self.dims[0]),
      np.linspace(0, 1, self.dims[1]),
      np.linspace(0, 1, self.dims[2]),
      indexing="ij"
    )
    self.comp = np.concatenate([z[...,None], y[...,None], x[...,None]], axis=-1)
    # preprocessing
    self.phys_prep = torch.Tensor(self.phys.reshape(-1, 3))
    self.comp_prep = torch.Tensor(self.comp.reshape(-1, 3))

    if intrans is not None:
      print("transforming inputs")
      self.phys_prep, self.inpp = fitTransPipeline(self.phys.reshape(-1, 3))
      self.phys_prep = torch.Tensor(self.phys_prep)
    
    if outtrans is not None:
      print("transforming outputs")
      self.comp_prep, self.outpp = fitTransPipeline(self.comp.reshape(-1, 3))
      self.comp_prep = torch.Tensor(self.comp_prep)

  def __len__(self):
    return len(self.phys_prep)

  def __getitem__(self, idx):
    return self.phys_prep[idx], self.comp_prep[idx]
    
ds = Phys2CompDataset(data_path)

[21, 51, 88]
