In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import random
import typing as tp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
RANDOM_SEED = 364298472
TRAIN = 0.8
VALIDATION = 0.1

dataset = torch.load("audio_dataset/codes/aug_codes.pt")
train_len = int(TRAIN * len(dataset))
val_len = int(VALIDATION * len(dataset))

random.seed(RANDOM_SEED)
random.shuffle(dataset)

train_codes, val_codes, test_codes = (
    dataset[:train_len],
    dataset[train_len : train_len + val_len],
    dataset[train_len + val_len :],
)

# FOR TESTING ONLY
train_codes = train_codes

len(train_codes), len(val_codes), len(test_codes)

(1368, 171, 171)

In [3]:
class SequenceDataset(TensorDataset):
    def __init__(
        self,
        data: tp.List[tp.Tuple[torch.Tensor, torch.Tensor]],
        device: torch.device,
        seq_len: int = 1500,
        stride: int = 750,
    ):
        src = []
        tgt = []
        for index, (backing, lead) in enumerate(data):
            if backing.shape[-1] < seq_len:
                print(f"Index {index} has seq_len {backing.shape[-1]}, skipping")
                continue
            src.append(backing.unfold(-1, seq_len, stride).transpose(0, 1))
            tgt.append(lead.unfold(-1, seq_len, stride).transpose(0, 1))
        src = torch.concat(src)
        tgt = torch.concat(tgt)
        src = torch.vmap(self.add_delay_interleaving)(src).to(device)
        tgt = torch.vmap(self.add_delay_interleaving)(tgt).to(device)
        return super().__init__(src, tgt)

    @staticmethod
    def add_delay_interleaving(
        streams: torch.Tensor, padding_idx: int = 2048
    ) -> torch.Tensor:
        num_streams = len(streams)
        new_streams = []
        for index, stream in enumerate(streams):
            new_streams.append(
                F.pad(stream, (index + 1, num_streams - index), value=padding_idx)
            )
        return torch.stack(new_streams)

    @staticmethod
    def remove_delay_interleaving(streams: torch.Tensor) -> torch.Tensor:
        num_streams = len(streams)
        stream_length = streams.shape[-1]
        new_streams = []
        for index, stream in enumerate(streams):
            new_streams.append(
                torch.narrow(
                    stream, -1, 1 + index, stream_length - (num_streams - 1) - 2
                )
            )
        return torch.stack(new_streams)


train_ds = SequenceDataset(train_codes, device=device)
val_ds = SequenceDataset(val_codes, device=device)
test_ds = SequenceDataset(test_codes, device=device)

Index 60 has seq_len 1320, skipping
Index 89 has seq_len 1264, skipping
Index 134 has seq_len 1432, skipping
Index 147 has seq_len 1320, skipping
Index 154 has seq_len 1288, skipping
Index 208 has seq_len 1408, skipping
Index 260 has seq_len 1248, skipping
Index 310 has seq_len 1288, skipping
Index 334 has seq_len 1432, skipping
Index 431 has seq_len 1392, skipping
Index 453 has seq_len 1392, skipping
Index 463 has seq_len 1488, skipping
Index 474 has seq_len 1320, skipping
Index 496 has seq_len 1248, skipping
Index 498 has seq_len 1368, skipping
Index 538 has seq_len 1496, skipping
Index 565 has seq_len 1304, skipping
Index 611 has seq_len 1488, skipping
Index 703 has seq_len 1264, skipping
Index 729 has seq_len 1248, skipping
Index 796 has seq_len 1384, skipping
Index 860 has seq_len 1304, skipping
Index 866 has seq_len 1304, skipping
Index 942 has seq_len 1376, skipping
Index 981 has seq_len 1376, skipping
Index 999 has seq_len 1408, skipping
Index 1026 has seq_len 1496, skipping
In

In [None]:
from model import MHAModel


# Save model state, optimizer state, and other info to a .pth file
def save_checkpoint(
    model, optimizer, epoch, loss, config, filepath="model_checkpoint/checkpoint.pth"
):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        "config": config,
    }
    torch.save(checkpoint, filepath)


def load_checkpoint(
    lr: int, optimizer=None, filepath="model_checkpoint/checkpoint.pth"
):
    checkpoint = torch.load(filepath)
    config = checkpoint["config"]
    model = MHAModel(**config)
    model.load_state_dict(checkpoint["model_state_dict"])  # Load model state
    if optimizer is None:
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(
        checkpoint["optimizer_state_dict"]
    )  # Load optimizer state
    epoch = checkpoint["epoch"]  # Get saved epoch
    loss = checkpoint["loss"]  # Get saved loss (optional)
    return model, optimizer, epoch, loss, config

In [None]:
FROM_CHECKPOINT = False
BATCH_SIZE = 2
LEARNING_RATE = 1e-2
NUM_EPOCHS = 10000
MODEL_CONFIG = dict(
    d_model=256,
    nhead=8,
    num_decoder_layers=24,
    num_encoder_layers=24,
    dim_feedforward=1024,
    device=device,
)
FILEPATH = "model_checkpoint/checkpoint.pth"

previous_epochs = 0
model_config = None

if FROM_CHECKPOINT:
    model, optimizer, previous_epochs, loss, model_config = load_checkpoint(
        lr=LEARNING_RATE, filepath=FILEPATH
    )
else:
    model = MHAModel(**MODEL_CONFIG)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    model_config = MODEL_CONFIG
criterion = nn.CrossEntropyLoss(ignore_index=2048).to(device)


# Learning rate scheduler (optional, for better convergence)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)



In [None]:
# Training loop
from torch.amp import GradScaler, autocast

scaler = GradScaler()

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (src, tgt) in enumerate(train_loader):
        optimizer.zero_grad()

        with autocast(device_type="cuda"):
            output: torch.Tensor = model(src, tgt[:, :, :-1])
            # Reshaping output and target tensors for loss computation
            output = output.view(-1, model.vocab_size)
            tgt = tgt[:, :, 1:].contiguous().view(-1)
            loss = criterion(output, tgt)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()

        predicted = output.argmax(-1)
        non_pad_mask = tgt != 2048
        correct = (predicted[non_pad_mask] == tgt[non_pad_mask]).sum().item()
        total = non_pad_mask.sum().item()

        accuracy = correct / (total + 1e-8)
        print(
            f"Epoch [{epoch + 1 + previous_epochs}/{NUM_EPOCHS + previous_epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], "
            f"Loss: {loss.item():.4f}, Accuracy: {100.0 * accuracy:.2f}%, Grad Norm: {grad_norm:.5f}"
        )

    # Scheduler step (if using)
    # scheduler.step()

    # Print the loss every epoch
    if (epoch + 1) % 1 == 0:
        print(
            f"Epoch [{epoch + previous_epochs + 1}/{NUM_EPOCHS + previous_epochs}], Loss: {loss.item():.4f}"
        )

    if (epoch + 1) % 1 == 0:
        save_checkpoint(
            model,
            optimizer,
            epoch + previous_epochs + 1,
            loss,
            model_config,
            filepath=FILEPATH,
        )


Epoch [45/10044], Batch [1/7], Loss: 3.3504, Accuracy: 35.77%, Grad Norm: 0.50841
Epoch [45/10044], Batch [2/7], Loss: 3.8057, Accuracy: 28.18%, Grad Norm: 0.88023
Epoch [45/10044], Batch [3/7], Loss: 4.2415, Accuracy: 19.26%, Grad Norm: 0.88402
Epoch [45/10044], Batch [4/7], Loss: 4.2268, Accuracy: 20.17%, Grad Norm: 0.59992
Epoch [45/10044], Batch [5/7], Loss: 3.8164, Accuracy: 26.83%, Grad Norm: 0.67382
Epoch [45/10044], Batch [6/7], Loss: 3.7653, Accuracy: 26.77%, Grad Norm: 0.89185
Epoch [45/10044], Batch [7/7], Loss: 3.4678, Accuracy: 31.58%, Grad Norm: 0.77221
Epoch [45/10044], Loss: 3.4678
Epoch [46/10044], Batch [1/7], Loss: 3.3603, Accuracy: 36.10%, Grad Norm: 0.76256
Epoch [46/10044], Batch [2/7], Loss: 3.7768, Accuracy: 28.57%, Grad Norm: 0.87446
Epoch [46/10044], Batch [3/7], Loss: 4.2225, Accuracy: 19.14%, Grad Norm: 0.84858
Epoch [46/10044], Batch [4/7], Loss: 4.2336, Accuracy: 19.58%, Grad Norm: 0.71968
Epoch [46/10044], Batch [5/7], Loss: 3.7969, Accuracy: 26.42%, Grad

KeyboardInterrupt: 

In [None]:
from transformers import EncodecModel
from audio_tokenization import convert_tensor_to_wav, convert_wav_to_tensor
import torchaudio

train_codes = train_codes
encodec = EncodecModel.from_pretrained("facebook/encodec_32khz")

i = 150

backing_wav = convert_tensor_to_wav(encodec, train_codes[i][0][:, :1500])
torchaudio.save(f"tmp/{i}_backing.wav", backing_wav, sample_rate=32000)

lead_wav = convert_tensor_to_wav(encodec, train_codes[i][1][:, :1500])
torchaudio.save(f"tmp/{i}_lead.wav", lead_wav, sample_rate=32000)

model, _, _, _, _ = load_checkpoint(
    lr=1e-5, filepath="model_checkpoint/float32/mini_256d_8h_24l_1024ff.pth"
)
src = model.add_delay_interleaving(train_codes[i][0][:, :1500].to(device)).unsqueeze(0)
print(src.shape)
model.eval()
model_out = model.generate(src)
model_out = model.remove_delay_interleaving(model_out.squeeze())
model_out_wav = convert_tensor_to_wav(encodec, model_out.cpu().squeeze())
torchaudio.save(f"tmp/{i}_model.wav", model_out_wav, sample_rate=32000)
torchaudio.save(
    f"tmp/{i}_model_combined.wav", model_out_wav + backing_wav, sample_rate=32000
)


  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


torch.Size([1, 4, 1505])


TypeError: get_save_func.<locals>.save() missing 1 required positional argument: 'sample_rate'