Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lti squeeze: ndarray.ndim == 0 is also a scalar #595

Merged
merged 3 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion control/lti.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _process_frequency_response(sys, omega, out, squeeze=None):
if squeeze is None:
squeeze = config.defaults['control.squeeze_frequency_response']

if not hasattr(omega, '__len__'):
if np.asarray(omega).ndim < 1:
# received a scalar x, squeeze down the array along last dim
out = np.squeeze(out, axis=2)

Expand Down
57 changes: 39 additions & 18 deletions control/tests/lti_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import control as ct
from control import c2d, tf, tf2ss, NonlinearIOSystem
from control.lti import (LTI, common_timebase, damp, dcgain, isctime, isdtime,
issiso, pole, timebaseEqual, zero)
from control.lti import (LTI, common_timebase, evalfr, damp, dcgain, isctime,
isdtime, issiso, pole, timebaseEqual, zero)
from control.tests.conftest import slycotonly
from control.exception import slycot_check

Expand Down Expand Up @@ -179,11 +179,20 @@ def test_isdtime(self, objfun, arg, dt, ref, strictref):
[1, 1, 2, [0.1, 1, 10], None, (1, 2, 3)], # MISO
[2, 1, 2, [0.1, 1, 10], True, (2, 3)],
[3, 1, 2, [0.1, 1, 10], False, (1, 2, 3)],
[1, 1, 2, 0.1, None, (1, 2)],
[1, 1, 2, 0.1, True, (2,)],
[1, 1, 2, 0.1, False, (1, 2)],
[1, 2, 2, [0.1, 1, 10], None, (2, 2, 3)], # MIMO
[2, 2, 2, [0.1, 1, 10], True, (2, 2, 3)],
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)]
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)],
[1, 2, 2, 0.1, None, (2, 2)],
[2, 2, 2, 0.1, True, (2, 2)],
[3, 2, 2, 0.1, False, (2, 2)],
])
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
@pytest.mark.parametrize("omega_type", ["numpy", "native"])
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape,
omega_type):
"""Test correct behavior of frequencey response squeeze parameter."""
# Create the system to be tested
if fcn == ct.frd:
sys = fcn(ct.rss(nstate, nout, ninp), [1e-2, 1e-1, 1, 1e1, 1e2])
Expand All @@ -193,15 +202,23 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
else:
sys = fcn(ct.rss(nstate, nout, ninp))

# Convert the frequency list to an array for easy of use
isscalar = not hasattr(omega, '__len__')
omega = np.array(omega)
if omega_type == "numpy":
omega = np.asarray(omega)
isscalar = omega.ndim == 0
# keep the ndarray type even for scalars
s = np.asarray(omega * 1j)
else:
isscalar = not hasattr(omega, '__len__')
if isscalar:
s = omega*1J
else:
s = [w*1J for w in omega]

# Call the transfer function directly and make sure shape is correct
assert sys(omega * 1j, squeeze=squeeze).shape == shape
assert sys(s, squeeze=squeeze).shape == shape

# Make sure that evalfr also works as expected
assert ct.evalfr(sys, omega * 1j, squeeze=squeeze).shape == shape
assert ct.evalfr(sys, s, squeeze=squeeze).shape == shape

# Check frequency response
mag, phase, _ = sys.frequency_response(omega, squeeze=squeeze)
Expand All @@ -216,22 +233,22 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):

# Make sure the default shape lines up with squeeze=None case
if squeeze is None:
assert sys(omega * 1j).shape == shape
assert sys(s).shape == shape

# Changing config.default to False should return 3D frequency response
ct.config.set_defaults('control', squeeze_frequency_response=False)
mag, phase, _ = sys.frequency_response(omega)
if isscalar:
assert mag.shape == (sys.noutputs, sys.ninputs, 1)
assert phase.shape == (sys.noutputs, sys.ninputs, 1)
assert sys(omega * 1j).shape == (sys.noutputs, sys.ninputs)
assert ct.evalfr(sys, omega * 1j).shape == (sys.noutputs, sys.ninputs)
assert sys(s).shape == (sys.noutputs, sys.ninputs)
assert ct.evalfr(sys, s).shape == (sys.noutputs, sys.ninputs)
else:
assert mag.shape == (sys.noutputs, sys.ninputs, len(omega))
assert phase.shape == (sys.noutputs, sys.ninputs, len(omega))
assert sys(omega * 1j).shape == \
assert sys(s).shape == \
(sys.noutputs, sys.ninputs, len(omega))
assert ct.evalfr(sys, omega * 1j).shape == \
assert ct.evalfr(sys, s).shape == \
(sys.noutputs, sys.ninputs, len(omega))

@pytest.mark.parametrize("fcn", [ct.ss, ct.tf, ct.frd, ct.ss2io])
Expand All @@ -243,13 +260,17 @@ def test_squeeze_exceptions(self, fcn):

with pytest.raises(ValueError, match="unknown squeeze value"):
sys.frequency_response([1], squeeze=1)
sys([1], squeeze='siso')
evalfr(sys, [1], squeeze='siso')
with pytest.raises(ValueError, match="unknown squeeze value"):
sys([1j], squeeze='siso')
with pytest.raises(ValueError, match="unknown squeeze value"):
evalfr(sys, [1j], squeeze='siso')

with pytest.raises(ValueError, match="must be 1D"):
sys.frequency_response([[0.1, 1], [1, 10]])
sys([[0.1, 1], [1, 10]])
evalfr(sys, [[0.1, 1], [1, 10]])
with pytest.raises(ValueError, match="must be 1D"):
sys([[0.1j, 1j], [1j, 10j]])
with pytest.raises(ValueError, match="must be 1D"):
evalfr(sys, [[0.1j, 1j], [1j, 10j]])

with pytest.warns(DeprecationWarning, match="LTI `inputs`"):
ninputs = sys.inputs
Expand Down