Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding recipe for Listenable Maps for Audio Classifiers #2538

Draft
wants to merge 271 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
271 commits
Select commit Hold shift + click to select a range
32c31fc
add NMF image logging for debug
fpaissan Dec 11, 2023
7a5a9c8
fix bug in viz L2I
fpaissan Dec 11, 2023
00dfdbe
log the number of finetuning masks
fpaissan Dec 11, 2023
8288675
lower crosscor thr
fpaissan Dec 11, 2023
3e541de
fix acc
fpaissan Dec 11, 2023
aff2f07
align L2I debugging w/ PIQ script
fpaissan Dec 12, 2023
e3b981a
fixed accuracy computation for L2I
fpaissan Dec 12, 2023
23b542e
L2I with variable number of components (K=200)
fpaissan Dec 12, 2023
da12c72
debugging l2i...
fpaissan Dec 17, 2023
f4fc9a9
update hparams
fpaissan Dec 18, 2023
024f64c
fixed oracle source
fpaissan Dec 19, 2023
3b3a8c4
fixed wrong sources and running finetuning experiments..
fpaissan Dec 19, 2023
ec01553
add AST as classifier
fpaissan Dec 21, 2023
7ea4972
hparams ast -- still not converging
fpaissan Dec 21, 2023
d0dc205
add ast augmentation
fpaissan Dec 21, 2023
d96bffd
synced merge
fpaissan Dec 21, 2023
69cc6e7
update training script after merge
fpaissan Dec 21, 2023
58117ab
with augmentations is better
fpaissan Dec 21, 2023
1dea3fc
just pushing hparams
fpaissan Dec 22, 2023
68d0d8e
classification with CE
fpaissan Dec 22, 2023
1935a7b
conv2d fix for CE
fpaissan Dec 22, 2023
eb120c8
playing with AST augmentation
fpaissan Dec 26, 2023
728fb0b
fixed thresholding
fpaissan Dec 26, 2023
1fe07e4
starting to experiment with no wham noise stuff
ycemsubakan Jan 3, 2024
ec23e86
add wham noise option in classifier training, dot prod correlation in…
ycemsubakan Jan 8, 2024
891d469
single mask training
ycemsubakan Jan 12, 2024
8f0b0c9
added zero grad
ycemsubakan Jan 12, 2024
99feb50
added the entropy loss
ycemsubakan Jan 12, 2024
2c617e0
implemented a psi function for cnn14
ycemsubakan Jan 14, 2024
57d0327
Update README.md
ycemsubakan Jan 14, 2024
e6add1f
added stft-mel transformation learning
ycemsubakan Jan 18, 2024
8fc8d8f
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Jan 18, 2024
4dd0594
add latest eval setup - working on gradient-based
fpaissan Jan 19, 2024
198ae6b
removed unused brain -- was causing issues in weights loading..
fpaissan Jan 19, 2024
5352014
training l2i on this classifier
fpaissan Jan 19, 2024
f5d6730
add l2i eval -- removing mosaic; not well defined in the case of L2I
fpaissan Jan 19, 2024
08b5fb6
removed old png file
fpaissan Jan 19, 2024
9f2873e
debugging eval weight loading..
fpaissan Jan 19, 2024
bd6c959
was always using vq
fpaissan Jan 19, 2024
bd40cf6
fixed eval AO
fpaissan Jan 19, 2024
db1ba5e
fixed eval -- now everything's fine also for L2I
fpaissan Jan 20, 2024
d77a822
better numerical stability
fpaissan Jan 20, 2024
431cd4d
handling quantus assertionerror
fpaissan Jan 20, 2024
ba16262
Merge pull request #3 from fpaissan/eval_in_cemstuff
ycemsubakan Jan 20, 2024
2f9dade
add saliency from captum
fpaissan Jan 20, 2024
ba674b2
updated smoothgrad for captum
fpaissan Jan 20, 2024
721ce4d
added norm to saliency
fpaissan Jan 20, 2024
f13dc47
IG from captum
fpaissan Jan 20, 2024
7053bb3
starting gradient-base eval on cnn14...
fpaissan Jan 20, 2024
17b27f5
Merge pull request #4 from fpaissan/eval_in_cemstuff
ycemsubakan Jan 20, 2024
410219d
commit before merge
ycemsubakan Jan 21, 2024
f687b60
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Jan 21, 2024
a659533
works on cnn14 -- but have a bad checkpoint
fpaissan Jan 22, 2024
f632940
fixed l2i as well
fpaissan Jan 22, 2024
596b56c
fixed acc in l2i
fpaissan Jan 22, 2024
163374b
fix not listenable
fpaissan Jan 22, 2024
5f30445
updated logging for eval
fpaissan Jan 22, 2024
1362d62
a bit less verbose
fpaissan Jan 22, 2024
a6b8581
printing at sample level
fpaissan Jan 22, 2024
5f7c5cc
fix logging - was missing avg
fpaissan Jan 22, 2024
61a2c00
was messing up in the forward
fpaissan Jan 23, 2024
399fc9f
now running train_piq.py
fpaissan Jan 23, 2024
c94e788
minor corrections
fpaissan Jan 23, 2024
dd63174
fix l2i training with wham!
fpaissan Jan 23, 2024
59f7d5c
fixed l2i computation
fpaissan Jan 23, 2024
a962410
linters
fpaissan Jan 23, 2024
af5c0a9
add check for wham usage in eval
fpaissan Jan 23, 2024
6047a48
add sample saving during eval
fpaissan Jan 23, 2024
ad8cf66
bug fixes
fpaissan Jan 23, 2024
b895803
added predictions info to the logging
fpaissan Jan 23, 2024
3ddc0b8
fixed id for overlap test
fpaissan Jan 23, 2024
a3ccf71
cutting sample before saving
fpaissan Jan 23, 2024
bd7aa57
fixed l2i sampling rate
fpaissan Jan 23, 2024
4bf6815
fixed random seed so eval will match
fpaissan Jan 23, 2024
dbc5b96
running on full set
fpaissan Jan 23, 2024
e77bfc7
faithfulness fix
ycemsubakan Jan 23, 2024
f36f3cd
remove pdb
ycemsubakan Jan 23, 2024
c891dd5
fix smoothgrad and IG
fpaissan Jan 24, 2024
a5ebd89
fix nmf for pre-training
fpaissan Jan 24, 2024
2529b85
removed nmf reconstructions
fpaissan Jan 24, 2024
af88a98
truncated gaussian fix for smoothgrad
fpaissan Jan 24, 2024
10685de
fix nans in sensitivity
fpaissan Jan 24, 2024
56a43a0
better l2i psi network
fpaissan Jan 24, 2024
d70f52e
saving to a different folder. helps not overriding experiments..
fpaissan Jan 25, 2024
3bf362b
fix l2i
fpaissan Jan 25, 2024
9119cb9
fix csv logging of exps
fpaissan Jan 25, 2024
db3d64a
add guided backprop
fpaissan Jan 25, 2024
f3e9bab
added gradcam. guided backprop and guided gradcam need debugging
fpaissan Jan 25, 2024
fe975d4
l2i encoder 1D
fpaissan Jan 26, 2024
123c340
mel only - ao
fpaissan Jan 26, 2024
3ac94a2
eval for mel only
fpaissan Jan 26, 2024
7905f9b
changed logging to simple write
fpaissan Jan 26, 2024
acb46e6
hardcoded checkpoint - to run on cc
fpaissan Jan 26, 2024
0d37e6c
save everything in one folder
fpaissan Jan 26, 2024
3d8b470
remove joblib import
fpaissan Jan 26, 2024
5b49c5f
fixed eval?
fpaissan Jan 26, 2024
9cac2d2
fix eval again..
fpaissan Jan 26, 2024
0370b24
maybe now?
fpaissan Jan 26, 2024
8ef85d1
trying on cc
fpaissan Jan 26, 2024
aab3e38
add eval_outdir
fpaissan Jan 27, 2024
0df00be
runs full eval
fpaissan Jan 27, 2024
edb2b9f
l2i with updated psi
fpaissan Jan 27, 2024
2558421
update gitignore
fpaissan Jan 27, 2024
6c65204
l2i logging different loss values
fpaissan Jan 27, 2024
45661aa
add us8k classifier
fpaissan Jan 27, 2024
783b40d
us8k interpretations
fpaissan Jan 27, 2024
31eb747
fixed guided backprop and guided gradcam
fpaissan Jan 27, 2024
275c125
add shap
fpaissan Jan 27, 2024
4a35624
normalizing shap attributions
fpaissan Jan 27, 2024
9d5cb9f
adding us8k prepare in interp..
fpaissan Jan 27, 2024
59bdb6d
eval on ID
fpaissan Jan 27, 2024
24bf013
fixed backward compatibility
fpaissan Jan 28, 2024
44f7680
added multiclass classification
ycemsubakan Jan 28, 2024
64bcc99
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Jan 28, 2024
bd3d512
eval xplorer v1
fpaissan Jan 29, 2024
eccd1ee
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
fpaissan Jan 29, 2024
65d64f3
eval xplorer v2
fpaissan Jan 29, 2024
44e9cc6
implemented multi label interpretation
ycemsubakan Jan 29, 2024
ba18b05
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Jan 29, 2024
591782f
update the loss function in multilabel interpretations
ycemsubakan Jan 29, 2024
ac3c13d
evaluation explorer - minor fixes
fpaissan Jan 30, 2024
da5cdba
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
fpaissan Jan 30, 2024
aaf68e7
add roar
fpaissan Jan 30, 2024
bb93285
roar test
fpaissan Jan 30, 2024
2b50755
just removing a print...
fpaissan Jan 30, 2024
ff4deb0
add roar script
fpaissan Jan 30, 2024
1cf6491
adding the user study parsing script
ycemsubakan Jan 31, 2024
fa343cb
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Jan 31, 2024
2038e8a
savefigs
ycemsubakan Jan 31, 2024
f72b647
fix to roar hparam
fpaissan Jan 31, 2024
0b3b823
minor
fpaissan Jan 31, 2024
2e17d3b
extract samples for user study
fpaissan Jan 31, 2024
f268ac3
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
fpaissan Jan 31, 2024
50e87ad
fix bug roar
fpaissan Jan 31, 2024
dba7269
fixed roar
fpaissan Jan 31, 2024
973aeac
fix another copy-paste error
fpaissan Jan 31, 2024
6b6e452
MRT eval
fpaissan Jan 31, 2024
ce9a038
roar with random baseline
fpaissan Jan 31, 2024
a9e984d
fix np seed
fpaissan Jan 31, 2024
42b1718
computes mrt metrics
fpaissan Jan 31, 2024
a5fced1
saving masks for mrt viz
fpaissan Jan 31, 2024
6c147f3
remove rand baseline roar
fpaissan Feb 1, 2024
1594993
abs
fpaissan Feb 1, 2024
bdbc946
gradcam eval
fpaissan Feb 1, 2024
969044c
fix class
Feb 1, 2024
16778b3
add mrt to l2i
fpaissan Feb 1, 2024
87609ea
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
fpaissan Feb 1, 2024
388c079
train piq us8k
fpaissan Feb 1, 2024
c46388d
param in mrt_evaluator
fpaissan Feb 1, 2024
cda9d46
add viz
fpaissan Feb 1, 2024
8490c0d
adding the latest
ycemsubakan Mar 22, 2024
769b479
Merge branch 'icml2024_cemstuff' of github.com:fpaissan/audio_interpr…
ycemsubakan Mar 22, 2024
d6407a5
fixing path problems for multilabelstuff
ycemsubakan Mar 22, 2024
79cedd4
changed the loss function to output 10 masks
ycemsubakan Mar 26, 2024
c641474
more standard maskout term
ycemsubakan Mar 26, 2024
f85ff88
changed encoder loading to local
ycemsubakan Mar 27, 2024
b6d80b7
added accuracy computation
ycemsubakan Mar 27, 2024
1f3515b
removed unnecessary evaluation methods
ycemsubakan Mar 27, 2024
ff1a332
added all ones mask and average energy computation
ycemsubakan Mar 28, 2024
e656f1c
fixed the bug for whitenoise
ycemsubakan Mar 28, 2024
6b5f34b
pushing eval later
ycemsubakan Mar 28, 2024
ae00b6c
l2i new ood
fpaissan Mar 28, 2024
974a15a
merge lmac and new focalnet stuff -new sb version
fpaissan Apr 17, 2024
476456d
removing useless files
fpaissan Apr 17, 2024
9080e7f
cleaning up classification as well
fpaissan Apr 17, 2024
de91a6b
removing useless hparams in interpret
fpaissan Apr 17, 2024
9d5b797
more useless files
fpaissan Apr 17, 2024
788a086
old linters
fpaissan Apr 17, 2024
8ce9887
fix paths
fpaissan Apr 17, 2024
b058284
fix paths
fpaissan Apr 17, 2024
ecdcffd
update Cnn14
fpaissan Apr 17, 2024
403f8cf
restored old piq file
fpaissan Apr 17, 2024
67fe5fa
wham on PIQ
fpaissan Apr 17, 2024
fb91eba
Adding LMAC - needs refactor (#5)
fpaissan Apr 17, 2024
038482c
removed useless code. needs to be modified to run with self.interpret…
fpaissan Apr 20, 2024
ce39d48
parent class and piq mods
fpaissan Apr 20, 2024
2547db0
fix fn names
fpaissan Apr 20, 2024
969c907
simplify viz
fpaissan Apr 20, 2024
2d86af1
move data prep function
fpaissan Apr 20, 2024
9745a8d
L2I with parent class
fpaissan Apr 20, 2024
56d314a
removed 1 decoderator
fpaissan Apr 20, 2024
a6f3b63
commenting viz_ints. need std
fpaissan Apr 20, 2024
eb1b6eb
unifying viz
fpaissan Apr 20, 2024
2238989
change fn call
fpaissan Apr 20, 2024
2e9594d
removed abstract class
fpaissan Apr 20, 2024
de5512c
disable viz_ints
fpaissan Apr 20, 2024
e58f857
rm bl comp
fpaissan Apr 20, 2024
880adf2
l2i viz
fpaissan Apr 20, 2024
976b84a
remove l2i fid
fpaissan Apr 20, 2024
b7f8d5e
add lens
fpaissan Apr 20, 2024
e7fdd24
removed some metrics
fpaissan Apr 20, 2024
17a2883
extra_metric fix
fpaissan Apr 20, 2024
947bb1d
removed another metric
fpaissan Apr 20, 2024
5beb812
removed another metric
fpaissan Apr 20, 2024
ba22ce6
add readme -- was falling behing somehow...
fpaissan Apr 22, 2024
9212a81
starting to std viz
fpaissan Apr 22, 2024
f0f069d
inp fid
fpaissan Apr 22, 2024
f75151d
fix ic
fpaissan Apr 22, 2024
608c35b
removing metrics as they will be compute elsewhere
fpaissan Apr 22, 2024
5e91d3c
viz piq
fpaissan Apr 22, 2024
dd8fbe9
viz piq remove mask_ll
fpaissan Apr 22, 2024
be21b5a
uniform piq viz
fpaissan Apr 22, 2024
7507ea5
PIQ fits parent class
fpaissan Apr 23, 2024
0ef3244
starting to unify metrics eval
fpaissan Apr 27, 2024
a8290ba
fixed metrics -- missing SPS and COMP
fpaissan Apr 28, 2024
5c76015
linters
fpaissan Apr 28, 2024
0b0e3da
lmac into template
fpaissan Apr 28, 2024
528777f
update lmac hparams
fpaissan Apr 28, 2024
70113af
minor
fpaissan Apr 28, 2024
da92665
not converging
fpaissan Apr 28, 2024
aa81451
converging now
fpaissan Apr 28, 2024
1dd6c1d
computing metrics
fpaissan Apr 28, 2024
2802d10
computing extra metrics
fpaissan Apr 28, 2024
3c8317f
extra metrics for l2i
fpaissan Apr 28, 2024
1c8264a
starting SPS and COMP
fpaissan Apr 28, 2024
7ed2b95
Adds quantus SPS and COMP metrics to the refactoring code (#6)
fpaissan Apr 28, 2024
5d0b2f6
removed unused file
fpaissan Apr 28, 2024
82b8d27
still throws strange error
fpaissan Apr 28, 2024
6200ff8
ood eval
fpaissan Apr 29, 2024
d102129
fixed paddedbatch stuff
fpaissan Apr 29, 2024
3d53cf2
eval L2I
fpaissan Apr 29, 2024
41a360d
remove useless files
fpaissan Apr 29, 2024
f41a28d
using right wham preparation
fpaissan Apr 29, 2024
c23db68
removing model wrapper as it is not needed
fpaissan Apr 29, 2024
e9ee9d6
fix ID samples
fpaissan Apr 29, 2024
7963ea9
fix linters
fpaissan Apr 30, 2024
1b7f7f6
model finetuning test
fpaissan Apr 30, 2024
05c59bd
pretrained_PIQ -> pretrained_interpreter
fpaissan Apr 30, 2024
ebfbc7d
update README.md
fpaissan Apr 30, 2024
09c3e8a
added README instructions for training with WHAM!
ycemsubakan Apr 30, 2024
a7cc35d
removing the dataset tag on experiment name
ycemsubakan Apr 30, 2024
88d24b4
Merge branch 'develop' of github.com:speechbrain/speechbrain into dev…
fpaissan Apr 30, 2024
bb7d888
Fix Checks (#8)
fpaissan Apr 30, 2024
0ded7fe
added wham hparams to vit.yaml
ycemsubakan Apr 30, 2024
75c47f1
Merge branch 'refactor_lmac' of github.com:fpaissan/audio_interpretab…
ycemsubakan Apr 30, 2024
c4cdf7d
added focalnet wham hyperparams
ycemsubakan May 1, 2024
d0fcc0b
Merge branch 'develop' of github.com:speechbrain/speechbrain into dev…
fpaissan May 2, 2024
c155871
Merge branch 'develop' of github.com:fpaissan/audio_interpretability …
fpaissan May 2, 2024
9125eee
add eval info
fpaissan May 2, 2024
8edcde5
add automatic wham download
fpaissan May 2, 2024
108fc92
additional instructions on README
ycemsubakan May 2, 2024
4659e24
Merge branch 'refactor_lmac' of github.com:fpaissan/audio_interpretab…
ycemsubakan May 2, 2024
0aef616
wham prepare uses explicit parameters
fpaissan May 2, 2024
2f4d174
Merge branch 'refactor_lmac' of github.com:fpaissan/audio_interpretab…
fpaissan May 2, 2024
56552b4
wham docstrings
fpaissan May 2, 2024
a5b0962
edited the instructions on different contamination types
ycemsubakan May 2, 2024
c5b29a7
Merge branch 'refactor_lmac' of github.com:fpaissan/audio_interpretab…
ycemsubakan May 2, 2024
20b6771
removing the table
fpaissan May 2, 2024
cff6d77
removing the table
fpaissan May 2, 2024
5e66454
revert changes to gitignore
fpaissan May 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion recipes/ESC50/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ python train.py hparams/vit.yaml --data_folder /yourpath/ESC50

---------------------------------------------------------------------------------------------------------

### To train with WHAM! noise

In order to train the classifier with WHAM! noise, you can train your classifier with the following command:

```shell
python train.py hparams/modelofchoice.yaml --data_folder /yourpath/ESC50 --add_wham_noise True --wham_folder /yourpath/wham_noise
```


## Results

| Hyperparams file | Accuracy (%) | Training time | HuggingFace link | Model link | GPUs |
Expand Down Expand Up @@ -139,4 +148,4 @@ If you use **SpeechBrain**, please cite:
- Code: https://github.com/speechbrain/speechbrain/
- HuggingFace: https://huggingface.co/speechbrain/

---------------------------------------------------------------------------------------------------------
---------------------------------------------------------------------------------------------------------
3 changes: 3 additions & 0 deletions recipes/ESC50/classification/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
matplotlib
pandas
scikit-learn
torchvision
transformers
wget
28 changes: 20 additions & 8 deletions recipes/ESC50/classification/hparams/cnn14.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# Authors:
# * Cem Subakan 2022, 2023
# * Francesco Paissan 2022, 2023
# * Francesco Paissan 2022, 2023, 2024
# (based on the SpeechBrain UrbanSound8k recipe)
# #################################

Expand All @@ -16,11 +16,21 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
data_folder: !PLACEHOLDER # e.g., /localscratch/ESC-50-master
audio_data_folder: !ref <data_folder>/audio

experiment_name: cnn14-esc50
experiment_name: !ref cnn14-esc50
output_folder: !ref ./results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

add_wham_noise: False
test_only: False

wham_folder: null
wham_audio_folder: !ref <wham_folder>/tr
wham_metadata: "metadata/wham_speechbrain.csv"


sample_rate: 16000
signal_length_s: 5

# Tensorboard logs
use_tensorboard: False
Expand All @@ -47,7 +57,6 @@ lr: 0.0002
base_lr: 0.00000001
max_lr: !ref <lr>
step_size: 65000
sample_rate: 44100

device: "cpu"

Expand All @@ -58,6 +67,7 @@ right_frames: 0
deltas: False

use_melspectra: True
use_log1p_mel: True

# Number of classes
out_n_neurons: 50
Expand All @@ -84,10 +94,9 @@ embedding_model: !new:speechbrain.lobes.models.Cnn14.Cnn14
mel_bins: !ref <n_mels>
emb_dim: 2048

classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
input_size: 2048
out_neurons: !ref <out_n_neurons>
lin_blocks: 1
classifier: !new:torch.nn.Linear
in_features: 2048
out_features: !ref <out_n_neurons>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
Expand All @@ -107,6 +116,7 @@ compute_fbank: !new:speechbrain.processing.features.Filterbank
n_mels: 80
n_fft: !ref <n_fft>
sample_rate: !ref <sample_rate>
log_mel: False

modules:
compute_stft: !ref <compute_stft>
Expand Down Expand Up @@ -145,7 +155,9 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
counter: !ref <epoch_counter>

use_pretrained: True
# If you do not want to use the pretrained encoder you can simply delete pretrained_encoder field.
# If you do not want to use the pretrained encoder
# you can simply delete pretrained_encoder field,
# or set use_pretrained=False
embedding_model_path: speechbrain/cnn14-esc50/embedding_model.ckpt
pretrained_encoder: !new:speechbrain.utils.parameter_transfer.Pretrainer
collect_in: !ref <save_folder>
Expand Down
13 changes: 13 additions & 0 deletions recipes/ESC50/classification/hparams/conv2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
data_folder: !PLACEHOLDER # e.g., /localscratch/ESC-50-master
audio_data_folder: !ref <data_folder>/audio

wham_folder: null
wham_audio_folder: !ref <wham_folder>/tr
wham_metadata: "metadata/wham_speechbrain.csv"

experiment_name: conv2dv2_classifier-16k
output_folder: !ref ./results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

test_only: False

# Tensorboard logs
use_tensorboard: False
Expand Down Expand Up @@ -48,6 +53,9 @@ base_lr: 0.000002
max_lr: !ref <lr>
step_size: 65000
sample_rate: 16000
signal_length_s: 5

add_wham_noise: False

device: "cpu"

Expand All @@ -65,6 +73,7 @@ dataloader_options:

use_pretrained: True
use_melspectra: False
use_log1p_mel: False
embedding_model: !new:speechbrain.lobes.models.PIQ.Conv2dEncoder_v2
dim: 256

Expand All @@ -73,6 +82,10 @@ classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
out_neurons: !ref <out_n_neurons>
lin_blocks: 1

#classifier: !new:torch.nn.Linear
#in_features: 256
#out_features: !ref <out_n_neurons>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

Expand Down
12 changes: 12 additions & 0 deletions recipes/ESC50/classification/hparams/focalnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ output_folder: !ref ./results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

add_wham_noise: True
test_only: False

wham_folder: null
wham_audio_folder: !ref <wham_folder>/tr
wham_metadata: "metadata/wham_speechbrain.csv"

use_melspectra: False
use_log1p_mel: False

# Tensorboard logs
use_tensorboard: False
tensorboard_logs_folder: !ref <output_folder>/tb_logs/
Expand Down Expand Up @@ -49,6 +59,8 @@ max_lr: !ref <lr>
step_size: 65000
sample_rate: 16000

signal_length_s: 5

# Number of classes
out_n_neurons: 50

Expand Down
11 changes: 11 additions & 0 deletions recipes/ESC50/classification/hparams/vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ output_folder: !ref ./results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

add_wham_noise: True
use_melspectra: False
use_log1p_mel: False
test_only: False

wham_folder: null
wham_audio_folder: !ref <wham_folder>/tr
wham_metadata: "metadata/wham_speechbrain.csv"

# Tensorboard logs
use_tensorboard: False
tensorboard_logs_folder: !ref <output_folder>/tb_logs/
Expand All @@ -47,7 +56,9 @@ lr: 0.0002
base_lr: 0.00000001
max_lr: !ref <lr>
step_size: 65000

sample_rate: 16000
signal_length_s: 5

# Number of classes
out_n_neurons: 50
Expand Down
69 changes: 54 additions & 15 deletions recipes/ESC50/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import torchvision
from confusion_matrix_fig import create_cm_fig
from esc50_prepare import prepare_esc50
from hyperpyyaml import load_hyperpyyaml
from sklearn.metrics import confusion_matrix
from wham_prepare import combine_batches, prepare_wham

import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
Expand All @@ -42,18 +45,23 @@ def compute_forward(self, batch, stage):
if hasattr(self.hparams, "augmentation") and stage == sb.Stage.TRAIN:
wavs, lens = self.hparams.augmentation(wavs, lens)

# Extract features
# augment batch with WHAM!
if hasattr(self.hparams, "add_wham_noise"):
if self.hparams.add_wham_noise:
wavs = combine_batches(wavs, iter(self.hparams.wham_dataset))

X_stft = self.modules.compute_stft(wavs)
X_stft_power = sb.processing.features.spectral_magnitude(
net_input = sb.processing.features.spectral_magnitude(
X_stft, power=self.hparams.spec_mag_power
)
if (
hasattr(self.hparams, "use_melspectra")
and self.hparams.use_melspectra
):
net_input = self.modules.compute_fbank(X_stft_power)
else:
net_input = torch.log1p(X_stft_power)
net_input = self.modules.compute_fbank(net_input)

if (not self.hparams.use_melspectra) or self.hparams.use_log1p_mel:
net_input = torch.log1p(net_input)

# Embeddings + sound classifier
if hasattr(self.modules.embedding_model, "config"):
Expand All @@ -80,11 +88,18 @@ def compute_forward(self, batch, stage):
else:
# SpeechBrain model
embeddings = self.modules.embedding_model(net_input)
if isinstance(embeddings, tuple):
embeddings, _ = embeddings

if embeddings.ndim == 4:
embeddings = embeddings.mean((-1, -2))

# run through classifier
outputs = self.modules.classifier(embeddings)

if outputs.ndim == 2:
outputs = outputs.unsqueeze(1)

return outputs, lens

def compute_objectives(self, predictions, batch, stage):
Expand All @@ -93,7 +108,15 @@ def compute_objectives(self, predictions, batch, stage):
uttid = batch.id
classid, _ = batch.class_string_encoded

loss = self.hparams.compute_cost(predictions, classid, lens)
# Target augmentation
N_augments = int(predictions.shape[0] / classid.shape[0])
classid = torch.cat(N_augments * [classid], dim=0)

# loss = self.hparams.compute_cost(predictions.squeeze(1), classid, lens)
target = F.one_hot(
classid.squeeze(), num_classes=self.hparams.out_n_neurons
)
loss = -(F.log_softmax(predictions.squeeze(), 1) * target).sum(1).mean()

if stage != sb.Stage.TEST:
if hasattr(self.hparams.lr_annealing, "on_batch_end"):
Expand Down Expand Up @@ -378,8 +401,6 @@ def label_pipeline(class_string):
hparams["tensorboard_logs_folder"]
)

from esc50_prepare import prepare_esc50

run_on_main(
prepare_esc50,
kwargs={
Expand All @@ -399,6 +420,20 @@ def label_pipeline(class_string):
datasets, label_encoder = dataio_prep(hparams)
hparams["label_encoder"] = label_encoder

if "wham_folder" in hparams:
hparams["wham_dataset"] = prepare_wham(
hparams["wham_folder"],
hparams["add_wham_noise"],
hparams["sample_rate"],
hparams["signal_length_s"],
hparams["wham_audio_folder"],
)

if hparams["wham_dataset"] is not None:
assert hparams["signal_length_s"] == 5, "Fix wham sig length!"

assert hparams["out_n_neurons"] == 50, "Fix number of outputs classes!"

class_labels = list(label_encoder.ind2lab.values())
print("Class Labels:", class_labels)

Expand All @@ -411,17 +446,21 @@ def label_pipeline(class_string):
)

# Load pretrained encoder if it exists in the yaml file
if not hasattr(ESC50_brain.modules, "embedding_model"):
ESC50_brain.hparams.embedding_model.to(ESC50_brain.device)

if "pretrained_encoder" in hparams and hparams["use_pretrained"]:
run_on_main(hparams["pretrained_encoder"].collect_files)
hparams["pretrained_encoder"].load_collected()

ESC50_brain.fit(
epoch_counter=ESC50_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
if not hparams["test_only"]:
ESC50_brain.fit(
epoch_counter=ESC50_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)

# Load the best checkpoint for evaluation
test_stats = ESC50_brain.evaluate(
Expand Down
1 change: 1 addition & 0 deletions recipes/ESC50/classification/wham_prepare.py