Skip to content

Commit

Permalink
Support multiple layer extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
leo19941227 committed Jun 14, 2021
2 parents 6531904 + 7246614 commit 39b4087
Show file tree
Hide file tree
Showing 59 changed files with 1,047 additions and 1,738 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,20 @@ If you find this toolkit helpful to your research, please do consider to cite [o
* Install **sox** on your OS
* Install generally used packages for *pretrain*, *upstream* and *downstream*:

```
```bash
git clone https://github.com/s3prl/s3prl.git
cd s3prl/
pip install -r requirements.txt
cd ../

git clone https://github.com/pytorch/fairseq.git
cd fairseq/

# The version used by the repo maintainer currently.
# Please must not use the stable version 0.10.2 as it
# contains known bugs for wav2vec2 inference and ASR decoding
git checkout 8df9e3a4

pip install -e ./
cd ../
```
Expand Down
3 changes: 3 additions & 0 deletions downstream/emotion/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(self, upstream_dim, downstream_expert, downstream_variant, expdir,
dataset = IEMOCAPDataset(DATA_ROOT, train_path, self.datarc['pre_load'])
trainlen = int((1 - self.datarc['valid_ratio']) * len(dataset))
lengths = [trainlen, len(dataset) - trainlen]

torch.manual_seed(0)
self.train_dataset, self.dev_dataset = random_split(dataset, lengths)

self.test_dataset = IEMOCAPDataset(DATA_ROOT, test_path, self.datarc['pre_load'])

model_cls = eval(self.modelrc['select'])
Expand Down
6 changes: 3 additions & 3 deletions downstream/example/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

SAMPLE_RATE = 16000
EXAMPLE_WAV_MIN_SEC = 5
EXAMPLE_WAV_MAX_SEC = 15
EXAMPLE_WAV_MAX_SEC = 20
EXAMPLE_DATASET_SIZE = 1000


Expand All @@ -15,8 +15,8 @@ def __init__(self, **kwargs):
self.class_num = 48

def __getitem__(self, idx):
wav_sec = random.randint(EXAMPLE_WAV_MIN_SEC, EXAMPLE_WAV_MAX_SEC)
wav = torch.randn(SAMPLE_RATE * wav_sec)
samples = random.randint(EXAMPLE_WAV_MIN_SEC * SAMPLE_RATE, EXAMPLE_WAV_MAX_SEC * SAMPLE_RATE)
wav = torch.randn(samples)
label = random.randint(0, self.class_num - 1)
return wav, label

Expand Down
Loading

0 comments on commit 39b4087

Please sign in to comment.