In [None]:
# Mount Google Drive (if needed)
from google.colab import drive
drive.mount('/content/drive')

# Clone the repository
!git clone https://github.com/paytm-temp/music.git
%cd music

# Install dependencies
%pip install -r requirements.txt
%pip install torch torchaudio transformers datasets accelerate wandb


In [None]:
import os
import json
import torch
from acestep.models.lyrics_utils.lyric_tokenizer import LyricTokenizer
from acestep.models.lyrics_utils.lyric_normalizer import LyricNormalizer
from acestep.models.lyrics_utils.lyric_encoder import LyricEncoder
from acestep.music_dcae.music_log_mel import MusicLogMel

# Initialize tokenizer and normalizer
with open('acestep/models/lyrics_utils/vocab.json', 'r', encoding='utf-8') as f:
    vocab = json.load(f)
    
tokenizer = LyricTokenizer(vocab)
normalizer = LyricNormalizer()
encoder = LyricEncoder()

# Initialize audio processor
audio_processor = MusicLogMel()

def prepare_song_data(song_dir='data/songs'):
    dataset = []
    for file in os.listdir(song_dir):
        if file.endswith('.mp3'):
            base_name = file[:-4]
            lyrics_file = os.path.join(song_dir, base_name + '_lyrics.txt')
            prompt_file = os.path.join(song_dir, base_name + '_prompt.txt')
            audio_file = os.path.join(song_dir, file)
            
            if os.path.exists(lyrics_file) and os.path.exists(prompt_file):
                with open(lyrics_file, 'r', encoding='utf-8') as f:
                    lyrics = f.read().strip()
                with open(prompt_file, 'r', encoding='utf-8') as f:
                    prompt = f.read().strip()
                    
                # Process lyrics and prompt
                normalized_lyrics = normalizer.normalize(lyrics)
                normalized_prompt = normalizer.normalize(prompt)
                
                dataset.append({
                    'audio_path': audio_file,
                    'lyrics': normalized_lyrics,
                    'prompt': normalized_prompt
                })
    
    return dataset

# Prepare dataset
dataset = prepare_song_data()
print(f"Prepared {len(dataset)} songs for training")


In [None]:
import json
from acestep.models.ace_step_transformer import AceStepTransformer
from acestep.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler

# Load configuration
with open('config/hi_rap_lora_config.json', 'r') as f:
    config = json.load(f)

# Initialize model and scheduler
model = AceStepTransformer(config)
scheduler = FlowMatchEulerDiscreteScheduler(
    num_train_timesteps=config.get('num_train_timesteps', 1000),
    beta_schedule=config.get('beta_schedule', 'linear')
)

# Configure LoRA parameters
lora_config = {
    'r': 16,  # LoRA rank
    'alpha': 32,  # LoRA scaling factor
    'dropout': 0.1,
    'target_modules': ['q_proj', 'k_proj', 'v_proj', 'out_proj']  # Layers to apply LoRA
}

# Enable LoRA training
model.enable_lora_training(lora_config)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Model initialized on {device}")


In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from tqdm.notebook import tqdm

class HindiRapDataset(Dataset):
    def __init__(self, dataset, tokenizer, audio_processor):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.audio_processor = audio_processor
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Process audio
        audio_features = self.audio_processor.process(item['audio_path'])
        
        # Process text
        lyrics_tokens = self.tokenizer.tokenize(item['lyrics'])
        prompt_tokens = self.tokenizer.tokenize(item['prompt'])
        
        return {
            'audio_features': torch.tensor(audio_features, device=device),
            'lyrics_tokens': torch.tensor(lyrics_tokens, device=device),
            'prompt_tokens': torch.tensor(prompt_tokens, device=device)
        }

# Training parameters
train_params = {
    'batch_size': 4,
    'learning_rate': 1e-4,
    'num_epochs': 50,
    'gradient_accumulation_steps': 2,
    'save_steps': 500,
    'warmup_steps': 100
}

# Create dataset and dataloader
train_dataset = HindiRapDataset(dataset, tokenizer, audio_processor)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=train_params['batch_size'],
    shuffle=True,
    num_workers=2
)

# Initialize optimizer
optimizer = AdamW(
    model.parameters(),
    lr=train_params['learning_rate'],
    weight_decay=0.01
)

print(f"Training setup complete with {len(train_dataset)} samples")


In [None]:
from pathlib import Path
import wandb  # Optional: for tracking experiments

# Initialize wandb (optional)
# wandb.init(project="hindi-rap-lora", config=train_params)

# Create checkpoint directory
checkpoint_dir = Path('checkpoints/hindi_rap_lora')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Training loop
global_step = 0
model.train()

for epoch in range(train_params['num_epochs']):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    epoch_loss = 0
    
    for batch_idx, batch in enumerate(progress_bar):
        # Forward pass
        loss = model(
            audio_features=batch['audio_features'],
            lyrics_tokens=batch['lyrics_tokens'],
            prompt_tokens=batch['prompt_tokens']
        )
        
        # Backward pass
        loss = loss / train_params['gradient_accumulation_steps']
        loss.backward()
        
        epoch_loss += loss.item()
        
        # Update weights if gradient accumulation is complete
        if (batch_idx + 1) % train_params['gradient_accumulation_steps'] == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            
            # Log metrics
            avg_loss = epoch_loss / (batch_idx + 1)
            progress_bar.set_postfix({'loss': avg_loss})
            # wandb.log({'loss': avg_loss, 'epoch': epoch, 'step': global_step})
            
            # Save checkpoint
            if global_step % train_params['save_steps'] == 0:
                checkpoint_path = checkpoint_dir / f'checkpoint-{global_step}.pt'
                torch.save({
                    'step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                    'epoch': epoch,
                }, checkpoint_path)
                print(f"\nSaved checkpoint: {checkpoint_path}")

print("Training completed!")


In [None]:
# Save final LoRA weights
final_weights_path = checkpoint_dir / 'final_lora_weights.pt'
model.save_lora_weights(final_weights_path)
print(f"Saved final LoRA weights to: {final_weights_path}")

# Save training configuration
config_path = checkpoint_dir / 'training_config.json'
with open(config_path, 'w') as f:
    json.dump({
        'train_params': train_params,
        'lora_config': lora_config
    }, f, indent=2)
print(f"Saved training configuration to: {config_path}")

# Optional: Save to Google Drive
# drive_path = '/content/drive/MyDrive/hindi_rap_lora'
# !mkdir -p {drive_path}
# !cp -r {checkpoint_dir}/* {drive_path}/
# print(f"Copied weights to Google Drive: {drive_path}")
