Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing in a higher frequency #3

Closed
Hatins opened this issue Apr 25, 2024 · 11 comments
Closed

Testing in a higher frequency #3

Hatins opened this issue Apr 25, 2024 · 11 comments

Comments

@Hatins
Copy link

Hatins commented Apr 25, 2024

Hi @NikolaZubic

Thanks for your nice work and opening source!
I have a question when testing in a higher frequency, which parameters we should change?

I have found two places about the 'step_scale':

class S5SSM(torch.nn.Module):
    def __init__(
        self,
        lambdaInit: torch.Tensor,
        V: torch.Tensor,
        Vinv: torch.Tensor,
        h: int,
        p: int,
        dt_min: float,
        dt_max: float,
        liquid: bool = False,
        factor_rank: Optional[int] = None,
        discretization: Literal["zoh", "bilinear"] = "bilinear",
        bcInit: Initialization = "factorized",
        degree: int = 1,
        bidir: bool = False,
        step_scale: float = 1.0,
        bandlimit: Optional[float] = None,
    ):

and

    def forward(self, signal, prev_state, step_scale: float | torch.Tensor = 1.0): 

in S5SSM moudle.

If I want two testing in frequecy 200Hz (training in frequency 20), how should I adjust the parameter?
Should I change the above step_scale from 1.0 to 0.1 both?

@Hatins
Copy link
Author

Hatins commented Apr 26, 2024

I try this way but It seems do not work/(ㄒoㄒ)/~~

@looper99
Copy link

@Hatins I also had problems with that, didn't work, and then I realized I didn't change the frequency of event data meaning re-processing it. Did you do that?

@NikolaZubic @magehrig Can Mamba from Gu & Dao be used here to also evaluate at different frequencies, and will Mamba give even better results than S5? Btw, thanks for sharing S5 as zip in PyTorch, it is really valuable.

@Hatins
Copy link
Author

Hatins commented Apr 26, 2024

Hi @looper99
I trained the model using event data sampled at a frequency of 20Hz. However, during evaluation, I recreated event frames with a frequency of 200Hz. During training, I utilized the 20Hz event dataset, while during testing, I employed the latter. Additionally, when testing with high-frequency event data, I adjusted the time_scale parameter from 1 to 10 (I also experimented with 0.1). Despite these adjustments, the results remained suboptimal, prompting me to raise this issue.

Actually, I haven't quite grasped your question. Are you suggesting evaluating high-frequency event data without recreating it?

@fengwei0907
Copy link

I think what he means is, can your frequency generalization technique also be applied in the s6 (ssm) model, also known as mamba? Does this model perform well even at high frequencies?

@fengwei0907
Copy link

Have you ever encountered this problem? When I adjust the training: precision: 32 to 16, an error will be reported. Do you know how to solve it? thanks

@Hatins
Copy link
Author

Hatins commented Apr 27, 2024

Have you ever encountered this problem? When I adjust the training: precision: 32 to 16, an error will be reported. Do you know how to solve it? thanks

Yes, the S5 model seems to need a higher precision, that 32.

@looper99
Copy link

I did it in the same way as in the pendulum task outlined in the following link: GitHub - lindermanlab/S5 - Issue #13. It worked just fine for event data.

I experimented and saw that there is a way to train with precision 16, but you have to add torch.cuda.amp.autocast(enabled=False) around combine function in jax_func:
GitHub - uzh-rpg/ssms_event_cameras - jax_func.py line 68.

But, when I tried to train this way, I can't get the mAP from the paper, which I get when using precision 32. I can only do GEN1 since I have one A100.

@Hatins
Copy link
Author

Hatins commented May 1, 2024

I did it in the same way as in the pendulum task outlined in the following link: GitHub - lindermanlab/S5 - Issue #13. It worked just fine for event data.

I experimented and saw that there is a way to train with precision 16, but you have to add torch.cuda.amp.autocast(enabled=False) around combine function in jax_func: GitHub - uzh-rpg/ssms_event_cameras - jax_func.py line 68.

But, when I tried to train this way, I can't get the mAP from the paper, which I get when using precision 32. I can only do GEN1 since I have one A100.

Hi @looper99
You mean you has successfully tested the S5 model in a higher input frequency? If that, could you tell me how do you modify the original model in this project? I still can not do that. thx.

@NikolaZubic
Copy link
Member

NikolaZubic commented May 1, 2024

@Hatins :
Regarding the experiments with different frequencies.
Before doing a 4-stage hierarchical backbone and converting the raw event stream to event representations, we had an S5 block that works directly on raw events. The model was trained like that and we needed a lot of resources, unfortunately, there is still no smart solution for this. After that, this SSM block was tuned for different frequencies, for example, step_scale=0.1 for 200 Hz. This means you need to process directly the raw events with SSMs with a different step_scale before converting them to event representations.
If you try to work with raw events only, without converting to event representations even on small datasets, you would easily fill 40 GB VRAM, imagine what would happen on bigger ones.
So, these experiments were more the test that we can do this, but it is still not fully practical, unfortunately. There is no free lunch, continuous SSMs (S4/S4D/S5) are powerful but still require a lot of memory.

Mamba requires less memory, but its not a continuous model anymore.

@NikolaZubic
Copy link
Member

NikolaZubic commented May 1, 2024

@looper99 Regarding the Mamba, it is a purely discrete model in the sense that the selectivity of B, C, and delta on the input enables it to learn what to remember and what to forget. All the continuous-time theory breaks on it. So, you cannot use Mamba for evaluation at different frequencies. What is good about Mamba is that the VRAM requirements are much less, so it is more practical in that sense. No attention, no MLPs, but expressive.

Mamba is for discrete data.
S4/S4D/S5 is for continuous data.

Why the authors decided to move from this continuous setting with Mamba, I don't know (probably to compete with Transformers on purely discrete data problems), but the ideal model would be the one that can do both discrete and continuous very well and can be adapted by re-scaling the delta. This itself is an open research question.

@fengwei0907
Copy link

I did it in the same way as in the pendulum task outlined in the following link: GitHub - lindermanlab/S5 - Issue #13. It worked just fine for event data.

I experimented and saw that there is a way to train with precision 16, but you have to add torch.cuda.amp.autocast(enabled=False) around combine function in jax_func: GitHub - uzh-rpg/ssms_event_cameras - jax_func.py line 68.

But, when I tried to train this way, I can't get the mAP from the paper, which I get when using precision 32. I can only do GEN1 since I have one A100.

I tried your method, but still encountered errors in the modifications made.I don't know why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants