Skip to content

Commit

Permalink
Merge pull request #595 from bnavigator/fix-594
Browse files Browse the repository at this point in the history
lti squeeze: ndarray.ndim == 0 is also a scalar
  • Loading branch information
bnavigator committed Apr 1, 2021
2 parents f1a9860 + 5646146 commit de87cc6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion control/lti.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _process_frequency_response(sys, omega, out, squeeze=None):
if squeeze is None:
squeeze = config.defaults['control.squeeze_frequency_response']

if not hasattr(omega, '__len__'):
if np.asarray(omega).ndim < 1:
# received a scalar x, squeeze down the array along last dim
out = np.squeeze(out, axis=2)

Expand Down
57 changes: 39 additions & 18 deletions control/tests/lti_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import control as ct
from control import c2d, tf, tf2ss, NonlinearIOSystem
from control.lti import (LTI, common_timebase, damp, dcgain, isctime, isdtime,
issiso, pole, timebaseEqual, zero)
from control.lti import (LTI, common_timebase, evalfr, damp, dcgain, isctime,
isdtime, issiso, pole, timebaseEqual, zero)
from control.tests.conftest import slycotonly
from control.exception import slycot_check

Expand Down Expand Up @@ -179,11 +179,20 @@ def test_isdtime(self, objfun, arg, dt, ref, strictref):
[1, 1, 2, [0.1, 1, 10], None, (1, 2, 3)], # MISO
[2, 1, 2, [0.1, 1, 10], True, (2, 3)],
[3, 1, 2, [0.1, 1, 10], False, (1, 2, 3)],
[1, 1, 2, 0.1, None, (1, 2)],
[1, 1, 2, 0.1, True, (2,)],
[1, 1, 2, 0.1, False, (1, 2)],
[1, 2, 2, [0.1, 1, 10], None, (2, 2, 3)], # MIMO
[2, 2, 2, [0.1, 1, 10], True, (2, 2, 3)],
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)]
[3, 2, 2, [0.1, 1, 10], False, (2, 2, 3)],
[1, 2, 2, 0.1, None, (2, 2)],
[2, 2, 2, 0.1, True, (2, 2)],
[3, 2, 2, 0.1, False, (2, 2)],
])
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
@pytest.mark.parametrize("omega_type", ["numpy", "native"])
def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape,
omega_type):
"""Test correct behavior of frequencey response squeeze parameter."""
# Create the system to be tested
if fcn == ct.frd:
sys = fcn(ct.rss(nstate, nout, ninp), [1e-2, 1e-1, 1, 1e1, 1e2])
Expand All @@ -193,15 +202,23 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):
else:
sys = fcn(ct.rss(nstate, nout, ninp))

# Convert the frequency list to an array for easy of use
isscalar = not hasattr(omega, '__len__')
omega = np.array(omega)
if omega_type == "numpy":
omega = np.asarray(omega)
isscalar = omega.ndim == 0
# keep the ndarray type even for scalars
s = np.asarray(omega * 1j)
else:
isscalar = not hasattr(omega, '__len__')
if isscalar:
s = omega*1J
else:
s = [w*1J for w in omega]

# Call the transfer function directly and make sure shape is correct
assert sys(omega * 1j, squeeze=squeeze).shape == shape
assert sys(s, squeeze=squeeze).shape == shape

# Make sure that evalfr also works as expected
assert ct.evalfr(sys, omega * 1j, squeeze=squeeze).shape == shape
assert ct.evalfr(sys, s, squeeze=squeeze).shape == shape

# Check frequency response
mag, phase, _ = sys.frequency_response(omega, squeeze=squeeze)
Expand All @@ -216,22 +233,22 @@ def test_squeeze(self, fcn, nstate, nout, ninp, omega, squeeze, shape):

# Make sure the default shape lines up with squeeze=None case
if squeeze is None:
assert sys(omega * 1j).shape == shape
assert sys(s).shape == shape

# Changing config.default to False should return 3D frequency response
ct.config.set_defaults('control', squeeze_frequency_response=False)
mag, phase, _ = sys.frequency_response(omega)
if isscalar:
assert mag.shape == (sys.noutputs, sys.ninputs, 1)
assert phase.shape == (sys.noutputs, sys.ninputs, 1)
assert sys(omega * 1j).shape == (sys.noutputs, sys.ninputs)
assert ct.evalfr(sys, omega * 1j).shape == (sys.noutputs, sys.ninputs)
assert sys(s).shape == (sys.noutputs, sys.ninputs)
assert ct.evalfr(sys, s).shape == (sys.noutputs, sys.ninputs)
else:
assert mag.shape == (sys.noutputs, sys.ninputs, len(omega))
assert phase.shape == (sys.noutputs, sys.ninputs, len(omega))
assert sys(omega * 1j).shape == \
assert sys(s).shape == \
(sys.noutputs, sys.ninputs, len(omega))
assert ct.evalfr(sys, omega * 1j).shape == \
assert ct.evalfr(sys, s).shape == \
(sys.noutputs, sys.ninputs, len(omega))

@pytest.mark.parametrize("fcn", [ct.ss, ct.tf, ct.frd, ct.ss2io])
Expand All @@ -243,13 +260,17 @@ def test_squeeze_exceptions(self, fcn):

with pytest.raises(ValueError, match="unknown squeeze value"):
sys.frequency_response([1], squeeze=1)
sys([1], squeeze='siso')
evalfr(sys, [1], squeeze='siso')
with pytest.raises(ValueError, match="unknown squeeze value"):
sys([1j], squeeze='siso')
with pytest.raises(ValueError, match="unknown squeeze value"):
evalfr(sys, [1j], squeeze='siso')

with pytest.raises(ValueError, match="must be 1D"):
sys.frequency_response([[0.1, 1], [1, 10]])
sys([[0.1, 1], [1, 10]])
evalfr(sys, [[0.1, 1], [1, 10]])
with pytest.raises(ValueError, match="must be 1D"):
sys([[0.1j, 1j], [1j, 10j]])
with pytest.raises(ValueError, match="must be 1D"):
evalfr(sys, [[0.1j, 1j], [1j, 10j]])

with pytest.warns(DeprecationWarning, match="LTI `inputs`"):
ninputs = sys.inputs
Expand Down

0 comments on commit de87cc6

Please sign in to comment.