Skip to content

Commit

Permalink
Initial implementation of parameter scanning logic in ParameterGrid
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar1729 committed Apr 23, 2024
1 parent ecf57a5 commit 14df2a9
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion src/qimpy/transport/geometry/_parameter_grid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations
from typing import Optional

import numpy as np
import torch

from qimpy.io import CheckpointPath
from qimpy import rc, log
from qimpy.io import CheckpointPath, InvalidInputException
from qimpy.mpi import ProcessGrid
from qimpy.transport.material import Material
from . import TensorList, Geometry, QuadSet
Expand All @@ -11,11 +14,15 @@
class ParameterGrid(Geometry):
"""Geometry specification."""

shape: tuple[int, int] #: Dimensions of parameter grid

def __init__(
self,
*,
material: Material,
shape: tuple[int, int],
dimension1: Optional[dict[str, dict[str, list]]] = None,
dimension2: Optional[dict[str, dict[str, list]]] = None,
process_grid: ProcessGrid,
checkpoint_in: CheckpointPath = CheckpointPath(),
) -> None:
Expand All @@ -26,8 +33,17 @@ def __init__(
----------
shape
:yaml:`Dimensions of parameter grid (always 2D).`
dimension1
:yaml:`Parameter names and values to sweep along dimension 1.`
The values can be specified with "loop" over explicit values,
or "sweep" linearly from the initial to the final value.
dimension2
:yaml:`Parameter names and values to sweep along dimension 2.`
Specification is the same as for `dimension1`.
"""
assert len(shape) == 2
self.shape = shape

# Create fake geometry for parameter grid
quad_set = QuadSet(
vertices=_GRID_VERTICES * np.array(shape),
Expand All @@ -50,6 +66,52 @@ def __init__(
)
self.dt_max = 0 # disable transport dt limit

# Prepare all parameter values:
parameters: dict[str, torch.Tensor] = {}
for i_dim, dimension in enumerate([dimension1, dimension2]):
if dimension is not None:
for key, values in dimension.items():
parameters[key] = self.create_values(i_dim, **values)
log.info(f"{parameters}")
exit()

def create_values(
self,
i_dim: int,
*,
loop: Optional[list] = None,
sweep: Optional[list] = None,
) -> torch.Tensor:
if (loop is None) == (sweep is None):
raise InvalidInputException(
"Exactly one of loop or sweep should be specified."
)

# Prepare scan of values based on loop or sweep:
values: Optional[torch.Tensor] = None
if loop is not None:
values = torch.tensor(loop, device=rc.device)
if len(values) != self.shape[i_dim]:
raise InvalidInputException(
f"Number of entries in loop = {len(values)} should match"
f" length {self.shape[i_dim]} of dimension {i_dim + 1}"
)
if sweep is not None:
limits = torch.tensor(sweep, device=rc.device)
if len(limits) != 2:
raise InvalidInputException(
f"Number of entries in sweep = {len(limits)} must be 2"
)
t = torch.linspace(0, 1, self.shape[i_dim], device=rc.device)
values = limits[0] + torch.einsum("t,...->t...", t, limits[1] - limits[0])

# Reshape to broadcast along appropriate dimension:
assert values is not None
if i_dim == 0:
return values[:, None]
else: # i_dim == 1
return values[None]

def rho_dot(self, rho: TensorList, t: float) -> TensorList:
return TensorList(self.material.rho_dot(rho_i, t) for rho_i in rho)

Expand Down

0 comments on commit 14df2a9

Please sign in to comment.