In [14]:
# audio 
# encoder
# downsampling
# similar
# decoder
# loss?

audio_path = "../data/input.wav"

# ground-truth

In [15]:
import dac
from audiotools import AudioSignal

# Download a model
model_path = dac.utils.download(model_type="16khz")
model = dac.DAC.load(model_path)

In [25]:
import torchaudio
waveform, sampling_rate = torchaudio.load(audio_path)

In [28]:
x = model.preprocess(waveform.unsqueeze(1), sampling_rate)
gt = model.encoder(x)

x.shape, x.shape[-1]/320, gt.shape

(torch.Size([1, 1, 94720]), 296.0, torch.Size([1, 1024, 296]))

In [20]:
x.shape

torch.Size([1, 94720])

In [21]:
# Load audio signal file
signal = AudioSignal('../data/input.wav')

# Encode audio signal as one long file
# (may run out of GPU memory on long files)
signal.to(model.device)

x = model.preprocess(signal.audio_data, signal.sample_rate)
gt = model.encoder(x)

x.shape, x.shape[-1]/320, gt.shape

(torch.Size([1, 1, 94720]), 296.0, torch.Size([1, 1024, 296]))

In [22]:
signal.audio_data.shape


torch.Size([1, 1, 94480])

# Encoder

In [3]:
from encoder import Encoder

encoder = Encoder()
enc_out = encoder(audio_path)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


# vocab

In [4]:
import torch

checkpoint = torch.load("model_checkpoint.pth", map_location="cpu")
embedding = checkpoint["embedding"]['weight'][:29,:]
char_to_idx = checkpoint["char_to_idx"]
char_to_idx.pop('<pad>')
char_to_idx.pop('<bos>')
idx_to_char = {v: k for k, v in char_to_idx.items()}

In [5]:
import torch
import torch.nn.functional as F

def get_closest_vocab(encoder_output, vocab_embeddings, return_indices=False):
    """
    Given the encoder output, find the closest embedding from the vocabulary for each time step.

    Args:
        encoder_output (torch.Tensor): Shape (batch_size, embed_dim, seq)
        vocab_embeddings (torch.Tensor): Shape (vocab_size, embed_dim)
        return_indices (bool): If True, returns the indices. If False, returns the embeddings.

    Returns:
        torch.Tensor: Closest vocabulary indices (batch_size, seq) if return_indices=True
                      Closest vocabulary embeddings (batch_size, embed_dim, seq) if return_indices=False
    """
    batch_size, embed_dim, seq = encoder_output.shape

    # Normalize for cosine similarity (optional)
    encoder_output = F.normalize(encoder_output, dim=1)
    vocab_embeddings = F.normalize(vocab_embeddings, dim=-1)

    # Compute similarity: (batch_size, embed_dim, seq) @ (embed_dim, vocab_size) → (batch_size, seq, vocab_size)
    similarity = torch.einsum('bes,ve->bsv', encoder_output, vocab_embeddings)

    # Get the closest vocab indices
    closest_indices = similarity.argmax(dim=-1)  # (batch_size, seq)

    if return_indices:
        return closest_indices  # (batch_size, seq)
    else:
        return vocab_embeddings[closest_indices].permute(0, 2, 1)  # (batch_size, embed_dim, seq)


get_closest_vocab(enc_out, embedding)

tensor([[[ 0.0066,  0.0066,  0.0066,  ..., -0.0637,  0.0066,  0.0066],
         [-0.0197, -0.0197, -0.0197,  ..., -0.0248, -0.0197, -0.0197],
         [ 0.0143,  0.0143,  0.0143,  ...,  0.0267,  0.0143,  0.0143],
         ...,
         [ 0.0422,  0.0422,  0.0422,  ..., -0.0130,  0.0422,  0.0422],
         [-0.0515, -0.0515, -0.0515,  ..., -0.0334, -0.0515, -0.0515],
         [-0.0329, -0.0329, -0.0329,  ..., -0.0109, -0.0329, -0.0329]]])

In [6]:
def merge_similar_indices(indices):
    batch_size, seq = indices.shape
    merged_indices = []

    for b in range(batch_size):
        unique_indices = []
        prev_idx = None

        for t in range(seq):
            current_idx = indices[b, t].item()

            if prev_idx is None or current_idx != prev_idx:
                unique_indices.append(current_idx)

            prev_idx = current_idx

        merged_indices.append(unique_indices)

    return merged_indices

ind = get_closest_vocab(enc_out, embedding, return_indices=True).detach().cpu().numpy()
"".join([idx_to_char[i] for i in merge_similar_indices(ind)[0]]).replace("<sil>", " ")

"LBKIBDQBHCBD'KHIGQEKLBHOE 'BUHBXGBGHTJUFALBOCKMZHOB'B'NBLHBPHTBVRPBDWBIGBDYPSOCXTZBKBITBMWBAIDBDXSYUDXMXRGDMCGOVE'CB'BCBCBCBCYL"

# decoder

In [7]:
from decoder import decoder, calculate_params
# In the simplest case, the output value of the layer with input size ( N , C in , L ) (N,C in ​ ,L) and output ( N , C out , L out ) (N,C out ​ ,L out ​ ) can be precisely described as:

decoder = decoder(inp_dim=768, hidden_dim=256, out_dim=1024, num_blocks=1, kernel_size=3)
calculate_params(decoder), calculate_params(encoder)

Total params: 0.854784 in millions
Total params: 94.962304 in millions


(None, None)

# Training loop

In [8]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


criterion = nn.MSELoss()
optimizer = optim.Adam(list(decoder.parameters()) + list(encoder.parameters()), lr=0.001)

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
decoder = decoder.to(device)
encoder = encoder.to(device)
decoder.train()  # Set the model to training mode
encoder.train()
1

1

In [9]:
gt = gt.to(device).detach()
embedding = embedding.to(device).detach()

In [10]:
# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Zero the gradients
    optimizer.zero_grad()

    enc_out = encoder(audio_path)
    enc_out = enc_out.to(device)
    
    # Forward pass
    # output = decoder(enc_out) # Shape: (batch, out_dim, time)
    enc_out = enc_out + (get_closest_vocab(enc_out, embedding) - enc_out).detach() 
    output = decoder( enc_out )
    
    # Ensure same sequence length for ground truth and output
    seq_len = min(gt.shape[-1], output.shape[-1])
    gt = gt[:, :, :seq_len]
    output = output[:, :, :seq_len]
    # Compute the loss
    loss = criterion(output, gt)

    # Backward pass
    loss.backward()
    # Update the parameters
    optimizer.step()
    
    torch.cuda.empty_cache()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Epoch 1/1000, Loss: 7.196202754974365
Epoch 2/1000, Loss: 6.967313289642334
Epoch 3/1000, Loss: 6.7494401931762695
Epoch 4/1000, Loss: 6.53893518447876
Epoch 5/1000, Loss: 6.330143928527832
Epoch 6/1000, Loss: 6.178974151611328
Epoch 7/1000, Loss: 5.982743740081787
Epoch 8/1000, Loss: 5.8326802253723145
Epoch 9/1000, Loss: 5.702139377593994
Epoch 10/1000, Loss: 5.583150863647461
Epoch 11/1000, Loss: 5.484498977661133
Epoch 12/1000, Loss: 5.419800758361816
Epoch 13/1000, Loss: 5.319036483764648
Epoch 14/1000, Loss: 5.253880500793457
Epoch 15/1000, Loss: 5.172684192657471
Epoch 16/1000, Loss: 5.118607521057129
Epoch 17/1000, Loss: 5.115623474121094
Epoch 18/1000, Loss: 5.023965358734131
Epoch 19/1000, Loss: 5.007221698760986
Epoch 20/1000, Loss: 4.947972297668457
Epoch 21/1000, Loss: 4.883514404296875
Epoch 22/1000, Loss: 4.87011194229126
Epoch 23/1000, Loss: 4.810446739196777
Epoch 24/1000, Loss: 4.7577362060546875
Epoch 25/1000, Loss: 4.740022659301758
Epoch 26/1000, Loss: 4.6945080757

In [11]:
output.shape

torch.Size([1, 1024, 295])

In [12]:
ind = get_closest_vocab(enc_out, embedding, return_indices=True).detach().cpu().numpy()
"".join([idx_to_char[i] for i in merge_similar_indices(ind)[0]]).replace("<sil>", " ")

"BQHARAD'D'QHUDXBKHAEX'B'BDBDUKSQHXOXGQPHMAD'D'FEB'BSBSBHYPQHARAWSIGQBPYHTRORCGXTGDQFKITGSWSIDKFSQXOCG'YHGQXCRG'E'B'B'BGB"

In [13]:
model = dac.DAC.load(model_path)
# Decode audio signal
y = model.decode(output.cpu()).detach().numpy()

# play the numpy array as audio using ipython.display.Audio
import IPython.display as ipd
ipd.Audio(y[0,0,:], rate=16000)  # load a NumPy array

