Skip to content

Commit

Permalink
Merge pull request #589 from bnavigator/drss-dt
Browse files Browse the repository at this point in the history
Return a discrete time system with drss()
  • Loading branch information
bnavigator committed Mar 25, 2021
2 parents ce4c2b6 + e50ce23 commit 8b900ca
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
36 changes: 22 additions & 14 deletions control/statesp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,11 +1434,11 @@ def _convert_to_statespace(sys, **kw):


# TODO: add discrete time option
def _rss_generate(states, inputs, outputs, type, strictly_proper=False):
def _rss_generate(states, inputs, outputs, cdtype, strictly_proper=False):
"""Generate a random state space.
This does the actual random state space generation expected from rss and
drss. type is 'c' for continuous systems and 'd' for discrete systems.
drss. cdtype is 'c' for continuous systems and 'd' for discrete systems.
"""

Expand All @@ -1465,6 +1465,8 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False):
if outputs < 1 or outputs % 1:
raise ValueError("outputs must be a positive integer. outputs = %g." %
outputs)
if cdtype not in ['c', 'd']:
raise ValueError("cdtype must be `c` or `d`")

# Make some poles for A. Preallocate a complex array.
poles = zeros(states) + zeros(states) * 0.j
Expand All @@ -1484,16 +1486,16 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False):
i += 2
elif rand() < pReal or i == states - 1:
# No-oscillation pole.
if type == 'c':
if cdtype == 'c':
poles[i] = -exp(randn()) + 0.j
elif type == 'd':
else:
poles[i] = 2. * rand() - 1.
i += 1
else:
# Complex conjugate pair of oscillating poles.
if type == 'c':
if cdtype == 'c':
poles[i] = complex(-exp(randn()), 3. * exp(randn()))
elif type == 'd':
else:
mag = rand()
phase = 2. * math.pi * rand()
poles[i] = complex(mag * cos(phase), mag * sin(phase))
Expand Down Expand Up @@ -1546,7 +1548,11 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False):
C = C * Cmask
D = D * Dmask if not strictly_proper else zeros(D.shape)

return StateSpace(A, B, C, D)
if cdtype == 'c':
ss_args = (A, B, C, D)
else:
ss_args = (A, B, C, D, True)
return StateSpace(*ss_args)


# Convert a MIMO system to a SISO system
Expand Down Expand Up @@ -1825,15 +1831,14 @@ def rss(states=1, outputs=1, inputs=1, strictly_proper=False):
Parameters
----------
states : integer
states : int
Number of state variables
inputs : integer
inputs : int
Number of system inputs
outputs : integer
outputs : int
Number of system outputs
strictly_proper : bool, optional
If set to 'True', returns a proper system (no direct term). Default
value is 'False'.
If set to 'True', returns a proper system (no direct term).
Returns
-------
Expand Down Expand Up @@ -1867,12 +1872,15 @@ def drss(states=1, outputs=1, inputs=1, strictly_proper=False):
Parameters
----------
states : integer
states : int
Number of state variables
inputs : integer
Number of system inputs
outputs : integer
outputs : int
Number of system outputs
strictly_proper: bool, optional
If set to 'True', returns a proper system (no direct term).
Returns
-------
Expand Down
36 changes: 35 additions & 1 deletion control/tests/statesp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from control.dtime import sample_system
from control.lti import evalfr
from control.statesp import (StateSpace, _convert_to_statespace, drss,
rss, ss, tf2ss, _statesp_defaults)
rss, ss, tf2ss, _statesp_defaults, _rss_generate)
from control.tests.conftest import ismatarrayout, slycotonly
from control.xferfcn import TransferFunction, ss2tf

Expand Down Expand Up @@ -855,6 +855,28 @@ def test_pole(self, states, outputs, inputs):
for z in p:
assert z.real < 0

@pytest.mark.parametrize('strictly_proper', [True, False])
def test_strictly_proper(self, strictly_proper):
"""Test that the strictly_proper argument returns a correct D."""
for i in range(100):
# The probability that drss(..., strictly_proper=False) returns an
# all zero D 100 times in a row is 0.5**100 = 7.89e-31
sys = rss(1, 1, 1, strictly_proper=strictly_proper)
if np.all(sys.D == 0.) == strictly_proper:
break
assert np.all(sys.D == 0.) == strictly_proper

@pytest.mark.parametrize('par, errmatch',
[((-1, 1, 1, 'c'), 'states must be'),
((1, -1, 1, 'c'), 'inputs must be'),
((1, 1, -1, 'c'), 'outputs must be'),
((1, 1, 1, 'x'), 'cdtype must be'),
])
def test_rss_invalid(self, par, errmatch):
"""Test invalid inputs for rss() and drss()."""
with pytest.raises(ValueError, match=errmatch):
_rss_generate(*par)


class TestDrss:
"""These are tests for the proper functionality of statesp.drss."""
Expand All @@ -873,6 +895,7 @@ def test_shape(self, states, outputs, inputs):
assert sys.nstates == states
assert sys.ninputs == inputs
assert sys.noutputs == outputs
assert sys.dt is True

@pytest.mark.parametrize('states', range(1, maxStates))
@pytest.mark.parametrize('outputs', range(1, maxIO))
Expand All @@ -884,6 +907,17 @@ def test_pole(self, states, outputs, inputs):
for z in p:
assert abs(z) < 1

@pytest.mark.parametrize('strictly_proper', [True, False])
def test_strictly_proper(self, strictly_proper):
"""Test that the strictly_proper argument returns a correct D."""
for i in range(100):
# The probability that drss(..., strictly_proper=False) returns an
# all zero D 100 times in a row is 0.5**100 = 7.89e-31
sys = drss(1, 1, 1, strictly_proper=strictly_proper)
if np.all(sys.D == 0.) == strictly_proper:
break
assert np.all(sys.D == 0.) == strictly_proper


class TestLTIConverter:
"""Test returnScipySignalLTI method"""
Expand Down

0 comments on commit 8b900ca

Please sign in to comment.