Skip to content

Commit

Permalink
SaShiMi release
Browse files Browse the repository at this point in the history
  • Loading branch information
krandiash committed Feb 22, 2022
1 parent 147f590 commit 74d2706
Show file tree
Hide file tree
Showing 52 changed files with 6,964 additions and 82 deletions.
40 changes: 32 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,43 @@

This repository provides implementations and experiments for the following papers.

## S4
## SaShiMi (arXiv)

![SaShiMi](assets/sashimi.png "SaShiMi Architecture")
> **It's Raw! Audio Generation with State-Space Models**\
> Karan Goel, Albert Gu, Chris Donahue, Christopher Ré\
> Paper: https://arxiv.org/abs/xxxx.yyyyy

This comment has been minimized.

Copy link
@lucidrains

lucidrains Feb 22, 2022

👀

This comment has been minimized.

Copy link
@krandiash

krandiash Feb 22, 2022

Author Contributor

Literally just came in: updated 🏃

## S4 (ICLR 2022 Oral)

![Structured State Spaces](assets/properties.png "Properties of Structured State Spaces")
> **Efficiently Modeling Long Sequences with Structured State Spaces**\
> Albert Gu, Karan Goel, Christopher Ré\
> Paper: https://arxiv.org/abs/2111.00396
## LSSL
## LSSL (NeurIPS 2021)

![Linear State Space Layer](assets/splash.png "Properties of Sequential State Spaces")
> **Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer**\
> Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré\
> Paper: https://arxiv.org/abs/2110.13985
## HiPPO
## HiPPO (NeurIPS 2020 Spotlight)
![HiPPO Framework](assets/hippo.png "HiPPO Framework")
> **HiPPO: Recurrent Memory with Optimal Polynomial Projections**\
> Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré\
> Paper: https://arxiv.org/abs/2008.07669

## Table of Contents
- [Repository Setup](#setup)
- S4
- [Experiments](#s4-experiments)
- [Training](#training)
- [Models](#models)
- [SaShiMi](sashimi/README)
- [Repository Structure](#overall-repository-structure)
- [Citation](#citation)
## Setup

### Requirements
Expand Down Expand Up @@ -257,6 +273,7 @@ src/ main source code for models, datasets, etc.
sequence/ sequence model backbones and layers including RNNs and S4/LSSL
tasks/ encoder/decoder modules to interface between data and model backbone
utils/
sashimi/ SaShiMi README and additional code (generation, metrics, MTurk)
train.py training loop entrypoint
```

Expand All @@ -265,11 +282,18 @@ train.py training loop entrypoint
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
@article{gu2021efficiently,
@article{goel2022sashimi,
title={It's Raw! Audio Generation with State-Space Models},
author={Goel, Karan and Gu, Albert and Donahue, Chris and R{\'e}, Christopher},
journal={arXiv preprint arXiv:xxxx.yyyyy},
year={2022}
}
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2111.00396},
year={2021}
author={Gu, Albert and Goel, Karan and R\'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}
@article{gu2021combining,
Expand All @@ -282,7 +306,7 @@ If you use this codebase, or otherwise found our work valuable, please cite:
@article{gu2020hippo,
title={HiPPO: Recurrent Memory with Optimal Polynomial Projections},
author={Gu, Albert and Dao, Tri and Ermon, Stefano and Rudra, Atri and Re, Christopher},
author={Gu, Albert and Dao, Tri and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={Advances in neural information processing systems},
volume={33},
year={2020}
Expand Down
Binary file added assets/sashimi.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ train:
ignore_warnings: False # Disable python warnings
# These control state
state:
mode: null # [ None | 'none' | 'reset' | 'bptt' ]
mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
chunk_len: null # [ None | int ] chunk length for tbptt (used by TBPTTDataLoader)
overlap_len: null # [ None | int ] overlap length for tbptt (used by TBPTTDataLoader)
n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
n_context_eval: ${.n_context}
# Convenience keys to allow grouping runs
Expand All @@ -41,6 +43,7 @@ train:
benchmark_step_T: 1 # Number of additional repeats to benchmark the step function
checkpoint_path: null # Path to checkpoint file: only used for visualization at the moment
visualizer: 'filters' # Which visualizer to use: [ 'filters' | 'forecasting' ]
disable_dataset: False # Disable dataset loading

# We primarily use wandb so this is moved to top level for convenience
# Set ~wandb or wandb=null or wandb.mode=disabled to disable logging
Expand Down
10 changes: 10 additions & 0 deletions configs/dataset/beethoven.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_name_: qautoaudio
path: beethoven
bits: 8
sample_len: 128000
train_percentage: 0.88
quantization: linear
drop_last: true
context_len: null
pad_len: null
__l_max: ${.sample_len}
5 changes: 5 additions & 0 deletions configs/dataset/sc09.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_name_: sc09
bits: 8
quantization: mu-law
pad_len: null
__l_max: 16000
10 changes: 10 additions & 0 deletions configs/dataset/youtubemix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_name_: qautoaudio
path: youtube_mix
bits: 8
sample_len: 131072
train_percentage: 0.88
quantization: mu-law
drop_last: true
context_len: null
pad_len: null
__l_max: ${.sample_len}
51 changes: 51 additions & 0 deletions configs/experiment/samplernn-beethoven.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_
defaults:
- /trainer: default
- /loader: torch
- /dataset: beethoven
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /model: samplernn

dataset:
quantization: linear

model:
bits: 8
quantization: linear
n_rnn: 1
frame_sizes:
- 8
- 2
- 2

train:
monitor: val/loss # Needed for plateau scheduler
mode: min
state:
mode: tbptt
chunk_len: 1024
overlap_len: 32 # this is model dependent (product of model.frame_sizes here)

task:
metrics:
- bpb
- accuracy
- accuracy@3
- accuracy@5
- accuracy@10


encoder: null
decoder: null

loader:
batch_size: 128
train_resolution: 1
eval_resolutions:
- 1

trainer:
gradient_clip_val: 1.0
gradient_clip_algorithm: value
47 changes: 47 additions & 0 deletions configs/experiment/samplernn-sc09.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_
defaults:
- /trainer: default
- /loader: torch
- /dataset: sc09
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /model: samplernn

model:
bits: 8
quantization: mu-law
n_rnn: 1
frame_sizes:
- 8
- 2
- 2

train:
monitor: val/loss # Needed for plateau scheduler
mode: min
state:
mode: tbptt
chunk_len: 1024
overlap_len: 32 # this is model dependent (product of model.frame_sizes here)

task:
metrics:
- bpb
- accuracy
- accuracy@3
- accuracy@5
- accuracy@10

encoder: null
decoder: null

loader:
batch_size: 128
train_resolution: 1
eval_resolutions:
- 1

trainer:
gradient_clip_val: 1.0
gradient_clip_algorithm: value
51 changes: 51 additions & 0 deletions configs/experiment/samplernn-youtubemix.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_
defaults:
- /trainer: default
- /loader: torch
- /dataset: youtubemix
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /model: samplernn

dataset:
quantization: mu-law

model:
bits: 8
quantization: mu-law
n_rnn: 1
frame_sizes:
- 8
- 2
- 2

train:
monitor: val/loss # Needed for plateau scheduler
mode: min
state:
mode: tbptt
chunk_len: 1024
overlap_len: 32 # this is model dependent (product of model.frame_sizes here)

task:
metrics:
- bpb
- accuracy
- accuracy@3
- accuracy@5
- accuracy@10


encoder: null
decoder: null

loader:
batch_size: 128
train_resolution: 1
eval_resolutions:
- 1

trainer:
gradient_clip_val: 1.0
gradient_clip_algorithm: value
59 changes: 59 additions & 0 deletions configs/experiment/sashimi-beethoven.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# @package _global_
defaults:
- /trainer: default
- /loader: torch
- /dataset: beethoven
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
- /model: sashimi

dataset:
quantization: linear
drop_last: False

model:
n_layers: 8
pool:
- 4
- 4
dropout: 0.0
prenorm: True

layer:
hurwitz: True
postact: glu


train:
monitor: val/loss
mode: min

task:
metrics:
- bpb
- accuracy
- accuracy@3
- accuracy@5
- accuracy@10

encoder: embedding

decoder:
_name_: sequence
mode: ragged

loader:
batch_size: 1
train_resolution: 1
eval_resolutions:
- 1

trainer:
max_epochs: 1000

optimizer:
lr: 0.004

scheduler:
patience: 20

1 comment on commit 74d2706

@lucidrains
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🍣

Please sign in to comment.