Skip to content

yixiaoer/bark-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🐶 Bark JAX

This project is the JAX implementation of Bark.

It is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Roadmap

  • Model architecture
    • EnCodec (encodec_24khz), implemented in encodec jax
    • 3 transformer models
      • Text to semantic tokens
      • Semantic to coarse tokens
      • Coarse to fine tokens

Install

This project requires Python 3.12, JAX 0.4.26.

Create venv:

python3.12 -m venv venv

install dependencies:

TPU VM:

. venv/bin/activate
pip install -U pip
pip install -U wheel
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers

Model Architecture

Bark is a series of three transformer models and one EnCodec model, to turn text into audio.

Transformer model: Text to semantic tokens

BarkSemanticModel(
  (input_embeds_layer): Embedding(129600, 1024)
  (position_embeds_layer): Embedding(1024, 1024)
  (drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-23): 24 x BarkBlock(
      (layernorm_1): BarkLayerNorm()
      (layernorm_2): BarkLayerNorm()
      (attn): BarkSelfAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (att_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (mlp): BarkMLP(
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
        (gelu): GELU(approximate='none')
      )
    )
  )
  (layernorm_final): BarkLayerNorm()
  (lm_head): Linear(in_features=1024, out_features=10048, bias=False)
)

Transformer model: Semantic to coarse tokens

BarkCoarseModel(
  (input_embeds_layer): Embedding(12096, 1024)
  (position_embeds_layer): Embedding(1024, 1024)
  (drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-23): 24 x BarkBlock(
      (layernorm_1): BarkLayerNorm()
      (layernorm_2): BarkLayerNorm()
      (attn): BarkSelfAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (att_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (mlp): BarkMLP(
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
        (gelu): GELU(approximate='none')
      )
    )
  )
  (layernorm_final): BarkLayerNorm()
  (lm_head): Linear(in_features=1024, out_features=12096, bias=False)
)

Transformer model: Coarse to fine tokens

BarkFineModel(
  (input_embeds_layers): ModuleList(
    (0-7): 8 x Embedding(1056, 1024)
  )
  (position_embeds_layer): Embedding(1024, 1024)
  (drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-23): 24 x BarkBlock(
      (layernorm_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (layernorm_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): BarkSelfAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (att_proj): Linear(in_features=1024, out_features=3072, bias=False)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (mlp): BarkMLP(
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
        (gelu): GELU(approximate='none')
      )
    )
  )
  (layernorm_final): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (lm_heads): ModuleList(
    (0-6): 7 x Linear(in_features=1024, out_features=1056, bias=False)
  )
)

EnCodec

For detailed information, go to EnCodec JAX implementation, EnCodec official, and paper High Fidelity Neural Audio Compression.

About

JAX implementation of Bark

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages