In [None]:
import sys
sys.path.append("/Users/yash/Desktop/yash-mtp/src/common")
from Model import *
import os
from SilenceRemover import *
from datasets import Dataset
from multiprocess import set_start_method
import numpy as np
import torch
from transformers import  AutoConfig, Wav2Vec2Processor
import librosa
from torch import mps
import torch.nn.functional as F
from sklearn.metrics import classification_report


def whatDevice():
    if  torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"
device = whatDevice()
print(f"Device: {device}")
mps.empty_cache()
device = "cpu"

In [None]:
directory = "/Users/yash/Desktop/MTP-2k23-24/Bhashini_Test_Data"
### Intializing models
## for wave2vec2
model_name_or_path = "yashcode00/wav2vec2-large-xlsr-indian-language-classification-featureExtractor"
config = AutoConfig.from_pretrained(model_name_or_path)
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
model_wave2vec2 = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path).to(device)
target_sampling_rate = processor.feature_extractor.sampling_rate
processor.feature_extractor.return_attention_mask = True
label_list  = ['asm', 'ben', 'eng', 'guj', 'hin', 'kan', 'mal', 'mar', 'odi', 'tam', 'tel']
lang2id = {'asm': 0, 'ben': 1, 'eng': 2, 'guj': 3, 'hin': 4, 'kan': 5, 'mal': 6, 'mar': 7, 'odi': 8, 'tam': 9, 'tel': 10,'pun': 10}
id2lang = {0: 'asm', 1: 'ben', 2: 'eng', 3: 'guj', 4: 'hin', 5: 'kan', 6: 'mal', 7: 'mar', 8: 'odi', 9: 'tam', 10: 'tel'}
input_column = 'path'
output_column = 'true_label'
window_size = 16000
hop_length_seconds = 1
# Calculate the hop size in samples
hop_size = int(hop_length_seconds * target_sampling_rate)  # Adjust 'sample_rate' as needed

In [None]:
def extractLabel(name: str):
    return name.split("_")[0]

In [None]:
##### Loading the data
df_test = {input_column:[],output_column:[]}
for audios in os.listdir(directory):
    if not audios.startswith("."):
        df_test["path"].append(os.path.join(directory,audios))
        df_test["true_label"].append(extractLabel(audios))

## Convert this dict into huggingface datset for ease
df_test = Dataset.from_dict(df_test)
print(f"The Test Data looks like: \n{df_test}")

In [None]:
df_test[0]

In [None]:
def speech_file_to_array_fn(path: str):
    speech_array, sampling_rate = librosa.load(path, sr=target_sampling_rate)
    # speech_array = RemoveSilenceFromArray(speech_array, target_sampling_rate)
    return speech_array

def label_to_id(label, label_list):
    if len(label_list) > 0:
        return label_list.index(label) if label in label_list else -1
    return label

def preprocess_function(examples):
    speech_list = [np.array(speech_file_to_array_fn(path),dtype=np.float32) for path in examples[input_column]]
    examples['speech'] = list(speech_list)
    return examples

In [None]:
result = df_test.map(
    preprocess_function,
    batched=True, 
    batch_size=16
)
print(f"After preprocessing: {result}")

In [None]:
## function to store the hidden feature representation from the last layer of wave2vec2
def predictOneSecond(frames):
    features = processor(frames, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    # print(f"shape of the processed input is: {input_values.shape}")
    try:
        with torch.no_grad():
            logits = model_wave2vec2(input_values, attention_mask=attention_mask).logits 
    except Exception as err:
        print(f"Error -> {err} \nSKIPPED! Input Length was: {len(frames[-1])} and features len was : {input_values.shape}")
    return logits


def predictOne(x):
    # Generate overlapping frames
    frames = [x[i:i+window_size] for i in range(0, len(x) - window_size + 1, hop_size)]
    # print(frames)
    # print("len of the audio splitted into one second chunks: ",len(frames))
    if len(frames[-1])<100:
        print(f"Last element has small length of {len(frames[-1])} while it shall be {len(frames[0])}, Dropping!")
        frames.pop()
    logits = predictOneSecond(frames)
    preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
    preds = np.argmax(np.bincount(preds))
    return preds

def predict(batches):
    preds = [predictOne(arr) for arr in batches['speech']]
    batches['predicted'] = preds
    return batches

result2 = result.map(
    predict,
    batched=True,
    batch_size=16,
    # num_proc=4,
)
print(f"After predictions: {result2}")

In [None]:
y_true = [lang2id[name] for name in result[output_column]]
y_pred = result2["predicted"]

print(y_true[:15])
print(y_pred[:15])

In [None]:
result = classification_report(y_true, y_pred, target_names=label_list)
print(result)

  precision    recall  f1-score   support

         asm       0.80      1.00      0.89         4
         ben       1.00      0.20      0.33         5
         eng       0.00      0.00      0.00         4
         guj       0.80      1.00      0.89         4
         hin       0.25      0.50      0.33         4
         kan       0.50      1.00      0.67         4
         mal       1.00      0.60      0.75         5
         mar       0.43      0.60      0.50         5
         odi       1.00      0.60      0.75         5
         tam       0.71      1.00      0.83         5
         tel       0.50      0.38      0.43         8

    accuracy                           0.60        53
   macro avg       0.64      0.62      0.58        53

In [None]:
path = "/Users/yash/Desktop/MTP-2k23-24/TTS_data_SilenceRemovedData/hin/train_hindifullfemale_04219.wav"
