-
Notifications
You must be signed in to change notification settings - Fork 4
/
grid.py
218 lines (189 loc) · 8.36 KB
/
grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
try:
import cupy as cp
except ImportError:
import numpy as cp
def _check_valid_tuple(points: tuple, grid_spacings: tuple) -> None:
if len(points) != len(grid_spacings):
raise ValueError(f"{points} and {grid_spacings} are not of same length")
if len(points) > 3:
raise ValueError(f"{points} is not a valid dimensionality")
for point in points:
if not isinstance(point, int):
raise ValueError(f"{points} contains non-integer values")
class Grid:
"""An object representing the numerical grid.
It contains information on the number of grid points, the shape, the
dimensionality, and lengths of the grid.
:param points: Number of points in each spatial dimension.
:type points: int or tuple of ints
:param grid_spacings: Numerical spacing between grid points in each
spatial dimension.
:type grid_spacings: float or tuple of floats
:ivar shape: Shape of the grid.
:ivar ndim: Dimensionality of the grid.
:ivar total_num_points: Total number of grid points across all dimensions.
:ivar num_points_x: Number of points in the x-direction.
:ivar num_points_y: (2D and 3D only) Number of points in the y-direction.
:ivar num_points_z: (3D only) Number of points in the z-direction.
:ivar length_x: Length of the grid in the x-direction.
:ivar length_y: (2D and 3D only) Length of the grid in the y-direction.
:ivar length_z: (3D only) Length of the grid in the z-direction.
:ivar x_mesh: The x meshgrid. The dimensionality matches that of `ndim`.
:ivar y_mesh: (2D and 3D only) The y meshgrid. The dimensionality matches
that of `ndim`.
:ivar z_mesh: (3D only) The z meshgrid. The dimensionality matches that of
`ndim`.
:ivar grid_spacing_x: Grid spacing in the x-direction.
:ivar grid_spacing_y: (2D and 3D only) Grid spacing in the y-direction.
:ivar grid_spacing_z: (3D only) Grid spacing in the z-direction.
:ivar grid_spacing_product: The product of the grid spacing for each
dimension.
:ivar fourier_x_mesh: The Fourier-space x meshgrid. The dimensionality
matches that of `ndim`.
:ivar fourier_y_mesh: (2D and 3D only) The Fourier-space y meshgrid. The
dimensionality matches that of `ndim`.
:ivar fourier_z_mesh: (3D only) The Fourier-space z meshgrid. The
dimensionality matches that of `ndim`.
:ivar fourier_spacing_x: Fourier grid spacing in the x-direction.
:ivar fourier_spacing_y: (2D and 3D only) Fourier grid spacing in the
y-direction.
:ivar fourier_spacing_z: (3D only) Fourier grid spacing in the z-direction.
"""
def __init__(
self,
points: int | tuple[int, ...],
grid_spacings: float | tuple[float, ...],
):
"""Constructs the grid object."""
self.shape = points
if isinstance(points, tuple):
_check_valid_tuple(points, grid_spacings)
self.ndim = len(points)
self.total_num_points = 1
for point in points:
self.total_num_points *= point
elif isinstance(points, int):
self.ndim = 1
self.total_num_points = points
else:
raise ValueError(
f"{points} is of unsupported type. Use int or tuple of ints."
)
if self.ndim == 1:
self._generate_1d_grids(points, grid_spacings)
elif self.ndim == 2:
self._generate_2d_grids(points, grid_spacings)
elif self.ndim == 3:
self._generate_3d_grids(points, grid_spacings)
def _generate_1d_grids(self, points: int, grid_spacing: float):
"""Generates meshgrid for a 1D grid."""
self.num_points_x = points
self.grid_spacing_x = grid_spacing
self.grid_spacing_product = self.grid_spacing_x
self.length_x = self.num_points_x * self.grid_spacing_x
self.x_mesh = (
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.grid_spacing_x
)
self.fourier_spacing_x = cp.pi / (self.num_points_x // 2 * self.grid_spacing_x)
self.fourier_x_mesh = cp.fft.fftshift(
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.fourier_spacing_x
)
# Defined on device for use in evolution
self.wave_number = cp.asarray(self.fourier_x_mesh**2)
def _generate_2d_grids(
self, points: tuple[int, ...], grid_spacings: tuple[float, ...]
):
"""Generates meshgrid for a 2D grid."""
self.num_points_x, self.num_points_y = points
self.grid_spacing_x, self.grid_spacing_y = grid_spacings
self.grid_spacing_product = self.grid_spacing_x * self.grid_spacing_y
self.length_x = self.num_points_x * self.grid_spacing_x
self.length_y = self.num_points_y * self.grid_spacing_y
x = (
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.grid_spacing_x
)
y = (
cp.arange(-self.num_points_y // 2, self.num_points_y // 2)
* self.grid_spacing_y
)
self.x_mesh, self.y_mesh = cp.meshgrid(x, y)
# Generate Fourier space variables
self.fourier_spacing_x = cp.pi / (self.num_points_x // 2 * self.grid_spacing_x)
self.fourier_spacing_y = cp.pi / (self.num_points_y // 2 * self.grid_spacing_y)
fourier_x = (
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.fourier_spacing_x
)
fourier_y = (
cp.arange(-self.num_points_y // 2, self.num_points_y // 2)
* self.fourier_spacing_y
)
self.fourier_x_mesh, self.fourier_y_mesh = cp.meshgrid(fourier_x, fourier_y)
self.fourier_x_mesh = cp.fft.fftshift(self.fourier_x_mesh)
self.fourier_y_mesh = cp.fft.fftshift(self.fourier_y_mesh)
# Defined on device for use in evolution
self.wave_number = cp.asarray(
self.fourier_x_mesh**2 + self.fourier_y_mesh**2
)
def _generate_3d_grids(
self, points: tuple[int, ...], grid_spacings: tuple[float, ...]
):
"""Generates meshgrid for a 3D grid."""
self.num_points_x, self.num_points_y, self.num_points_z = points
(
self.grid_spacing_x,
self.grid_spacing_y,
self.grid_spacing_z,
) = grid_spacings
self.grid_spacing_product = (
self.grid_spacing_x * self.grid_spacing_y * self.grid_spacing_z
)
self.length_x = self.num_points_x * self.grid_spacing_x
self.length_y = self.num_points_y * self.grid_spacing_y
self.length_z = self.num_points_z * self.grid_spacing_z
x = (
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.grid_spacing_x
)
y = (
cp.arange(-self.num_points_y // 2, self.num_points_y // 2)
* self.grid_spacing_y
)
z = (
cp.arange(-self.num_points_z // 2, self.num_points_z // 2)
* self.grid_spacing_z
)
self.x_mesh, self.y_mesh, self.z_mesh = cp.meshgrid(x, y, z)
# Generate Fourier space variables
self.fourier_spacing_x = cp.pi / (self.num_points_x // 2 * self.grid_spacing_x)
self.fourier_spacing_y = cp.pi / (self.num_points_y // 2 * self.grid_spacing_y)
self.fourier_spacing_z = cp.pi / (self.num_points_z // 2 * self.grid_spacing_z)
fourier_x = (
cp.arange(-self.num_points_x // 2, self.num_points_x // 2)
* self.fourier_spacing_x
)
fourier_y = (
cp.arange(-self.num_points_y // 2, self.num_points_y // 2)
* self.fourier_spacing_y
)
fourier_z = (
cp.arange(-self.num_points_z // 2, self.num_points_z // 2)
* self.fourier_spacing_z
)
(
self.fourier_x_mesh,
self.fourier_y_mesh,
self.fourier_z_mesh,
) = cp.meshgrid(fourier_x, fourier_y, fourier_z)
self.fourier_x_mesh = cp.fft.fftshift(self.fourier_x_mesh)
self.fourier_y_mesh = cp.fft.fftshift(self.fourier_y_mesh)
self.fourier_z_mesh = cp.fft.fftshift(self.fourier_z_mesh)
# Defined on device for use in evolution
self.wave_number = cp.asarray(
self.fourier_x_mesh**2
+ self.fourier_y_mesh**2
+ self.fourier_z_mesh**2
)