In [1]:
# Same model but trains without using the mask values
# The mask is used to select what to train on
# Instead, we want to train on everything

In [None]:
import torchaudio
from audiocraft.models import MusicGen
from transformers import get_scheduler
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
import random
import wandb

from torch.utils.data import Dataset

from audiocraft.modules.conditioners import ClassifierFreeGuidanceDropout

import os

In [2]:
# Each element has
#    Left ear path
#    Right ear path
#    Label
#    Original path

# Code from: https://github.com/chavinlo/musicgen_trainer/blob/main/train.py
# Create a class to hold your data set
# Modified since we are not loading the same way as the original code. 
class AudioDataset(Dataset):
    def __init__(self, data_dir, baseline):
        # Uses the absolute path to the directory where the data is stored
        self.data_dir = data_dir
        self.data_map = []
        #self.baseline_file_name = baseline_file_name
        dir_map = os.listdir(data_dir)
        for d in dir_map:
            name, ext = os.path.splitext(d)
            # Only have 1 data point for each recording
            # that takes the left/right and the baseline
            if 'recording' in name or 'EARS_1' in name:
                continue
            if ext == ".wav":
                # We will have labels for everything
                #if os.path.exists(os.path.join(data_dir, name + ".txt")):
                label = name.split('Deg')[0]
                if label is not None:
                    temp = name.split('_')
                    index_of_target = name.split('_')[-1]
                    left_ear = temp[0] + '_' + temp[1] + '_' + '1' + '_' + index_of_target + '.wav'
                    right_ear = temp[0] + '_' + temp[1] + '_' + '2' + '_' + index_of_target + '.wav'
                    orig = baseline + index_of_target + ".wav"
                    self.data_map.append(
                        {
                            "left_target": os.path.join(data_dir, left_ear),
                            "right_target": os.path.join(data_dir, right_ear),
                            "label": label,
                            "original": os.path.join(data_dir, orig),
                        }
                    )
                else:
                    raise ValueError(f"No label file for {name}")

    def __len__(self):
        return len(self.data_map)

    def __getitem__(self, idx):
        data = self.data_map[idx]
        left_audio = data["left_target"]
        right_audio = data["right_target"]
        label = data.get("label", "")
        original = data.get("original", "")

        return left_audio, right_audio, label, original

NameError: name 'Dataset' is not defined

In [None]:
def fixnan(tensor: torch.Tensor):
    nan_mask = torch.isnan(tensor)
    result = torch.where(nan_mask, torch.zeros_like(tensor), tensor)

    return result


def one_hot_encode(tensor, num_classes=2048):
    shape = tensor.shape
    one_hot = torch.zeros((shape[0], shape[1], num_classes))

    for i in range(shape[0]):
        for j in range(shape[1]):
            index = tensor[i, j].item()
            one_hot[i, j, index] = 1

    return one_hot

def preprocess_stereo_audio(audio_path1, audio_path2, model, duration: int = 30):
    # Keep the audio in stereo
    wav1, sr1 = torchaudio.load(audio_path1)
    wav2, sr2 = torchaudio.load(audio_path2)

    assert sr1 == sr2, f"Sample rates are {sr1} and {sr2} not the same"
    assert wav1.shape == wav2.shape, f"Audio shapes are {wav1.shape} and {wav2.shape} not the same"

    wav, sr = torch.cat((wav1, wav2), dim=0), sr1

    # check if frequencies match
    assert sr == model.sample_rate, f"Sample rate is {sr} not {model.sample_rate}"

    # Check if audio length is long enough
    #if wav.shape[1] < duration * model.sample_rate:
    #    return None
    
    # Audio should be 30 seconds long exactly
    assert wav.shape[1] == duration * model.sample_rate, f"Audio is {wav.shape[1] / model.sample_rate} seconds long, not {duration} seconds long"

    # Move audio tensor to GPU
    wav = wav.cuda()
    #print(f"Audio shape: {wav.shape}")
    # add a batch dimension
    # Copy the audio to the other channel
    #wav = torch.cat((wav, wav), dim=0)
    # Add a batch dimension
    wav = wav.unsqueeze(0)
    #print(f"Audio shape: {wav.shape}")
    

    # Encode using models compression method
    with torch.no_grad():
        gen_audio = model.compression_model.encode(wav)
    
    codes, scale = gen_audio
    #print(codes)
    #print(scale)

    assert scale is None

    return codes

def preprocess_audio(audio_path, model, duration: int = 30):
    # Keep the audio in stereo
    wav, sr = torchaudio.load(audio_path)

    # check if frequencies match
    assert sr == model.sample_rate, f"Sample rate is {sr} not {model.sample_rate}"

    # Check if audio length is long enough
    #if wav.shape[1] < duration * model.sample_rate:
    #    return None
    
    # Audio should be 30 seconds long exactly
    assert wav.shape[1] == duration * model.sample_rate, f"Audio is {wav.shape[1] / model.sample_rate} seconds long, not {duration} seconds long"

    # Move audio tensor to GPU
    wav = wav.cuda()
    #print(f"Audio shape: {wav.shape}")
    # add a batch dimension
    # Copy the audio to the other channel
    wav = torch.cat((wav, wav), dim=0)
    # Add a batch dimension
    wav = wav.unsqueeze(0)
    #print(f"Audio shape: {wav.shape}")
    

    # Encode using models compression method
    with torch.no_grad():
        gen_audio = model.compression_model.encode(wav)
    
    codes, scale = gen_audio
    #print(codes)
    #print(scale)

    assert scale is None

    return codes


# Counts number of nans 
def count_nans(tensor):
    # Creating a boolean mask where True represents NaN values in the tensor
    nan_mask = torch.isnan(tensor)

    # Calculating the total number of NaN values by summing up the True values in the mask
    num_nans = torch.sum(nan_mask).item()

    # Returning the count of NaN values as an integer
    return num_nans




In [None]:
import torch
import wandb  # Assuming wandb is used, it's being imported

def train(
    dataset_path: str,
    model: MusicGen,
    lr: float,
    epochs: int,
    use_wandb: bool,
    no_label: bool = False,
    tune_text: bool = False,
    save_step: int = None,
    grad_acc: int = 8,
    use_scaler: bool = False,
    weight_decay: float = 1e-5,
    warmup_steps: int = 10,
    batch_size: int = 10,
    use_cfg: bool = False,
    save_path: str='models/',
    baseline: str = 'recording_01_'
):

    # Load the pretrained model
    #model = MusicGen.get_pretrained(model_id)
    model.lm = model.lm.to(torch.float32)  # Convert model to float32

    # Create dataset and dataloader
    dataset = AudioDataset(dataset_path, baseline)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    learning_rate = lr
    model.lm.train()

    scaler = torch.cuda.amp.GradScaler()
    # what is inside train_dataloader
    #print(train_dataloader)
    #print(train_dataloader.dataset.data_map)
    #return 0
    if tune_text:
        print("Tuning text")
    else:
        print("Tuning everything")

    # from paper
    optimizer = AdamW(
        model.lm.condition_provider.parameters()
        if tune_text
        else model.lm.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.95),
        weight_decay=weight_decay,
    )
    scheduler = get_scheduler(
        "cosine",
        optimizer,
        warmup_steps,
        int(epochs * len(train_dataloader) / grad_acc),
    )

    #criterion = nn.CrossEntropyLoss()
    criterion = nn.MSELoss() #thnks
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_epochs = epochs

    save_step = save_step
    save_models = False if save_step is None else True

    save_path = save_path

    os.makedirs(save_path, exist_ok=True)

    current_step = 0

    for epoch in range(num_epochs):
        for batch_idx, (left_targets, right_targets, labels, originals) in enumerate(train_dataloader):
            optimizer.zero_grad()

            all_codes = []
            all_target_codes = []
            texts = []
            # What are audio and label
            #print(target)
            #print(label)
            #print(original)
            

            # Iterate through audio paths and corresponding labels
            for left_target, right_target, label, original in zip(left_targets, right_targets, labels, originals):
                #print(inner_audio)
                #print(l)
                #print(inner_orig.shape)
                inner_audio = preprocess_audio(original, model)  # Preprocess audio to tensor
                target_audio = preprocess_stereo_audio(left_target, right_target, model) # Prepocess target audio to tensor

                # Need both the predicted audio and a target
                if inner_audio is None or target_audio is None:
                    continue
                
                
                if use_cfg:
                    codes = torch.cat([inner_audio, inner_audio], dim=0)
                    target_codes = torch.cat([target_audio, target_audio], dim=0)
                else:
                    codes = inner_audio
                    target_codes = target_audio

                all_codes.append(codes)
                all_target_codes.append(target_codes)
                #texts.append(open(l, "r").read().strip())
                texts.append(label)
            
            attributes, _ = model._prepare_tokens_and_attributes(texts, None)
            conditions = attributes
            if use_cfg:
                null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
                conditions = conditions + null_conditions
            tokenized = model.lm.condition_provider.tokenize(conditions)
            cfg_conditions = model.lm.condition_provider(tokenized)
            condition_tensors = cfg_conditions

            # If we have no codes then no training :(
            if len(all_codes) == 0 or len(all_target_codes) == 0:
                continue
            

            codes = torch.cat(all_codes, dim=0)
            target_codes = torch.cat(all_target_codes, dim=0)
            
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                # Compute predictions using the model's LM
                # 'codes' are the input data, 'conditions' are the conditioning features
                # 'condition_tensors' are the tokenized conditions
                lm_output = model.lm.compute_predictions(
                    codes=codes, conditions=[], condition_tensors=condition_tensors
                )

                # The lines below copy the values 
                # 'codes' are the input data
                target_codes = target_codes[0]
                # 'logits' are the raw, unnormalized predictions generated by the model
                logits = lm_output.logits[0]

                # One-hot encode the 'codes' to match the dimensionality of the 'logits'
                target_codes = one_hot_encode(target_codes, num_classes=2048)
                target_codes = target_codes.cuda()
                logits = logits.cuda()

                

                logits = logits.view(-1, 2048)
                target_codes = target_codes.view(-1, 2048)
                loss = criterion(logits, target_codes)
            current_step += 1 / grad_acc

            # assert count_nans(masked_logits) == 0

            (scaler.scale(loss) if use_scaler else loss).backward()

            total_norm = 0
            for p in model.lm.condition_provider.parameters():
                try:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                except AttributeError:
                    pass
            total_norm = total_norm ** (1.0 / 2)

            if use_wandb:
                run.log(
                    {
                        "loss": loss.item(),
                        "total_norm": total_norm,
                    }
                )

            print(
                f"Epoch: {epoch}/{num_epochs}, Batch: {batch_idx}/{len(train_dataloader)}, Loss: {loss.item()}"
            )

            if batch_idx % grad_acc != grad_acc - 1:
                continue

            if use_scaler:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 0.5)

            if use_scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            print(f"Current step: {current_step}")
            # Saves every save_step during batch training
            # if save_models:
            #     if (
            #         current_step == int(current_step)
            #         and int(current_step) % save_step == 0
            #     ):
            #         torch.save(
            #             model.lm.state_dict(), f"{save_path}/lm_{current_step}.pt"
            #         )
        # Save once an epoch
        if epoch % save_step == 0 and save_models:
            torch.save(model.lm.state_dict(), f"{save_path}/lm_epoch_{epoch}.pt")

    torch.save(model.lm.state_dict(), f"{save_path}/lm_final.pt")


In [None]:
dataset_path = '/workspace/small_model_data2'
model_id = 'facebook/musicgen-stereo-small'
lr = 1e-5
epochs = 100
use_wandb = False
save_step = 32
grad_acc = 2
no_label = False
tune_text = False
weight_decay = 1e-5
warmup_steps = 16
batch_size = 4
use_cfg = False
save_path = 'small_model_save_path/'
baseline = 'recording_01_'

In [None]:
model = MusicGen.get_pretrained('facebook/musicgen-stereo-small')
# Continue from previous training
#model.lm.load_state_dict(torch.load('small_model_save_path/lm_896.0.pt'))
model.lm = model.lm.to(torch.float32)


In [None]:

train(
    dataset_path=dataset_path,
    model=model,
    lr=lr,
    epochs=epochs,
    use_wandb=use_wandb,
    save_step=save_step,
    no_label=no_label,
    tune_text=tune_text,
    weight_decay=weight_decay,
    grad_acc=grad_acc,
    warmup_steps=warmup_steps,
    batch_size=batch_size,
    use_cfg=use_cfg,
    save_path=save_path,
    use_scaler=True,
    baseline=baseline,
)