In [None]:
# Mount shared drive
#from google.colab import drive
#drive.mount('/content/drive')

# Change to project directory in shared drive
import os
import sys

# Update this path to your shared drive location
DRIVE_PATH = "/Users/aakashdhondiyal/Downloads/hindi-lora"
os.chdir(DRIVE_PATH)
sys.path.append(DRIVE_PATH)

print(f"Working directory: {os.getcwd()}")


In [None]:
# Install required packages
%pip install torch torchaudio transformers datasets accelerate wandb


In [None]:
# Import required modules
import json
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from pathlib import Path
import sys
import tqdm as tqdm_module
import os

# Force tqdm to use console mode
if hasattr(tqdm_module, '_instances'):
    del tqdm_module._instances
tqdm = tqdm_module.tqdm

from acestep.models.ace_step_transformer import ACEStepTransformer2DModel as AceStepTransformer
from acestep.models.lyrics_utils.lyric_tokenizer import LyricTokenizer
from acestep.models.lyrics_utils.lyric_normalizer import LyricNormalizer
from acestep.music_dcae.music_log_mel import MusicLogMel
from acestep.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from peft import LoraConfig, get_peft_model

# Check CUDA availability and setup device
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda:
    print(f'GPU available! Using: {torch.cuda.get_device_name(0)}')
    torch.cuda.empty_cache()  # Clear GPU memory
else:
    print('Using CPU')

# Initialize components
tokenizer = LyricTokenizer('acestep/models/lyrics_utils/vocab.json')
normalizer = LyricNormalizer()
audio_processor = MusicLogMel()

# Load model architecture configuration
with open('acestep/models/config.json', 'r') as f:
    model_config = json.load(f)

# Load LoRA training configuration
with open('config/hi_rap_lora_config.json', 'r') as f:
    lora_config_dict = json.load(f)

# Create a wrapper class to make our model compatible with PEFT
class ModelWrapper(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.config = base_model.config
        
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs):
        # Extract our custom parameters from kwargs
        hidden_states = kwargs.get('hidden_states', inputs_embeds)
        encoder_text_hidden_states = kwargs.get('encoder_text_hidden_states')
        text_attention_mask = kwargs.get('text_attention_mask')
        speaker_embeds = kwargs.get('speaker_embeds')
        lyric_token_idx = kwargs.get('lyric_token_idx', input_ids)
        lyric_mask = kwargs.get('lyric_mask', attention_mask)
        timestep = kwargs.get('timestep')
        
        # Remove parameters we're explicitly passing to avoid duplication
        # Also filter out PEFT-specific parameters that our model doesn't expect
        filtered_kwargs = {k: v for k, v in kwargs.items() if k not in [
            'hidden_states', 'encoder_text_hidden_states', 'text_attention_mask',
            'speaker_embeds', 'lyric_token_idx', 'lyric_mask', 'timestep',
            'labels', 'output_attentions', 'output_hidden_states', 'return_dict'
        ]}
        
        # Call the base model with the correct parameters
        return self.base_model(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            encoder_text_hidden_states=encoder_text_hidden_states,
            text_attention_mask=text_attention_mask,
            speaker_embeds=speaker_embeds,
            lyric_token_idx=lyric_token_idx,
            lyric_mask=lyric_mask,
            timestep=timestep,
            **filtered_kwargs
        )

# Initialize model with correct architecture configuration
base_model = AceStepTransformer(**model_config)
model = ModelWrapper(base_model)

# Configure LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],  # Updated target modules to match model architecture
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS"  # Using SEQ_CLS to avoid generation method requirements
)

# Add LoRA adapters to the model
model = get_peft_model(model, lora_config)
model = model.to(device)

# Initialize scheduler
scheduler = FlowMatchEulerDiscreteScheduler(
    num_train_timesteps=lora_config_dict.get('num_train_timesteps', 1000),
    shift=1.0  # Default value, you can adjust this based on your needs
)

print("Model initialized with LoRA configuration")


In [None]:
# Dataset class definition
class HindiRapDataset(Dataset):
    def __init__(self, song_dir='data/songs'):
        self.song_dir = song_dir
        self.samples = []
        
        for file in os.listdir(song_dir):
            if file.endswith('_lyrics.txt'):  # Look for lyrics files
                base_name = file[:-11]  # Remove '_lyrics.txt'
                lyrics_file = os.path.join(song_dir, file)
                prompt_file = os.path.join(song_dir, base_name + '_prompt.txt')
                audio_file = os.path.join(song_dir, base_name + '.mp3')
                
                if os.path.exists(prompt_file) and os.path.exists(audio_file):
                    with open(lyrics_file, 'r', encoding='utf-8') as f:
                        lyrics = f.read().strip()
                    with open(prompt_file, 'r', encoding='utf-8') as f:
                        prompt = f.read().strip()
                        
                    self.samples.append({
                        'audio_path': audio_file,
                        'lyrics': lyrics,
                        'prompt': prompt
                    })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        
        # Process real audio using MusicLogMel
        try:
            audio_features = audio_processor.process(item['audio_path'])
            # Ensure audio_features is a torch tensor with correct shape
            if isinstance(audio_features, torch.Tensor):
                audio_tensor = audio_features
            else:
                audio_tensor = torch.tensor(audio_features, dtype=torch.float32)
            
            print(f"Original audio shape: {audio_tensor.shape}")
            
            # The model expects [channels, height, width] where:
            # - channels = 8 (in_channels from config)
            # - height >= 16 (patch_size[0] from config)
            # - width can vary (will be processed by patch embedding)
            
            if audio_tensor.dim() == 2:  # If [features, time] or [channels, time]
                # Reshape to ensure height >= 16
                channels, time_steps = audio_tensor.shape
                
                # Calculate height and width to ensure height >= 16
                if time_steps >= 16:
                    # Reshape into [channels, height, width] where height >= 16
                    height = 16
                    width = time_steps // height
                    if width == 0:
                        width = 1
                    # Trim to fit exact dimensions
                    total_needed = height * width
                    if time_steps > total_needed:
                        audio_tensor = audio_tensor[:, :total_needed]
                    elif time_steps < total_needed:
                        # Pad the time dimension
                        padding_needed = total_needed - time_steps
                        audio_tensor = F.pad(audio_tensor, (0, padding_needed))
                    
                    # Reshape to [channels, height, width]
                    audio_tensor = audio_tensor.view(channels, height, width)
                else:
                    # If time_steps < 16, pad to minimum size
                    min_size = 16 * 16  # 256 total elements
                    padding_needed = min_size - time_steps
                    audio_tensor = F.pad(audio_tensor, (0, padding_needed))
                    audio_tensor = audio_tensor.view(channels, 16, 16)
            
            # Ensure we have exactly 8 channels
            current_channels = audio_tensor.shape[0]
            if current_channels != 8:
                if current_channels < 8:
                    # Repeat channels to get 8
                    repeat_factor = 8 // current_channels
                    remainder = 8 % current_channels
                    repeated = audio_tensor.repeat(repeat_factor, 1, 1)
                    if remainder > 0:
                        extra = audio_tensor[:remainder]
                        audio_tensor = torch.cat([repeated, extra], dim=0)
                    else:
                        audio_tensor = repeated
                else:
                    # Take first 8 channels
                    audio_tensor = audio_tensor[:8]
            
            print(f"Final audio shape: {audio_tensor.shape}")
                    
        except Exception as e:
            print(f"Error processing audio {item['audio_path']}: {e}")
            # Fallback to dummy data if audio processing fails
            audio_tensor = torch.randn(8, 16, 128, dtype=torch.float32)
            print(f"Using fallback shape: {audio_tensor.shape}")
        
        # Process text data
        lyrics_tokens = tokenizer.encode(normalizer.normalize(item['lyrics']))
        lyrics_tensor = torch.tensor(lyrics_tokens)
        
        # Create attention masks
        lyrics_mask = torch.ones(len(lyrics_tokens))
        
        # Process prompt for text embeddings (we'll use this instead of dummy text)
        prompt_tokens = tokenizer.encode(normalizer.normalize(item['prompt']))
        prompt_tensor = torch.tensor(prompt_tokens)
        
        return {
            'audio_features': audio_tensor,
            'lyrics_tokens': lyrics_tensor,
            'lyrics_mask': lyrics_mask,
            'prompt_tokens': prompt_tensor,
            'prompt_text': item['prompt']  # Keep original text for reference
        }

# Create dataset
train_dataset = HindiRapDataset()

# Define custom collate function for padding
def collate_fn(batch):
    # Find max lengths in the batch
    max_lyrics_len = max(len(item['lyrics_tokens']) for item in batch)
    max_prompt_len = max(len(item['prompt_tokens']) for item in batch)
    max_audio_width = max(item['audio_features'].shape[-1] for item in batch)  # Width dimension (last dim)
    
    # Debug info (can be removed later)
    # print(f"Batch size: {len(batch)}, Max audio width: {max_audio_width}")
    
    # Pad each sequence to max_len
    padded_audio_features = []
    for item in batch:
        audio = item['audio_features']
        current_width = audio.shape[-1]
        if current_width < max_audio_width:
            # Pad the width dimension (last dimension)
            padding = (0, max_audio_width - current_width)  # Pad last dimension
            padded_audio = F.pad(audio, padding)
        else:
            padded_audio = audio
        padded_audio_features.append(padded_audio)
        # print(f"Padded audio shape: {padded_audio.shape}")
    
    padded_batch = {
        'audio_features': torch.stack(padded_audio_features),
        'lyrics_tokens': torch.stack([
            F.pad(item['lyrics_tokens'], (0, max_lyrics_len - len(item['lyrics_tokens'])))
            for item in batch
        ]),
        'lyrics_mask': torch.stack([
            F.pad(item['lyrics_mask'], (0, max_lyrics_len - len(item['lyrics_mask'])))
            for item in batch
        ]),
        'prompt_tokens': torch.stack([
            F.pad(item['prompt_tokens'], (0, max_prompt_len - len(item['prompt_tokens'])))
            for item in batch
        ]),
        'prompt_texts': [item['prompt_text'] for item in batch]  # Keep original texts
    }
    
    # print(f"Final batch audio shape: {padded_batch['audio_features'].shape}")
    return padded_batch

# Configure DataLoader with adaptive settings and custom collate function
num_workers = 2 if use_cuda else 0  # Use workers only if GPU is available
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=use_cuda,  # Enable pin_memory if using GPU
    collate_fn=collate_fn  # Add custom collate function
)

print(f"Dataset prepared with {len(train_dataset)} songs")


In [None]:
# Training setup
train_params = {
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'gradient_accumulation_steps': 2,
    'save_steps': 500
}

optimizer = AdamW(model.parameters(), lr=train_params['learning_rate'])
checkpoint_dir = Path('checkpoints/hindi_rap_lora')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Training loop
global_step = 0
model.train()

# Clear GPU memory before training if available
if use_cuda:
    torch.cuda.empty_cache()

for epoch in range(train_params['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{train_params['num_epochs']}")
    epoch_loss = 0
    
    # Use tqdm with console mode
    for batch_idx, batch in enumerate(tqdm_module.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", ascii=True)):
        # Create attention masks
        lyrics_mask = (batch['lyrics_tokens'] != 0).long()  # Create mask where non-zero tokens are 1
        
        # Audio mask should match the sequence length after patch embedding
        # After patch embedding: [batch, channels, height, width] -> [batch, seq_len, dim]
        # The sequence length after patch embedding = width (since patch_size is [16, 1])
        batch_size, channels, height, width = batch['audio_features'].shape
        audio_mask = torch.ones(batch_size, width, device=device)
        
        print(f"Audio features shape: {batch['audio_features'].shape}")
        print(f"Audio mask shape: {audio_mask.shape}")
        print(f"Lyrics tokens shape: {batch['lyrics_tokens'].shape}")
        print(f"Lyrics mask shape: {lyrics_mask.shape}")
        
        # Use real prompt data for text embeddings instead of dummy data
        # Convert prompt tokens to embeddings (using a simple embedding approach for now)
        prompt_mask = (batch['prompt_tokens'] != 0).long()
        # Create text embeddings from prompt tokens (simplified approach)
        # For now, create embeddings with the expected dimension (768)
        text_embeddings = torch.randn((batch['audio_features'].shape[0], batch['prompt_tokens'].shape[1], 768), device=device)
        text_mask = prompt_mask
        
        # Create dummy speaker embeddings (we don't have speaker info in our dataset)
        dummy_speaker = torch.zeros((batch['audio_features'].shape[0], 512), device=device)
        
        # Get timestep
        timestep = torch.randint(0, scheduler.num_train_timesteps, (batch['audio_features'].shape[0],), device=device)
        
        # Forward pass using the wrapper with real data
        output = model(
            input_ids=batch['lyrics_tokens'],
            attention_mask=audio_mask,  # Use audio_mask for main attention_mask
            inputs_embeds=batch['audio_features'],
            hidden_states=batch['audio_features'],
            encoder_text_hidden_states=text_embeddings,
            text_attention_mask=text_mask,
            speaker_embeds=dummy_speaker,
            lyric_token_idx=batch['lyrics_tokens'],
            lyric_mask=lyrics_mask,  # Use lyrics_mask for lyric-specific attention
            timestep=timestep,
            return_dict=True
        )
        
        # Get loss from output
        loss = output.sample  # Assuming the loss is in the sample field
        
        # Backward pass
        loss = loss / train_params['gradient_accumulation_steps']
        loss.backward()
        
        epoch_loss += loss.item()
        
        # Update weights if gradient accumulation is complete
        if (batch_idx + 1) % train_params['gradient_accumulation_steps'] == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            
            # Log metrics
            avg_loss = epoch_loss / (batch_idx + 1)
            if global_step % 10 == 0:  # Print every 10 steps
                print(f"Step {global_step}, Loss: {avg_loss:.4f}")
            
            # Save checkpoint
            if global_step % train_params['save_steps'] == 0:
                checkpoint_path = checkpoint_dir / f'checkpoint-{global_step}.pt'
                torch.save({
                    'step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                    'epoch': epoch,
                }, checkpoint_path)
                print(f"\nSaved checkpoint: {checkpoint_path}")
                
                # Clear GPU memory after saving if available
                if use_cuda:
                    torch.cuda.empty_cache()

print("Training completed!")


In [None]:
# Save final LoRA weights
final_weights_path = checkpoint_dir / 'final_lora_weights.pt'
model.save_pretrained(final_weights_path)
print(f"Saved final LoRA weights to: {final_weights_path}")

# Save training configuration
config_path = checkpoint_dir / 'training_config.json'
with open(config_path, 'w') as f:
    json.dump({
        'train_params': train_params,
        'lora_config': lora_config.to_dict()  # Convert LoraConfig to dict for saving
    }, f, indent=2)
print(f"Saved training configuration to: {config_path}")

# Clear GPU memory one last time if available
if use_cuda:
    torch.cuda.empty_cache()
