Skip to content

Commit

Permalink
Merge pull request #7 from berndie/train_and_predict
Browse files Browse the repository at this point in the history
Fix for the final training of the train method for backwards models
  • Loading branch information
OleBialas committed May 24, 2023
2 parents a3d030c + 266bb01 commit 2fb82c2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mtrf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def train(
verbose=verbose,
)
best_regularization = list(regularization)[np.argmin(mse)]
self._train(stimulus, response, fs, tmin, tmax, best_regularization)
self._train(xs, ys, fs, tmin, tmax, best_regularization)
return r, mse

def _train(self, xs, ys, fs, tmin, tmax, regularization):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def test_predict():
prediction, r, mse = trf.predict(stimulus, response, average=False)
assert r.shape[-1] == mse.shape[-1] == trf.weights.shape[-1]

# Backwards prediction
trf = TRF(-1)
trf.train(stimulus, response, fs, tmin, tmax, regularization)
for average in [True, list(range(randint(stimulus[0].shape[-1])))]:
prediction, r, mse = trf.predict(stimulus, response, average=average)
assert len(prediction) == len(stimulus)
assert all([p[0].shape == s[0].shape for p, s in zip(prediction, stimulus)])
assert np.isscalar(r) and np.isscalar(mse)
prediction, r, mse = trf.predict(stimulus, response, average=False)
assert r.shape[-1] == mse.shape[-1] == trf.weights.shape[-1]


def test_test():
tmin = np.random.uniform(-0.1, 0.05)
Expand Down

0 comments on commit 2fb82c2

Please sign in to comment.