Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add temporal torchmodel #626

Open
wants to merge 10 commits into
base: torch
Choose a base branch
from

Conversation

tallyhawley
Copy link
Member

@tallyhawley tallyhawley commented Jun 11, 2024

Description

Adds torch backend for FadingTemporal, TorchFadingTemporal.

Refactors the TorchBaseModel forward signature to add a state kwarg for temporal models & to make the e_locs argument optional in the parent. Subclasses can choose which kwargs to use in the forward.

Type of Change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Checklist

  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

For detailed information on these and other aspects, see the contribution guidelines: https://pulse2percept.readthedocs.io/en/latest/developers/contributing.html

TODO: add offline fading, implement engine & torch model call from FadingTemporal
@tallyhawley tallyhawley changed the base branch from master to torch June 11, 2024 17:53
Copy link

codecov bot commented Jun 22, 2024

Codecov Report

Attention: Patch coverage is 50.00000% with 25 lines in your changes missing coverage. Please review.

Please upload report for BASE (torch@2cb0772). Learn more about missing BASE report.

Files Patch % Lines
pulse2percept/models/temporal.py 43.18% 25 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##             torch     #626   +/-   ##
========================================
  Coverage         ?   91.19%           
========================================
  Files            ?      128           
  Lines            ?    11946           
  Branches         ?        0           
========================================
  Hits             ?    10894           
  Misses           ?     1052           
  Partials         ?        0           
Flag Coverage Δ
unittests 91.19% <50.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@tallyhawley tallyhawley marked this pull request as ready for review June 22, 2024 02:29
@jgranley
Copy link
Member

I'm a little confused why this is failing right now. Theres an easy fix to the numpy array not being c-contiguous (just copy it with order='c' in predict_percept), but the test where this is failing shouldnt even be running. We pass --benchmark-disable to pytest (pyproject.toml) , and it is skipping most of the other benchmarks, so i'm not sure why its still running this one. Eventually I'll change the workflow to also include benchmarking, but maybe for now we can look into why this test is even being ran?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants