<a href="https://colab.research.google.com/github/sahith2004/Indic-Codecs-Evaluation/blob/main/Finetuning_Encodec_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Installing Requirements**

In [None]:
!pip install datasets==2.16.0
!pip install huggingface-hub==0.34.0
!pip install transformers[torch]
!pip install accelerate -U

**Training class for finetuning Encodec model**
## picked MultiScale STFT Loss from encodec code


In [None]:
from transformers import EncodecModel
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiScaleSTFTLoss(nn.Module):
    def __init__(self, fft_sizes=(512, 1024, 2048), eps=1e-8):
        super().__init__()
        self.fft_sizes = fft_sizes
        self.eps = eps

    def stft_mag(self, x, n_fft):
        if x.dim() == 3:
            x = x.mean(1)
        spec = torch.stft(
            x,
            n_fft=n_fft,
            hop_length=n_fft//4,
            win_length=n_fft,
            window=torch.hann_window(n_fft, device=x.device),
            return_complex=True
        )
        return torch.abs(spec)

    def forward(self, x, y):
        sc, mag = 0, 0
        for n_fft in self.fft_sizes:
            mx = self.stft_mag(x, n_fft)
            my = self.stft_mag(y, n_fft)
            sc += torch.norm(mx - my, p='fro') / (torch.norm(my, p='fro') + self.eps)
            mag += F.l1_loss(torch.log(mx+self.eps), torch.log(my+self.eps))
        return sc/len(self.fft_sizes), mag/len(self.fft_sizes)




class TrainEncodecModel(EncodecModel):
    def __init__(self, config, processor=None):
        super().__init__(config)
        self.processor = processor

        self.mse = nn.MSELoss()
        self.ms_stft = MultiScaleSTFTLoss()

        try:
            from torchaudio.transforms import MelSpectrogram
            sr = processor.sampling_rate if processor else 24000
            self.mel = MelSpectrogram(
                sample_rate=sr, n_fft=1024, hop_length=256, n_mels=80
            )
        except:
            self.mel = None

        # Fixed weights
        self.weights = {
            "waveform": 1.0,
            "ms_sc": 1.0,
            "ms_mag": 1.0,
            "mel": 1.0,
        }

    def forward(self, input_values, padding_mask=None, **kwargs):
        outputs = super().forward(input_values, padding_mask=padding_mask, **kwargs)
        recon = outputs["audio_values"]


        waveform_loss = self.mse(recon, input_values)

        sc_loss, mag_loss = self.ms_stft(input_values, recon)


        mel_loss = torch.tensor(0.0, device=input_values.device)
        if self.mel is not None:
            mel_x = self.mel(input_values.squeeze(1))
            mel_y = self.mel(recon.squeeze(1))
            mel_loss = F.l1_loss(torch.log1p(mel_y), torch.log1p(mel_x))



        total_loss = (
            self.weights["waveform"] * waveform_loss +
            self.weights["ms_sc"] * sc_loss +
            self.weights["ms_mag"] * mag_loss +
            self.weights["mel"] * mel_loss +
        )

        outputs["loss"] = total_loss

        return outputs


**Loading the model**

In [None]:
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
import torch.nn as nn

model = TrainEncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")

Some weights of CustomEncodecModel were not initialized from the model checkpoint at facebook/encodec_24khz and are newly initialized: ['mel.mel_scale.fb', 'mel.spectrogram.window']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


**Loading the data**

In [None]:
hindi_data = load_dataset("SPRINGLab/IndicTTS-Hindi", split="train",use_auth_token="")




In [None]:
print(hindi_data)

Dataset({
    features: ['audio', 'text', 'gender'],
    num_rows: 11825
})


**Preprocessing the data for training**

In [None]:
import numpy as np
import torch


def full_preprocess(examples, max_length=123840):

    audio_batch = []
    for audio in examples["audio"]:
        a = audio["array"]

        if len(a) > max_length:
            a = a[:max_length]
        else:
            pad = np.zeros(max_length - len(a), dtype=np.float32)
            a = np.concatenate([a, pad])

        audio_batch.append(a)

    proc_out = processor(
        raw_audio=audio_batch,
        sampling_rate=processor.sampling_rate,
        return_tensors="pt"
    )

    return {
        "input_values": proc_out["input_values"],
        "padding_mask": proc_out["padding_mask"],
    }


hindi_data_10k = hindi_data.select(range(10000))


tokenized_datasets = hindi_data_10k.map(
    full_preprocess,
    batched=True,
    batch_size=32,
    num_proc=4,
    remove_columns=hindi_data_10k.column_names
)

tokenized_datasets


Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['input_values', 'padding_mask'],
    num_rows: 10000
})

**Training setup**

In [None]:
from transformers import TrainingArguments
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, AutoTokenizer

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=5e-7,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=5,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=5,
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
)

**Training**

In [None]:
import torch
torch.cuda.empty_cache()
trainer.train()

Step,Training Loss


Step,Training Loss
500,2.5405


KeyboardInterrupt: 