Skip to content

Commit

Permalink
Merge pull request #138 from cle1109/drop_return_values
Browse files Browse the repository at this point in the history
Various methods in ooapi.py now return self
  • Loading branch information
cbrnr committed Feb 16, 2016
2 parents 4c0ff49 + 6302ff5 commit ba939d7
Showing 1 changed file with 54 additions and 11 deletions.
65 changes: 54 additions & 11 deletions scot/ooapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ def set_locations(self, locations):
----------
locations : array_like
3D Electrode locations. Each row holds the x, y, and z coordinates of an electrode.
Returns
-------
self : Workspace
The Workspace object.
"""
self.locations_ = locations
return self

def set_premixing(self, premixing):
""" Set premixing matrix.
Expand All @@ -149,8 +155,14 @@ def set_premixing(self, premixing):
----------
premixing : array_like, shape = [n_signals, n_channels]
Matrix that maps data signals to physical channels.
Returns
-------
self : Workspace
The Workspace object.
"""
self.premixing_ = premixing
return self

def set_data(self, data, cl=None, time_offset=0):
""" Assign data to the workspace.
Expand All @@ -166,6 +178,11 @@ def set_data(self, data, cl=None, time_offset=0):
Class labels associated with each trial.
time_offset : float, optional
Trial starting time; used for labelling the x-axis of time/frequency plots.
Returns
-------
self : Workspace
The Workspace object.
"""
self.data_ = atleast_3d(data)
self.cl_ = np.asarray(cl if cl is not None else [None]*self.data_.shape[0])
Expand All @@ -179,6 +196,8 @@ def set_data(self, data, cl=None, time_offset=0):
if self.unmixing_ is not None:
self.activations_ = dot_special(self.unmixing_.T, self.data_)

return self

def set_used_labels(self, labels):
""" Specify which trials to use in subsequent analysis steps.
Expand All @@ -188,11 +207,17 @@ def set_used_labels(self, labels):
----------
labels : list of class labels
Marks all trials that have a label that is in the `labels` list for further processing.
Returns
-------
self : Workspace
The Workspace object.
"""
mask = np.zeros(self.cl_.size, dtype=bool)
for l in labels:
mask = np.logical_or(mask, self.cl_ == l)
self.trial_mask_ = mask
return self

def do_mvarica(self, varfit='ensemble'):
""" Perform MVARICA
Expand All @@ -209,8 +234,8 @@ def do_mvarica(self, varfit='ensemble'):
Returns
-------
result : class
see :func:`mvarica` for a description of the return value.
self : Workspace
The Workspace object.
Raises
------
Expand All @@ -232,7 +257,7 @@ def do_mvarica(self, varfit='ensemble'):
self.activations_ = dot_special(self.unmixing_.T, self.data_)
self.mixmaps_ = []
self.unmixmaps_ = []
return result
return self

def do_cspvarica(self, varfit='ensemble'):
""" Perform CSPVARICA
Expand All @@ -249,8 +274,8 @@ def do_cspvarica(self, varfit='ensemble'):
Returns
-------
result : class
see :func:`cspvarica` for a description of the return value.
self : Workspace
The Workspace object.
Raises
------
Expand Down Expand Up @@ -278,7 +303,7 @@ def do_cspvarica(self, varfit='ensemble'):
self.activations_ = dot_special(self.unmixing_.T, self.data_)
self.mixmaps_ = []
self.unmixmaps_ = []
return result
return self

def do_ica(self):
""" Perform ICA
Expand All @@ -287,8 +312,8 @@ def do_ica(self):
Returns
-------
result : class
see :func:`plainica` for a description of the return value.
self : Workspace
The Workspace object.
Raises
------
Expand All @@ -306,7 +331,7 @@ def do_ica(self):
self.connectivity_ = None
self.mixmaps_ = []
self.unmixmaps_ = []
return result
return self

def remove_sources(self, sources):
""" Remove sources from the decomposition.
Expand All @@ -319,6 +344,11 @@ def remove_sources(self, sources):
sources : {slice, int, array of ints}
Indices of components to remove.
Returns
-------
self : Workspace
The Workspace object.
Raises
------
RuntimeError
Expand All @@ -335,9 +365,15 @@ def remove_sources(self, sources):
self.connectivity_ = None
self.mixmaps_ = []
self.unmixmaps_ = []
return self

def fit_var(self):
""" Fit a var model to the source activations.
""" Fit a VAR model to the source activations.
Returns
-------
self : Workspace
The Workspace object.
Raises
------
Expand All @@ -348,9 +384,15 @@ def fit_var(self):
raise RuntimeError("VAR fitting requires source activations (run do_mvarica first)")
self.var_.fit(data=self.activations_[self.trial_mask_, :, :])
self.connectivity_ = Connectivity(self.var_.coef, self.var_.rescov, self.nfft_)
return self

def optimize_var(self):
""" Optimize the var model's hyperparameters (such as regularization).
""" Optimize the VAR model's hyperparameters (such as regularization).
Returns
-------
self : Workspace
The Workspace object.
Raises
------
Expand All @@ -361,6 +403,7 @@ def optimize_var(self):
raise RuntimeError("VAR fitting requires source activations (run do_mvarica first)")

self.var_.optimize(self.activations_[self.trial_mask_, :, :])
return self

def get_connectivity(self, measure_name, plot=False):
""" Calculate spectral connectivity measure.
Expand Down

0 comments on commit ba939d7

Please sign in to comment.