Skip to content

Commit

Permalink
implemented seeding, passes current test cases but not new ones yet
Browse files Browse the repository at this point in the history
  • Loading branch information
pseeth committed Jan 30, 2020
1 parent 43ff378 commit 622b969
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 28 deletions.
54 changes: 37 additions & 17 deletions scaper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@
from .util import _get_sorted_files
from .util import _validate_folder_path
from .util import _populate_label_list
from .util import _check_random_state
from .util import _trunc_norm
from .util import _uniform
from .util import _choose
from .util import _normal
from .util import _const
from .util import max_polyphony
from .util import polyphony_gini
from .util import is_real_number, is_real_array
from .audio import get_integrated_lufs
from .version import version as scaper_version

# TODO: for seeding, turn these into more complex functions in util?
SUPPORTED_DIST = {"const": lambda x: x,
"choose": lambda x: random.choice(x),
"uniform": random.uniform,
"normal": random.normalvariate,
SUPPORTED_DIST = {"const": _const,
"choose": _choose,
"uniform": _uniform,
"normal": _normal,
"truncnorm": _trunc_norm}

# Define single event spec as namedtuple
Expand Down Expand Up @@ -246,8 +251,7 @@ def trim(audio_infile, jams_infile, audio_outfile, jams_outfile, start_time,
# Copy result back to original file
shutil.copyfile(tmpfiles[-1].name, audio_outfile)

# TODO: this should take a np.RandomSeed object (default to None)
def _get_value_from_dist(dist_tuple):
def _get_value_from_dist(dist_tuple, random_state):
'''
Sample a value from the provided distribution tuple.
Expand Down Expand Up @@ -275,7 +279,7 @@ def _get_value_from_dist(dist_tuple):
'''
# Make sure it's a valid distribution tuple
_validate_distribution(dist_tuple)
return SUPPORTED_DIST[dist_tuple[0]](*dist_tuple[1:])
return SUPPORTED_DIST[dist_tuple[0]](*dist_tuple[1:], random_state=random_state)


def _validate_distribution(dist_tuple):
Expand Down Expand Up @@ -813,6 +817,7 @@ def _validate_event(label, source_file, source_time, event_time,
# Time stretch
_validate_time_stretch(time_stretch)


# TODO: add a seed parameter in init that defaults to None
class Scaper(object):
'''
Expand Down Expand Up @@ -841,7 +846,7 @@ class Scaper(object):
'''

def __init__(self, duration, fg_path, bg_path, protected_labels=[]):
def __init__(self, duration, fg_path, bg_path, protected_labels=[], random_state=None):
'''
Create a Scaper object.
Expand All @@ -865,6 +870,11 @@ def __init__(self, duration, fg_path, bg_path, protected_labels=[]):
whose semantic validity would be lost if the sound were trimmed
before the sound event ends, for example an animal vocalization
such as a dog bark.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number
generator; If RandomState instance, random_state is the random number
generator; If None, the random number generator is the RandomState
instance used by np.random.
'''
# Duration must be a positive real number
Expand Down Expand Up @@ -901,6 +911,9 @@ def __init__(self, duration, fg_path, bg_path, protected_labels=[]):
# Copy list of protected labels
self.protected_labels = protected_labels[:]

# Get random number generator
self.random_state = _check_random_state(random_state)

def add_background(self, label, source_file, source_time):
'''
Add a background recording to the background specification.
Expand Down Expand Up @@ -1178,7 +1191,7 @@ def _instantiate_event(self, event, isbackground=False,
label_tuple = tuple(label_tuple)
else:
label_tuple = event.label
label = _get_value_from_dist(label_tuple)
label = _get_value_from_dist(label_tuple, self.random_state)

# Make sure we can use this label
if (not allow_repeated_label) and (label in used_labels):
Expand All @@ -1190,7 +1203,7 @@ def _instantiate_event(self, event, isbackground=False,
"allow_repeated_label=False.".format(label))
else:
while label in used_labels:
label = _get_value_from_dist(label_tuple)
label = _get_value_from_dist(label_tuple, self.random_state)

# Update the used labels list
if label not in used_labels:
Expand All @@ -1206,7 +1219,7 @@ def _instantiate_event(self, event, isbackground=False,
else:
source_file_tuple = event.source_file

source_file = _get_value_from_dist(source_file_tuple)
source_file = _get_value_from_dist(source_file_tuple, self.random_state)

# Make sure we can use this source file
if (not allow_repeated_source) and (source_file in used_source_files):
Expand All @@ -1219,7 +1232,7 @@ def _instantiate_event(self, event, isbackground=False,
"allow_repeated_source=False.".format(label))
else:
while source_file in used_source_files:
source_file = _get_value_from_dist(source_file_tuple)
source_file = _get_value_from_dist(source_file_tuple, self.random_state)

# Update the used source files list
if source_file not in used_source_files:
Expand All @@ -1239,7 +1252,9 @@ def _instantiate_event(self, event, isbackground=False,
# potentially be non-positive, hence the loop.
event_duration = -np.Inf
while event_duration <= 0:
event_duration = _get_value_from_dist(event.event_duration)
event_duration = _get_value_from_dist(
event.event_duration, self.random_state
)

# Check if chosen event duration is longer than the duration of the
# selected source file, if so adjust the event duration.
Expand All @@ -1260,7 +1275,9 @@ def _instantiate_event(self, event, isbackground=False,
else:
time_stretch = -np.Inf
while time_stretch <= 0:
time_stretch = _get_value_from_dist(event.time_stretch)
time_stretch = _get_value_from_dist(
event.time_stretch, self.random_state
)
# compute duration after stretching
event_duration_stretched = event_duration * time_stretch

Expand Down Expand Up @@ -1343,7 +1360,9 @@ def _instantiate_event(self, event, isbackground=False,
# foreground events it's not.
event_time = -np.Inf
while event_time < 0:
event_time = _get_value_from_dist(event.event_time)
event_time = _get_value_from_dist(
event.event_time, self.random_state
)

# Make sure the selected event time + event duration are is not greater
# than the total duration of the soundscape, if it is adjust the event
Expand Down Expand Up @@ -1375,15 +1394,15 @@ def _instantiate_event(self, event, isbackground=False,
ScaperWarning)

# determine snr
snr = _get_value_from_dist(event.snr)
snr = _get_value_from_dist(event.snr, self.random_state)

# get role (which can only take "foreground" or "background" and
# is set internally, not by the user).
role = event.role

# determine pitch_shift
if event.pitch_shift is not None:
pitch_shift = _get_value_from_dist(event.pitch_shift)
pitch_shift = _get_value_from_dist(event.pitch_shift, self.random_state)
else:
pitch_shift = None

Expand All @@ -1397,6 +1416,7 @@ def _instantiate_event(self, event, isbackground=False,
role=role,
pitch_shift=pitch_shift,
time_stretch=time_stretch)

# Return
return instantiated_event

Expand Down
121 changes: 118 additions & 3 deletions scaper/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .scaper_exceptions import ScaperError
import scipy
import numpy as np
import numbers
import random


@contextmanager
Expand Down Expand Up @@ -148,8 +150,119 @@ def _populate_label_list(folder_path, label_list):
# ensure consistent ordering of labels
label_list.sort()

# TODO: add a seed parameter here for seeding
def _trunc_norm(mu, sigma, trunc_min, trunc_max):

def _check_random_state(seed):
"""Turn seed into a np.random.RandomState instance
Parameters
----------
seed : None | int | instance of RandomState
If seed is None, return the RandomState singleton used by np.random.
If seed is an int, return a new RandomState instance seeded with seed.
If seed is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if seed is None or seed is np.random:
return np.random.mtrand._rand
if isinstance(seed, (numbers.Integral, np.integer)):
return np.random.RandomState(seed)
if isinstance(seed, np.random.RandomState):
return seed
raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
' instance' % seed)


def _const(item, random_state):
'''
Return a value sampled from a constant distribution (just the item).
Parameters
----------
item : any
What to return
random_state : mtrand.RandomState
RandomState object used to sample from this distribution (ignored).
This is here to match the other function specifications.
Returns
-------
value : any
item, returned.
'''
return item


def _uniform(minimum, maximum, random_state):
'''
Return a random value sampled from a uniform distribution
between ```minimum``` and ```maximum```.
Parameters
----------
minimum : float
Minimum of uniform distribution
maximum : float
Maximum of uniform distribution
random_state : mtrand.RandomState
RandomState object used to sample from this distribution.
Returns
-------
value : float
A random value sampled from the uniform distribution defined
by ```minimum```, ```maximum```.
'''
return random_state.uniform(minimum, maximum)


def _normal(mu, sigma, random_state):
'''
Return a random value sampled from a normal distribution with
mean ```mu``` and standard deviation ```sigma```.
Parameters
----------
mu : float
The mean of the truncated normal distribution
sig : float
The standard deviation of the truncated normal distribution
random_state : mtrand.RandomState
RandomState object used to sample from this distribution.
Returns
-------
value : float
A random value sampled from the normal distribution defined
by ```mu```, ```sigma```.
'''
return random_state.normal(mu, sigma)


def _choose(list_of_options, random_state):
'''
Return a random item from ```list_of_options```, using random_state.
Parameters
----------
list_of_options : list
List of items to choose from.
random_state : mtrand.RandomState
RandomState object used to sample from this distribution.
Returns
-------
value : any
A random item chosen from ```list_of_options```.
'''
index = random_state.randint(len(list_of_options))
return list_of_options[index]


def _trunc_norm(mu, sigma, trunc_min, trunc_max, random_state):
'''
Return a random value sampled from a truncated normal distribution with
mean ```mu``` and standard deviation ```sigma``` whose values are limited
Expand All @@ -165,6 +278,8 @@ def _trunc_norm(mu, sigma, trunc_min, trunc_max):
The minimum value allowed for the distribution (lower boundary)
trunc_max : float
The maximum value allowed for the distribution (upper boundary)
random_state : mtrand.RandomState
RandomState object used to sample from this distribution.
Returns
-------
Expand All @@ -178,7 +293,7 @@ def _trunc_norm(mu, sigma, trunc_min, trunc_max):
# values for a standard normal distribution (mu=0, sigma=1), so we need
# to recompute a and b given the user specified parameters.
a, b = (trunc_min - mu) / float(sigma), (trunc_max - mu) / float(sigma)
return scipy.stats.truncnorm.rvs(a, b, mu, sigma)
return scipy.stats.truncnorm.rvs(a, b, mu, sigma, random_state=random_state)


def max_polyphony(ann):
Expand Down
15 changes: 8 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,40 +358,41 @@ def test_trim(atol=1e-5, rtol=1e-8):


def test_get_value_from_dist():

rng = scaper.util._check_random_state(0)
# const
x = scaper.core._get_value_from_dist(('const', 1))
x = scaper.core._get_value_from_dist(('const', 1), rng)
assert x == 1

# choose
for _ in range(10):
x = scaper.core._get_value_from_dist(('choose', [1, 2, 3]))
x = scaper.core._get_value_from_dist(('choose', [1, 2, 3]), rng)
assert x in [1, 2, 3]

# uniform
for _ in range(10):
x = scaper.core._get_value_from_dist(('choose', [1, 2, 3]))
x = scaper.core._get_value_from_dist(('choose', [1, 2, 3]), rng)
assert x in [1, 2, 3]

# normal
for _ in range(10):
x = scaper.core._get_value_from_dist(('normal', 5, 1))
x = scaper.core._get_value_from_dist(('normal', 5, 1), rng)
assert scaper.util.is_real_number(x)

# truncnorm
for _ in range(10):
x = scaper.core._get_value_from_dist(('truncnorm', 5, 10, 0, 10))
x = scaper.core._get_value_from_dist(('truncnorm', 5, 10, 0, 10), rng)
assert scaper.util.is_real_number(x)
assert 0 <= x <= 10

# COPY TESTS FROM test_validate_distribution (to ensure validation applied)
def __test_bad_tuple_list(tuple_list):
rng = scaper.util._check_random_state(0)
for t in tuple_list:
if isinstance(t, tuple):
print(t, len(t))
else:
print(t)
pytest.raises(ScaperError, scaper.core._get_value_from_dist, t)
pytest.raises(ScaperError, scaper.core._get_value_from_dist, t, random_state=rng)

# not tuple = error
nontuples = [[], 5, 'yes']
Expand Down
Loading

0 comments on commit 622b969

Please sign in to comment.