forked from neuraloperator/neuraloperator
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Spectrum analysis of datasets. (neuraloperator#193)
Adds Spectrum analysis function and example.
- Loading branch information
1 parent
cb62f8b
commit 7c54261
Showing
2 changed files
with
162 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
""" | ||
A simple Darcy-Flow spectrum analysis | ||
=========================== | ||
In this example, we demonstrate how to use the spectrum analysis function on the small Darcy-Flow example. | ||
For more details on spectrum analysis, users can take a look at this reference: https://www.astronomy.ohio-state.edu/ryden.1/ast825/ch7.pdf | ||
Short summary | ||
=========================== | ||
Spectral analysis is useful because it allows researchers to study the distribution of energy across different scales in a fluid flow. By examining the energy spectrum, one can gain insights into the behavior of turbulence or any other dataset and the underlying physical processes. The energy spectrum is analysed through the Fourier transform which is a mathematical tool that decomposes a function or signal into its constituent frequencies. In a fluid flow, it is used to analyze the distribution of energy across different scales in a flow. Specifically, the Fourier transform is applied to the velocity field of the flow, converting it into a frequency domain representation. Higher the wavenumber corresponds to higher frequency and higher energy and is a much harder task to solve as we need higher modes to capture the high-frequency behavior of the flow. Overall this allows researchers to study the energy spectrum, which provides insights into the behavior of turbulence and the underlying physical processes. | ||
""" | ||
# Original Author: Zongyi Li | ||
# Modified by: Robert Joseph George | ||
# %% | ||
# Import the library | ||
# ------------------ | ||
# We first import our `neuralop` library and required dependencies. | ||
import numpy as np | ||
import torch | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
from neuralop.utils import spectrum_2d | ||
from neuralop.datasets import load_darcy_flow_small | ||
|
||
font = {'size' : 28} | ||
matplotlib.rc('font', **font) | ||
|
||
torch.manual_seed(0) | ||
np.random.seed(0) | ||
|
||
# %% | ||
# Define some variables | ||
T = 500 # number of time steps | ||
samples = 50 | ||
s = 16 # resolution of the dataset | ||
|
||
# additional paramaters for the dataset | ||
Re = 5000 | ||
index = 1 | ||
T = 100 | ||
dataset_name = "Darcy Flow" | ||
|
||
# %% | ||
# Loading the Navier-Stokes dataset in 128x128 resolution | ||
train_loader, test_loaders, output_encoder = load_darcy_flow_small( | ||
n_train=50, batch_size=50, | ||
test_resolutions=[16, 32], n_tests=[50], | ||
test_batch_sizes=[32], positional_encoding=False, | ||
encode_output=False | ||
) | ||
|
||
# This is highly depending on your dataset and its structure ['x', 'y'] (In Darcy flow) | ||
print("Original dataset shape", train_loader.dataset[:samples]['x'].shape) # check the shape | ||
|
||
# It is important to note that we want the last two dimensions to represent the spatial dimensions | ||
# So in some cases one might have to permute the dataset after squeezing the initial dimensions as well | ||
dataset_pred = train_loader.dataset[:samples]['x'].squeeze() # squeeze the dataset to remove the batch dimension or other dimensions | ||
|
||
# Shape of the dataset | ||
shape = dataset_pred.shape | ||
|
||
# Define the grid size - in our case its a 2d Grid repeating, for higher dimensions this will change | ||
# Example for 3d grid | ||
""" | ||
batchsize, size_x, size_y, size_z = 1, shape[0], shape[1], shape[2] | ||
gridx = torch.tensor(np.linspace(-1, 1, size_x), dtype=torch.float) | ||
gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) | ||
gridy = torch.tensor(np.linspace(-1, 1, size_y), dtype=torch.float) | ||
gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) | ||
gridz = torch.tensor(np.linspace(-1, 1, size_z), dtype=torch.float) | ||
gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) | ||
grid = torch.cat((gridx, gridy, gridz), dim=-1) | ||
""" | ||
batchsize, size_x, size_y = 1, shape[1], shape[2] | ||
gridx = torch.tensor(np.linspace(-1, 1, size_x), dtype=torch.float) | ||
gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, size_y]) | ||
gridy = torch.tensor(np.linspace(-1, 1, size_y), dtype=torch.float) | ||
gridy = gridy.reshape(1, 1, size_y).repeat([batchsize, size_x, 1]) | ||
grid = torch.cat((gridx, gridy), dim=-1) | ||
|
||
# %% | ||
# ############################################################## | ||
### FFT plot | ||
############################################################## | ||
|
||
# Generate the spectrum of the dataset | ||
# Again only the last two dimensions have to be resolution and the first dimension is the reshaped product of all the other dimensions | ||
truth_sp = spectrum_2d(dataset_pred.reshape(samples * batchsize, s, s), s) | ||
|
||
# Generate the spectrum plot and set all the settings | ||
fig, ax = plt.subplots(figsize=(10,10)) | ||
|
||
linewidth = 3 | ||
ax.set_yscale('log') | ||
|
||
length = 16 # typically till the resolution length of the dataset | ||
buffer = 10 # just add a buffer to the plot | ||
k = np.arange(length + buffer) * 1.0 | ||
ax.plot(truth_sp, 'k', linestyle=":", label="NS", linewidth=4) | ||
|
||
ax.set_xlim(1,length+buffer) | ||
ax.set_ylim(10, 10^10) | ||
plt.legend(prop={'size': 20}) | ||
plt.title('Spectrum of {} Datset'.format(dataset_name)) | ||
|
||
plt.xlabel('wavenumber') | ||
plt.ylabel('energy') | ||
|
||
# show the figure | ||
leg = plt.legend(loc='best') | ||
leg.get_frame().set_alpha(0.5) | ||
plt.show() | ||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters