In [3]:
import os
import torch
import torchaudio
from transformers import ASTForAudioClassification
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
from transformers import Wav2Vec2Model, Wav2Vec2Processor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Function to load and preprocess MP3 or WAV audio using librosa
def load_audio(file_path):
    # Load audio using librosa (handles both MP3 and WAV)
    waveform, sample_rate = librosa.load(file_path, sr=16000)  # Resampling to 16kHz
    return waveform, sample_rate

# Function to extract embeddings from audio
def extract_embeddings(audio_path):
    # Load the audio file using librosa
    waveform, sample_rate = load_audio(audio_path)

    # Ensure the waveform is in the shape (batch_size, sequence_length)
    waveform = torch.tensor(waveform).squeeze()  # Add batch dimension

    # Process the audio to get features using the processor
    input_values = processor(waveform, return_tensors="pt", padding=True).input_values

    # Extract the hidden states (embeddings) from the model
    with torch.no_grad():
        outputs = model(input_values)

    # Extract embeddings (output of the last hidden layer)
    embeddings = outputs.last_hidden_state
    return embeddings

In [6]:
# Function to process all audio files in subfolders and extract embeddings
def process_subfolders(directory_path):
    embeddings_dict = {}  # Store embeddings by subfolder and filename
    for subfolder in os.listdir(directory_path):
        subfolder_path = os.path.join(directory_path, subfolder)
        
        if os.path.isdir(subfolder_path):
            subfolder_embeddings = {}  # Store embeddings for the current subfolder
            for filename in os.listdir(subfolder_path):
                if filename.endswith(".mp3") or filename.endswith(".wav"):
                    audio_path = os.path.join(subfolder_path, filename)
                    print(f"Processing {filename} in {subfolder}...")
                    embeddings = extract_embeddings(audio_path)
                    subfolder_embeddings[filename] = embeddings
                    print(f"Extracted embeddings for {filename}")
            embeddings_dict[subfolder] = subfolder_embeddings
    return embeddings_dict


In [10]:
# Example usage
audio_file = "TRAIN\TRAIN\ar\أذهب\common_voice_ar_19068306.wav"  # Replace with your file path
embeddings = extract_embeddings(audio_file)

# You can now use the embeddings to create your prototypical network
print(embeddings.shape)

  waveform, sample_rate = librosa.load(file_path, sr=16000)  # Resampling to 16kHz


OSError: [Errno 22] Invalid argument: 'TRAIN\\TRAIN\x07r\\أذهب\\common_voice_ar_19068306.wav'

In [69]:
print(embeddings)

tensor([[[ 0.0024, -0.0283,  0.0913,  ..., -0.2030,  0.0268, -0.1202],
         [ 0.0042, -0.0230,  0.0943,  ..., -0.2063,  0.0253, -0.1277],
         [ 0.0125, -0.0256,  0.0967,  ..., -0.2084,  0.0285, -0.1333],
         ...,
         [ 0.0038, -0.0256,  0.0846,  ..., -0.2283,  0.0202, -0.1008],
         [ 0.0025, -0.0263,  0.0845,  ..., -0.2261,  0.0172, -0.0968],
         [-0.0023, -0.0280,  0.0848,  ..., -0.2272,  0.0160, -0.0964]]])


In [71]:
# Example usage
dataset_directory = 'Few_shots\Tamil'  # Replace with your dataset directory path
embeddings_dict = process_subfolders(dataset_directory)

Processing common_voice_ta_22102795.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22102795.mp3
Processing common_voice_ta_22119939.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22119939.mp3
Processing common_voice_ta_22136198.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22136198.mp3
Processing common_voice_ta_21689633.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21689633.mp3
Processing common_voice_ta_21711106.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21711106.mp3
Processing common_voice_ta_21721820.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21721820.mp3
Processing common_voice_ta_21896827.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896827.mp3
Processing common_voice_ta_22018504.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018504.mp3
Processing common_voice_ta_22024485.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22024485.mp3
Processing common_voice_ta_21896817.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896817.mp3
Processing common_voice_ta_21971682.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971682.mp3
Processing common_voice_ta_22018498.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018498.mp3
Processing common_voice_ta_21896819.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896819.mp3
Processing common_voice_ta_21971728.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971728.mp3
Processing common_voice_ta_22018511.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018511.mp3
Processing common_voice_ta_21714254.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21714254.mp3
Processing common_voice_ta_21896829.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896829.mp3
Processing common_voice_ta_21971726.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971726.mp3
Processing common_voice_ta_21843468.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21843468.mp3
Processing common_voice_ta_21971677.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971677.mp3
Processing common_voice_ta_22018509.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018509.mp3
Processing common_voice_ta_22136234.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22136234.mp3
Processing common_voice_ta_22280772.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22280772.mp3
Processing common_voice_ta_22305206.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22305206.mp3
Processing common_voice_ta_21689632.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21689632.mp3
Processing common_voice_ta_21746487.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21746487.mp3
Processing common_voice_ta_21829429.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21829429.mp3
Processing common_voice_ta_21708919.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21708919.mp3
Processing common_voice_ta_21746221.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21746221.mp3
Processing common_voice_ta_21896820.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896820.mp3
Processing common_voice_ta_21829428.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21829428.mp3
Processing common_voice_ta_21896824.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896824.mp3
Processing common_voice_ta_21971674.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971674.mp3
Processing common_voice_ta_21896828.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896828.mp3
Processing common_voice_ta_21971675.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971675.mp3
Processing common_voice_ta_22018500.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018500.mp3
Processing common_voice_ta_21896823.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21896823.mp3
Processing common_voice_ta_21971725.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_21971725.mp3
Processing common_voice_ta_22018552.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22018552.mp3
Processing common_voice_ta_22102692.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22102692.mp3
Processing common_voice_ta_22280779.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Extracted embeddings for common_voice_ta_22280779.mp3
Processing common_voice_ta_22305162.mp3 in ஹே...
Extracted embeddings for common_voice_ta_22305162.mp3


In [73]:
print(embeddings_dict)

{'ஃபயர்ஃபாக்ஸ்': {'common_voice_ta_22102795.mp3': tensor([[[-0.0569,  0.0471,  0.0583,  ..., -0.0502,  0.0683, -0.0767],
         [-0.0554,  0.0480,  0.0624,  ..., -0.0581,  0.0681, -0.0791],
         [-0.0548,  0.0489,  0.0634,  ..., -0.0625,  0.0676, -0.0782],
         ...,
         [-0.0528,  0.0437,  0.0467,  ..., -0.0622,  0.0665, -0.0556],
         [-0.0557,  0.0436,  0.0440,  ..., -0.0658,  0.0689, -0.0476],
         [-0.0587,  0.0413,  0.0525,  ..., -0.0608,  0.0683, -0.0507]]]), 'common_voice_ta_22119939.mp3': tensor([[[-0.0207,  0.0283,  0.0356,  ..., -0.1737,  0.0563, -0.0175],
         [-0.0276,  0.0181,  0.0477,  ..., -0.1859,  0.0596, -0.0164],
         [-0.0332,  0.0055,  0.0433,  ..., -0.1968,  0.0695, -0.0233],
         ...,
         [-0.0371, -0.0009,  0.0534,  ..., -0.1940,  0.0691, -0.0302],
         [-0.0323, -0.0031,  0.0550,  ..., -0.1982,  0.0671, -0.0195],
         [-0.0281,  0.0034,  0.0532,  ..., -0.1959,  0.0642, -0.0225]]]), 'common_voice_ta_22136198.mp3': 

In [76]:
# import torch
# import learn2learn as l2l
# import torch.nn.functional as F

# # Define the FewShotModel (similar to the one you have)
# class FewShotModel(torch.nn.Module):
#     def __init__(self):
#         super(FewShotModel, self).__init__()
#         self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
#         self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

#     def forward(self, input_values):
#         outputs = self.model(input_values)
#         return outputs.last_hidden_state

# # Function to compute the prototypes from the embeddings
# def compute_prototypes(embeddings_dict):
#     prototypes = {}
#     for class_name, embeddings in embeddings_dict.items():
#         prototypes[class_name] = torch.mean(torch.stack(embeddings), dim=0)  # Mean embedding
#     return prototypes

# # MAML training loop with Prototypical Networks
# def meta_train(model, tasks, num_iterations=1000, meta_lr=1e-3, inner_lr=1e-2):
#     meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
    
#     for iteration in range(num_iterations):
#         # Sample a few tasks
#         task = tasks.sample()  # Assume tasks is a batch of few-shot learning tasks
        
#         # Split into support and query set for the current task
#         support_set, query_set = task.get_support_set(), task.get_query_set()

#         # Compute embeddings for support set using the model
#         support_embeddings = model(support_set)  # (support_samples, embedding_dim)

#         # Compute prototypes from support set embeddings (Prototypical Networks)
#         prototypes = compute_prototypes(support_embeddings)

#         # Now fine-tune the model with MAML: compute gradients based on query set
#         query_embeddings = model(query_set)  # (query_samples, embedding_dim)
        
#         # Calculate loss (using query set embeddings and prototypes)
#         loss = compute_loss(query_embeddings, prototypes)
        
#         # Perform a gradient update using MAML's meta-learning approach
#         meta_optimizer.zero_grad()
#         loss.backward()
#         meta_optimizer.step()
        
#         print(f"Iteration {iteration}, Loss: {loss.item()}")

# # Function to compute loss between query set and prototypes
# def compute_loss(query_embeddings, prototypes):
#     # Compute the Euclidean distance between query samples and prototypes
#     distances = {}
#     for class_name, prototype in prototypes.items():
#         distance = F.pairwise_distance(query_embeddings, prototype.unsqueeze(0))  # Compute Euclidean distance
#         distances[class_name] = distance
#     return torch.mean(torch.stack(list(distances.values())))

# # Example inference for new keyword using prototypical networks
# def infer_with_prototypes(model, audio_path, embeddings_dict):
#     new_audio_embedding = extract_embeddings(audio_path)  # Extract embedding for the new sample
#     prototypes = compute_prototypes(embeddings_dict)  # Compute prototypes from the training data
#     predicted_class = infer_with_prototypes(new_audio_embedding, prototypes)
#     print(f"Predicted class for the new audio sample: {predicted_class}")
    
# def infer_with_prototypes(new_audio_embedding, prototypes):
#     # Compute the Euclidean distance from the new embedding to each prototype
#     distances = {}
#     for class_name, prototype in prototypes.items():
#         distance = F.pairwise_distance(new_audio_embedding, prototype.unsqueeze(0))  # Compute Euclidean distance
#         distances[class_name] = distance.item()
#     predicted_class = min(distances, key=distances.get)
#     return predicted_class


In [None]:
import torch
import os
import librosa
import torch.nn as nn
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch.nn.functional as F
import numpy as np

# 1. Define the model
class FewShotModel(nn.Module):
    def __init__(self):
        super(FewShotModel, self).__init__()
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

    def forward(self, input_values):
        outputs = self.model(input_values)
        # Extract the embeddings from the first element of the output tuple
        return outputs[0]  # This corresponds to the last hidden state (embeddings)

# 2. Load audio file using librosa (mp3 or wav)
def load_audio(file_path):
    waveform, sample_rate = librosa.load(file_path, sr=16000)  # Resampling to 16kHz
    return waveform, sample_rate

# 3. Extract embeddings from audio
def extract_embeddings(audio_path, model, processor):
    # Load the audio file
    waveform, sample_rate = load_audio(audio_path)

    # Process the audio to get features using the processor
    input_values = processor(waveform, return_tensors="pt", padding=True).input_values

    # Extract the hidden states (embeddings) from the model
    with torch.no_grad():
        outputs = model(input_values)

    # Return the embeddings from the last hidden layer
    embeddings = outputs  # Since the output is a tensor now
    return embeddings

# 4. Compute prototypes for each class
def compute_prototypes(embeddings_dict):
    prototypes = {}
    for class_name, embeddings in embeddings_dict.items():
        prototypes[class_name] = torch.mean(torch.stack(embeddings), dim=0)  # Mean embedding
    return prototypes

# 5. Inference using Prototypical Networks (find nearest prototype)
def infer_with_prototypes(embedding, prototypes):
    # Compute the distances to each prototype (class center)
    distances = {}
    for class_name, prototype in prototypes.items():
        distance = F.pairwise_distance(embedding, prototype.unsqueeze(0))  # Euclidean distance
        distances[class_name] = distance.item()

    # Classify based on the minimum distance
    predicted_class = min(distances, key=distances.get)
    return predicted_class

# 6. Meta-Training with MAML and Prototypical Networks
def meta_train(model, tasks, num_iterations=1000, lr_inner=0.01, lr_outer=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_outer)

    for iteration in range(num_iterations):
        task = tasks.sample()  # Sample a batch of tasks

        support_set, query_set = task.get_support_set(), task.get_query_set()

        # Step 1: Perform an inner gradient update for the support set using MAML
        support_embeddings = model(support_set)
        support_loss = F.cross_entropy(support_embeddings, task.support_labels)
        model.zero_grad()
        support_loss.backward()

        # Inner loop update
        optimizer.step()

        # Step 2: Compute the prototypes for each class (using the support set)
        prototypes = compute_prototypes(support_embeddings)

        # Step 3: Inference with prototypes (using the query set)
        query_embeddings = model(query_set)
        predicted_class = infer_with_prototypes(query_embeddings, prototypes)

        print(f"Iteration {iteration}, Predicted class: {predicted_class}")

# 7. Process all subfolders and extract embeddings for meta-training
def process_subfolders(directory_path, model, processor):
    embeddings_dict = {}  # Store embeddings by subfolder and filename
    for subfolder in os.listdir(directory_path):
        subfolder_path = os.path.join(directory_path, subfolder)
        
        if os.path.isdir(subfolder_path):
            subfolder_embeddings = []  # Store embeddings for the current subfolder
            for filename in os.listdir(subfolder_path):
                if filename.endswith(".mp3") or filename.endswith(".wav"):
                    audio_path = os.path.join(subfolder_path, filename)
                    print(f"Processing {filename} in {subfolder}...")
                    embeddings = extract_embeddings(audio_path, model, processor)
                    subfolder_embeddings.append(embeddings)
            embeddings_dict[subfolder] = subfolder_embeddings
    return embeddings_dict

# 8. Inference with new keywords
def inference_with_new_keywords(model, audio_path, embeddings_dict, processor):
    # Extract the embedding for the new audio sample
    new_audio_embedding = extract_embeddings(audio_path, model, processor)

    # Compute prototypes from the embeddings dictionary (training data)
    prototypes = compute_prototypes(embeddings_dict)

    # Classify the new audio sample by finding the nearest prototype
    predicted_class = infer_with_prototypes(new_audio_embedding, prototypes)
    print(f"Predicted class for the new audio sample: {predicted_class}")

# Example Usage:

# Define dataset directory path
dataset_directory = 'cv-corpus-7.0-singleword\ta'  # Replace with your dataset directory path

# Initialize model and processor
model = FewShotModel()
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

# Process the dataset (subfolders and audio files)
embeddings_dict = process_subfolders(dataset_directory, model, processor)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22102795.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22119939.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22136198.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21689633.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21711106.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21721820.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896827.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018504.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22024485.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896817.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971682.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018498.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896819.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971728.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018511.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21714254.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896829.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971726.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21843468.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971677.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018509.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22136234.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22280772.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22305206.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21689632.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21746487.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21829429.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21708919.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21746221.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896820.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21829428.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896824.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971674.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896828.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971675.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018500.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896823.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971725.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018552.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22102692.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22280779.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22305162.mp3 in ஹே...


In [92]:
import torch
import os
import librosa
import torch.nn.functional as F
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# 1. Load audio file using librosa (mp3 or wav)
def load_audio(file_path):
    waveform, sample_rate = librosa.load(file_path, sr=16000)  # Resampling to 16kHz
    return waveform, sample_rate

# 2. Extract embeddings from audio
def extract_embeddings(audio_path, model, processor):
    # Load the audio file
    waveform, sample_rate = load_audio(audio_path)

    # Process the audio to get features using the processor
    input_values = processor(waveform, return_tensors="pt", padding=True).input_values

    # Extract the hidden states (embeddings) from the model
    with torch.no_grad():
        outputs = model(input_values)  # Wav2Vec2BaseModelOutput

    # Access last_hidden_state and average across the time dimension
    embeddings = torch.mean(outputs.last_hidden_state, dim=1)
    return embeddings.squeeze(0)  # Remove batch dimension

# 3. Compute prototypes for each class
def compute_prototypes(embeddings_dict):
    prototypes = {}
    for class_name, embeddings in embeddings_dict.items():
        embeddings_tensor = torch.stack(embeddings)  # Stack the embeddings
        prototypes[class_name] = torch.mean(embeddings_tensor, dim=0)  # Mean embedding
    return prototypes

# 4. Inference using Prototypical Networks (find nearest prototype)
def infer_with_prototypes(embedding, prototypes):
    # Compute the distances to each prototype (class center)
    distances = {}
    for class_name, prototype in prototypes.items():
        distance = F.pairwise_distance(embedding.unsqueeze(0), prototype.unsqueeze(0))  # Euclidean distance
        distances[class_name] = distance.item()

    # Classify based on the minimum distance
    predicted_class = min(distances, key=distances.get)
    return predicted_class

# 5. Process all subfolders and extract embeddings for meta-training
def process_subfolders(directory_path, model, processor):
    embeddings_dict = {}  # Store embeddings by subfolder and filename
    for subfolder in os.listdir(directory_path):
        subfolder_path = os.path.join(directory_path, subfolder)
        
        if os.path.isdir(subfolder_path):
            subfolder_embeddings = []  # Store embeddings for the current subfolder
            for filename in os.listdir(subfolder_path):
                if filename.endswith(".mp3") or filename.endswith(".wav"):
                    audio_path = os.path.join(subfolder_path, filename)
                    print(f"Processing {filename} in {subfolder}...")
                    embeddings = extract_embeddings(audio_path, model, processor)
                    subfolder_embeddings.append(embeddings)
            embeddings_dict[subfolder] = subfolder_embeddings
    return embeddings_dict

# 6. Inference with new keywords
def inference_with_new_keywords(model, audio_path, embeddings_dict, processor):
    # Extract the embedding for the new audio sample
    new_audio_embedding = extract_embeddings(audio_path, model, processor)

    # Compute prototypes from the embeddings dictionary (training data)
    prototypes = compute_prototypes(embeddings_dict)

    # Classify the new audio sample by finding the nearest prototype
    predicted_class = infer_with_prototypes(new_audio_embedding, prototypes)
    print(f"Predicted class for the new audio sample: {predicted_class}")

# Example Usage:

# Define dataset directory path
dataset_directory = 'Few_shots/Tamil'  # Replace with your dataset directory path

# Initialize model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Process the dataset (subfolders and audio files)
embeddings_dict = process_subfolders(dataset_directory, model, processor)

# Example: Inference with a new keyword
audio_file = "Few_shots/Tamil/ஆம்/common_voice_ta_21689633.mp3"  # Replace with your new audio file path
inference_with_new_keywords(model, audio_file, embeddings_dict, processor)


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22102795.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22119939.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22136198.mp3 in ஃபயர்ஃபாக்ஸ்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21689633.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21711106.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21721820.mp3 in ஆம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896827.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018504.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22024485.mp3 in ஆறு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896817.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971682.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018498.mp3 in இரண்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896819.mp3 in இல்லை...
Processing common_voice_ta_21971728.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018511.mp3 in இல்லை...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21714254.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896829.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971726.mp3 in எட்டு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21843468.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971677.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018509.mp3 in ஏழு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22136234.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22280772.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22305206.mp3 in ஐந்து...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21689632.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21746487.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21829429.mp3 in ஒன்பது...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21708919.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21746221.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896820.mp3 in ஒன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21829428.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896824.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971674.mp3 in நான்கு...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896828.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971675.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018500.mp3 in பூஜ்யம்...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21896823.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_21971725.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22018552.mp3 in மூன்று...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22102692.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22280779.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Processing common_voice_ta_22305162.mp3 in ஹே...


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Predicted class for the new audio sample: ஆம்


In [None]:
# Example: Inference with a new keyword
audio_file = "dataset\one_word_tamil\ஏழு\common_voice_ta_21843468.mp3"  # Replace with your new audio file path
inference_with_new_keywords(model, audio_file, embeddings_dict, processor)

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Predicted class for the new audio sample: ஏழு


In [103]:
# Example: Inference with a new keyword
audio_file = "dataset\\cv-corpus-5.1-2020-06-22-ta\\cv-corpus-5.1-2020-06-22\\ta\\clips\\common_voice_ta_20319087.mp3"  # Replace with your new audio file path
inference_with_new_keywords(model, audio_file, embeddings_dict, processor)

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Predicted class for the new audio sample: மூன்று
