In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import random
import typing as tp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
RANDOM_SEED = 364298472
TRAIN = 0.8
VALIDATION = 0.1

dataset = torch.load("audio_dataset/codes/aug_codes.pt")
train_len = int(TRAIN * len(dataset))
val_len = int(VALIDATION * len(dataset))

random.seed(RANDOM_SEED)
random.shuffle(dataset)

train_codes, val_codes, test_codes = (
    dataset[:train_len],
    dataset[train_len : train_len + val_len],
    dataset[train_len + val_len:],
)

# FOR TESTING ONLY
train_codes = train_codes

len(train_codes), len(val_codes), len(test_codes)

(1368, 171, 171)

In [3]:
class SequenceDataset(TensorDataset):
    def __init__(
        self,
        data: tp.List[tp.Tuple[torch.Tensor, torch.Tensor]],
        device: torch.device,
        seq_len: int = 1500,
        stride: int = 750,
    ):
        src = []
        tgt = []
        for index, (backing, lead) in enumerate(data):
            if backing.shape[-1] < seq_len:
                print(f"Index {index} has seq_len {backing.shape[-1]}, skipping")
                continue
            src.append(backing.unfold(-1, seq_len, stride).transpose(0, 1))
            tgt.append(lead.unfold(-1, seq_len, stride).transpose(0, 1))
        src = torch.concat(src)
        tgt = torch.concat(tgt)
        src = torch.vmap(self.add_delay_interleaving)(src).to(device)
        tgt = torch.vmap(self.add_delay_interleaving)(tgt).to(device)
        return super().__init__(src, tgt)

    @staticmethod
    def add_delay_interleaving(
        streams: torch.Tensor, padding_idx: int = 2048
    ) -> torch.Tensor:
        num_streams = len(streams)
        new_streams = []
        for index, stream in enumerate(streams):
            new_streams.append(
                F.pad(stream, (index + 1, num_streams - index), value=padding_idx)
            )
        return torch.stack(new_streams)

    @staticmethod
    def remove_delay_interleaving(streams: torch.Tensor) -> torch.Tensor:
        num_streams = len(streams)
        stream_length = streams.shape[-1]
        new_streams = []
        for index, stream in enumerate(streams):
            new_streams.append(
                torch.narrow(
                    stream, -1, 1 + index, stream_length - (num_streams - 1) - 2
                )
            )
        return torch.stack(new_streams)


train_ds = SequenceDataset(train_codes, device=device)
val_ds = SequenceDataset(val_codes, device=device)
test_ds = SequenceDataset(test_codes, device=device)

Index 60 has seq_len 1320, skipping
Index 89 has seq_len 1264, skipping
Index 134 has seq_len 1432, skipping
Index 147 has seq_len 1320, skipping
Index 154 has seq_len 1288, skipping
Index 208 has seq_len 1408, skipping
Index 260 has seq_len 1248, skipping
Index 310 has seq_len 1288, skipping
Index 334 has seq_len 1432, skipping
Index 431 has seq_len 1392, skipping
Index 453 has seq_len 1392, skipping
Index 463 has seq_len 1488, skipping
Index 474 has seq_len 1320, skipping
Index 496 has seq_len 1248, skipping
Index 498 has seq_len 1368, skipping
Index 538 has seq_len 1496, skipping
Index 565 has seq_len 1304, skipping
Index 611 has seq_len 1488, skipping
Index 703 has seq_len 1264, skipping
Index 729 has seq_len 1248, skipping
Index 796 has seq_len 1384, skipping
Index 860 has seq_len 1304, skipping
Index 866 has seq_len 1304, skipping
Index 942 has seq_len 1376, skipping
Index 981 has seq_len 1376, skipping
Index 999 has seq_len 1408, skipping
Index 1026 has seq_len 1496, skipping
In

In [4]:
from model import MHAModel

# Save model state, optimizer state, and other info to a .pth file
def save_checkpoint(
    model, optimizer, epoch, loss, config, filepath="model_checkpoint/checkpoint.pth"
):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        "config": config,
    }
    torch.save(checkpoint, filepath)


def load_checkpoint(lr: int, optimizer=None, filepath="model_checkpoint/checkpoint.pth"):
    checkpoint = torch.load(filepath)
    config = checkpoint["config"]
    model = MHAModel(**config)
    model.load_state_dict(checkpoint["model_state_dict"])  # Load model state
    if optimizer is None:
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(
        checkpoint["optimizer_state_dict"]
    )  # Load optimizer state
    epoch = checkpoint["epoch"]  # Get saved epoch
    loss = checkpoint["loss"]  # Get saved loss (optional)
    return model, optimizer, epoch, loss, config

In [5]:
FROM_CHECKPOINT = True
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10
MODEL_CONFIG = dict(
    d_model=128,
    nhead=8,
    num_decoder_layers=8,
    num_encoder_layers=8,
    dim_feedforward=512,
    device=device,
)

previous_epochs = 0
model_config = None

if FROM_CHECKPOINT:
    model, optimizer, previous_epochs, loss, model_config = load_checkpoint(lr=LEARNING_RATE)
else:
    model = MHAModel(**MODEL_CONFIG)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    model_config = MODEL_CONFIG
criterion = nn.CrossEntropyLoss(ignore_index=2048).to(device)


# Learning rate scheduler (optional, for better convergence)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)



In [6]:
# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (src, tgt) in enumerate(train_loader):
        optimizer.zero_grad()

        output: torch.Tensor = model(src, tgt[:, :, :-1])

        # Reshaping output and target tensors for loss computation
        output = output.view(-1, model.vocab_size)
        tgt = tgt[:, :, 1:].contiguous().view(-1)
        loss = criterion(output, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item()
        
        predicted = output.argmax(-1)
        non_pad_mask = tgt != 2048
        correct = (predicted[non_pad_mask] == tgt[non_pad_mask]).sum().item()
        total = non_pad_mask.sum().item()
        
        accuracy = correct / (total + 1e-8)
        print(
            f"Epoch [{epoch + 1 + previous_epochs}/{NUM_EPOCHS + previous_epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], "
            f"Loss: {loss.item():.4f}, Accuracy: {100.0 * accuracy:.2f}%"
        )

    # Scheduler step (if using)
    # scheduler.step()

    # Print the loss every epoch
    if (epoch + 1) % 1 == 0:
        print(f"Epoch [{epoch + previous_epochs + 1}/{NUM_EPOCHS + previous_epochs}], Loss: {loss.item():.4f}")

    if (epoch + 1) % 1 == 0:
        save_checkpoint(model, optimizer, epoch + previous_epochs + 1, loss, model_config)


Epoch [359/368], Batch [1/249], Loss: 3.4811, Accuracy: 28.49%
Epoch [359/368], Batch [2/249], Loss: 4.2160, Accuracy: 19.97%
Epoch [359/368], Batch [3/249], Loss: 4.5531, Accuracy: 15.07%
Epoch [359/368], Batch [4/249], Loss: 4.2650, Accuracy: 18.81%
Epoch [359/368], Batch [5/249], Loss: 4.3478, Accuracy: 18.24%
Epoch [359/368], Batch [6/249], Loss: 3.9985, Accuracy: 21.84%
Epoch [359/368], Batch [7/249], Loss: 4.2303, Accuracy: 19.98%
Epoch [359/368], Batch [8/249], Loss: 3.6564, Accuracy: 30.61%
Epoch [359/368], Batch [9/249], Loss: 3.7484, Accuracy: 26.22%
Epoch [359/368], Batch [10/249], Loss: 3.5105, Accuracy: 28.09%
Epoch [359/368], Batch [11/249], Loss: 4.1679, Accuracy: 19.06%
Epoch [359/368], Batch [12/249], Loss: 3.7263, Accuracy: 25.02%
Epoch [359/368], Batch [13/249], Loss: 4.0366, Accuracy: 20.09%
Epoch [359/368], Batch [14/249], Loss: 3.9178, Accuracy: 22.78%
Epoch [359/368], Batch [15/249], Loss: 3.8929, Accuracy: 22.79%
Epoch [359/368], Batch [16/249], Loss: 4.1418, Ac

In [None]:
from transformers import EncodecModel
from audio_tokenization import convert_tensor_to_wav
import torchaudio

encodec = EncodecModel.from_pretrained("facebook/encodec_32khz")
for i, (backing, lead) in enumerate(train_codes):
    backing_wav = convert_tensor_to_wav(encodec, backing)
    torchaudio.save(f"tmp/{i}_backing.wav", backing_wav, sample_rate=32000)
    lead_wav = convert_tensor_to_wav(encodec, lead)
    torchaudio.save(f"tmp/{i}_lead.wav", lead_wav, sample_rate=32000)
    torchaudio.save(
        f"tmp/{i}.wav", torch.clip(backing_wav + lead_wav, -1.0, 1.0), sample_rate=32000
    )


  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [35]:
torchaudio.save(
        f"tmp/{i}.wav", torch.clip(backing_wav + lead_wav, -1.0, 1.0), sample_rate=32000
    )

In [6]:
src = model.add_delay_interleaving(train_codes[0][0][:, :1400].to(device)).unsqueeze(0)
print(src.shape)
B, num_streams, T = src.shape
first_token = torch.Tensor([[2048 for _ in range(4)]])
tgt = torch.zeros_like(src)
tgt[:, :, 0] = first_token
index = 1

while index < T:
    print(src.shape, tgt[:, :, :index].shape)
    tokens = model._get_next_token(
        src,
        tgt[:, :, :index].detach(),
    )
    print(tokens.shape)
    tgt[:, :, index] = tokens.squeeze(-1)
    index += 1

torch.Size([1, 4, 1405])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 1])
src embed: torch.Size([1, 1405, 128])
tgt emb: torch.Size([1, 1, 128])
transf out: torch.Size([1, 1, 128])
torch.Size([1, 4, 1])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 2])
src embed: torch.Size([1, 1405, 128])
tgt emb: torch.Size([1, 2, 128])
transf out: torch.Size([1, 2, 128])
torch.Size([1, 4, 1])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 3])
src embed: torch.Size([1, 1405, 128])
tgt emb: torch.Size([1, 3, 128])
transf out: torch.Size([1, 3, 128])
torch.Size([1, 4, 1])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 4])
src embed: torch.Size([1, 1405, 128])
tgt emb: torch.Size([1, 4, 128])
transf out: torch.Size([1, 4, 128])
torch.Size([1, 4, 1])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 5])
src embed: torch.Size([1, 1405, 128])
tgt emb: torch.Size([1, 5, 128])
transf out: torch.Size([1, 5, 128])
torch.Size([1, 4, 1])
torch.Size([1, 4, 1405]) torch.Size([1, 4, 6])
src embed: torch.Size([1, 1405, 128])
tgt emb: t

In [None]:
model.pe.pe.shape

torch.Size([1, 1600, 128])

In [23]:
model.add_delay_interleaving(train_codes[0][0][:, :1400].to(device)).shape

torch.Size([4, 1405])