Skip to content

wsntxxn/AudioCaption

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

87 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Audio Captioning Recipe

This repository provides a simple and easy-to-use recipe for audio captioning: data pre-processing, training, evaluation and inference.

Since we refactor our code several times, codebase of previous challenges are maintained in separate branches: DCASE2021 and DCASE2022. Please check these branches for reference to challenge submissions.

Quick usage with Hugging Face🤗

For quick usage, we provide an inference interface with Hugging Face. You can use it easily. Only some necessary repositories need to be installed:

pip install numpy torch torchaudio einops transformers efficientnet_pytorch

Lightweight EffB2-Transformer model

For standard captioning, we recommend our latest lightweight model for fast inference:

import torch
from transformers import AutoModel, PreTrainedTokenizerFast
import torchaudio


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# use the model trained on AudioCaps
model = AutoModel.from_pretrained(
    "wsntxxn/effb2-trm-audio-captioning",
    trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "wsntxxn/audiocaps-simple-tokenizer"
)

# inference on a single audio clip
wav, sr = torchaudio.load("/path/to/file.wav")
wav = torchaudio.functional.resample(wav, sr, model.config.sample_rate)
if wav.size(0) > 1:
    wav = wav.mean(0).unsqueeze(0)

with torch.no_grad():
    word_idxs = model(
        audio=wav,
        audio_length=[wav.size(1)],
    )

caption = tokenizer.decode(word_idxs[0], skip_special_tokens=True)
print(caption)

# inference on a batch
wav1, sr1 = torchaudio.load("/path/to/file1.wav")
wav1 = torchaudio.functional.resample(wav1, sr1, model.config.sample_rate)
wav1 = wav1.mean(0) if wav1.size(0) > 1 else wav1[0]

wav2, sr2 = torchaudio.load("/path/to/file2.wav")
wav2 = torchaudio.functional.resample(wav2, sr2, model.config.sample_rate)
wav2 = wav2.mean(0) if wav2.size(0) > 1 else wav2[0]

wav_batch = torch.nn.utils.rnn.pad_sequence([wav1, wav2], batch_first=True)

with torch.no_grad():
    word_idxs = model(
        audio=wav_batch,
        audio_length=[wav1.size(0), wav2.size(0)],
    )

captions = tokenizer.batch_decode(word_idxs, skip_special_tokens=True)
print(captions)

Alternatively, you can use the model trained on Clotho:

model = AutoModel.from_pretrained(
    "wsntxxn/effb2-trm-clotho-captioning",
    trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "wsntxxn/clotho-simple-tokenizer"
)

Temporal-sensitive and controllable model

We also provide a temporal-enhanced captioning model for specific (simultaneous / sequential) temporal relationship description:

model = AutoModel.from_pretrained(
    "wsntxxn/cnn14rnn-tempgru-audiocaps-captioning",
    trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "wsntxxn/audiocaps-simple-tokenizer"
)

wav, sr = torchaudio.load("/path/to/file.wav")
wav = torchaudio.functional.resample(wav, sr, model.config.sample_rate)
if wav.size(0) > 1:
    wav = wav.mean(0).unsqueeze(0)

with torch.no_grad():
    word_idxs = model(
        audio=wav,
        audio_length=[wav.size(1)],
    )

caption = tokenizer.decode(word_idxs[0], skip_special_tokens=True)
print(caption)

You can also manually assign a temporal tag:

with torch.no_grad():
    word_idxs = model(
        audio=wav,
        audio_length=[wav.size(1)],
        temporal_tag=[2], # desribe "sequential" if there are sequential events, otherwise use the most complex relationship
    )

The temporal tag is defined as:

Temporal Tag Definition
0 Only 1 Event
1 Simultaneous Events
2 Sequential Events
3 More Complex Events

From 0 to 3, the relationship indicated by the tag becomes more and more complex. The model will infer the tag automatically. If temporal_tag is not provided as the input, the model will use the inferred tag. Otherwise, the model will try to follow the input tag if possible.

If you find the temporal model useful, please cite this paper:

@inproceedings{xie2023enhance,
    author = {Zeyu Xie and Xuenan Xu and Mengyue Wu and Kai Yu},
    title = {Enhance Temporal Relations in Audio Captioning with Sound Event Detection},
    year = 2023,
    booktitle = {Proc. INTERSPEECH},
    pages = {4179--4183},
}

The above instruction is for quick inference with pre-trained models. If you want to further develop your own captioning model, please refer to the following instructions.

Install

Checkout this repository and install the required packages:

$ git clone https://github.com/wsntxxn/AudioCaption
$ cd AudioCaption
$ pip install -r requirements.txt

Install the repository as a package:

$ pip install -e .

PTB tokenizer and SPICE evaluation also requires java, which can be installed by conda

$ conda install bioconda::java-jdk

Data pre-processing

We now support Clotho and AudioCaps. See details in data/README.md.

Training

The training configuration is written in a YAML file and passed to the training script. Examples are in eg_configs.

For example, train a model with Cnn14_Rnn encoder and Transformer decoder on Clotho:

$ python python_scripts/train_eval/run.py train \
    --config eg_configs/clotho_v2/waveform/cnn14rnn_trm.yaml

where the CNN14 is initialized by the PANNs checkpoint, which should be manually downloaded to experiments/pretrained_encoder/Cnn14_mAP=0.431.pth.

Evaluation

Assume the experiment directory is $EXP_PATH. Evaluation under the configuration in eg_configs/clotho_v2/waveform/test.yaml:

$ python python_scripts/train_eval/run.py evaluate \
    --experiment_path $EXP_PATH \
    --eval_config eg_configs/clotho_v2/waveform/test.yaml

Inference

Inference using the checkpoint $CKPT:

$ python python_scripts/inference/inference.py \
    --input input.wav 
    --output output.json
    --checkpoint $CKPT

Using off-the-shelf models

We release the models trained on Clotho and AudioCaps for easy use. They use contrastive pre-trained feature extractor:

$ mkdir pretrained_feature_extractors
$ wget https://github.com/wsntxxn/AudioCaption/releases/download/v0.0.2/contrastive_pretrain_cnn14_bertm.pth -O pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth

AudioCaps:

$ wget https://github.com/wsntxxn/AudioCaption/releases/download/v0.0.2/audiocaps_cntrstv_cnn14rnn_trm.zip
$ unzip audiocaps_cntrstv_cnn14rnn_trm.zip
$ python python_scripts/inference/inference.py \
    --input input.wav \
    --output output.json \
    --checkpoint audiocaps_cntrstv_cnn14rnn_trm/swa.pth

Clotho:

$ wget https://github.com/wsntxxn/AudioCaption/releases/download/v0.0.2/clotho_cntrstv_cnn14rnn_trm.zip
$ unzip clotho_cntrstv_cnn14rnn_trm.zip
$ python python_scripts/inference/inference.py \
    --input input.wav \
    --output output.json \
    --checkpoint clotho_cntrstv_cnn14rnn_trm/swa.pth

If you find these models useful, please cite our technical report:

@techreport{xu2022sjtu,
    author={Xu, Xuenan and Xie, Zeyu and Wu, Mengyue and Yu, Kai},
    title={The {SJTU} System for {DCASE2022} Challenge Task 6: Audio Captioning with Audio-Text Retrieval Pre-training},
    institution={DCASE2022 Challenge},
    year={2022}
}

Demo

The code to create a demo interface using gradio is provided:

$ python demo.py --ckpt $EXP_PATH/swa.pth

Then the demo is running on http://localhost:7860.

Related Papers

The following papers are related to this repository: