In [2]:
import sys
sys.path.insert(0, 'C:\\Users\\Paarth Tandon\\Desktop\\repos\\personal-music-gen')

In [3]:
from audiocraft.models import MusicGen
from personal_musicgen.data.datasets import AudioDataset
from personal_musicgen.model_utils import *

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, Subset
import torchaudio

In [4]:
data = AudioDataset('../data/chunks_no_voice/')
model = MusicGen.get_pretrained('small')
model.lm = model.lm.to(torch.float32)



In [5]:
model.device

device(type='cuda', index=0)

In [6]:
dataloader = DataLoader(Subset(data, list(range(100))), batch_size=1)
acc_steps = 4
optimizer = AdamW(
    model.lm.parameters(),
    lr=1e-5,
    betas=(0.9, 0.95),
    weight_decay=0.1,
)
scaler = torch.cuda.amp.GradScaler()

In [9]:
from tqdm import tqdm


def train_step(
        model: MusicGen,
        optimizer: AdamW,
        scaler: GradScaler,
        dataloader: DataLoader,
        grad_acc_steps: int
) -> dict:
    device = model.device
    
    total_loss = 0

    for i, (audio_fns, label_fns) in tqdm(enumerate(dataloader), total=len(dataloader)):
        codes_l = []
        text_l = []

        for audio_fn, label_fn in zip(audio_fns, label_fns):
            codes = encode_audio(model, audio_fn)
            codes_l.append(codes)
            with open(label_fn, 'r') as label_f:
                text_l.append(label_f.read().strip())
        
        codes = torch.cat(codes_l, dim=0).to(device)

        attributes, _ = model._prepare_tokens_and_attributes(text_l, None)
        conditional_vector = get_contitional_vector(model, attributes)

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            lm_output = model.lm.compute_predictions(
                codes=codes,
                conditions=[],
                condition_tensors=conditional_vector
            )

            loss = compute_masked_loss(lm_output, codes)
            scaler.scale(loss).backward()
            
            total_loss += loss.item()
            print(loss.item())

            if (i + 1) % grad_acc_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.lm.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

    return {
        'loss': total_loss / len(dataloader)
    }

In [10]:
train_step(model, optimizer, scaler, dataloader, acc_steps)

  1%|          | 1/100 [00:05<09:41,  5.88s/it]

3.240298271179199


  2%|▏         | 2/100 [00:09<07:11,  4.40s/it]

3.163398027420044


  3%|▎         | 3/100 [00:12<06:08,  3.80s/it]

1.828628420829773


  4%|▍         | 4/100 [00:15<05:35,  3.50s/it]

2.915003538131714


  5%|▌         | 5/100 [00:21<07:00,  4.43s/it]

4.784695625305176


  6%|▌         | 6/100 [00:24<06:16,  4.01s/it]

5.107187271118164


  7%|▋         | 7/100 [00:27<05:41,  3.67s/it]

4.934700965881348


  8%|▊         | 8/100 [00:30<05:28,  3.57s/it]

4.523026466369629


  9%|▉         | 9/100 [00:37<06:43,  4.44s/it]

3.7372982501983643


 10%|█         | 10/100 [00:40<06:07,  4.08s/it]

3.068359375


 11%|█         | 11/100 [00:43<05:25,  3.66s/it]

4.520743370056152


 12%|█▏        | 12/100 [00:46<05:03,  3.45s/it]

4.414567947387695


 13%|█▎        | 13/100 [00:52<06:18,  4.36s/it]

3.3525471687316895


 14%|█▍        | 14/100 [00:55<05:46,  4.03s/it]

3.097527027130127


 15%|█▌        | 15/100 [00:58<05:12,  3.67s/it]

3.004128932952881


 16%|█▌        | 16/100 [01:01<04:53,  3.49s/it]

4.071561336517334


 17%|█▋        | 17/100 [01:08<06:04,  4.40s/it]

3.316941976547241


 18%|█▊        | 18/100 [01:11<05:31,  4.04s/it]

3.182795763015747


 19%|█▉        | 19/100 [01:14<04:55,  3.65s/it]

4.207293510437012


 20%|██        | 20/100 [01:17<04:36,  3.46s/it]

4.255553722381592


 21%|██        | 21/100 [01:23<05:40,  4.31s/it]

4.413774490356445


 22%|██▏       | 22/100 [01:26<05:10,  3.99s/it]

4.387679576873779


 23%|██▎       | 23/100 [01:29<04:39,  3.62s/it]

3.030332565307617


 24%|██▍       | 24/100 [01:32<04:21,  3.44s/it]

4.5116868019104


 25%|██▌       | 25/100 [01:39<05:25,  4.33s/it]

5.005823612213135


 26%|██▌       | 26/100 [01:42<04:55,  3.99s/it]

4.141480922698975


 27%|██▋       | 27/100 [01:45<04:23,  3.61s/it]

2.549711227416992


 28%|██▊       | 28/100 [01:48<04:07,  3.43s/it]

4.2668843269348145


 29%|██▉       | 29/100 [01:54<05:05,  4.30s/it]

3.746333360671997


 30%|███       | 30/100 [01:57<04:36,  3.95s/it]

4.0812788009643555


 31%|███       | 31/100 [02:00<04:08,  3.60s/it]

4.4870195388793945


 32%|███▏      | 32/100 [02:03<03:53,  3.43s/it]

4.761338233947754


 33%|███▎      | 33/100 [02:09<04:50,  4.34s/it]

3.87001371383667


 34%|███▍      | 34/100 [02:13<04:24,  4.01s/it]

3.2783734798431396


 35%|███▌      | 35/100 [02:15<03:55,  3.63s/it]

3.9527554512023926


 36%|███▌      | 36/100 [02:18<03:41,  3.46s/it]

3.7286298274993896


 36%|███▌      | 36/100 [02:21<04:10,  3.92s/it]


KeyboardInterrupt: 