In [1]:
import os
import random
import json
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_DIR = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TRAIN"
TEST_AUDIO_DIR = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TEST_DUMMY_FINAL"  # Directory containing test audio files
OUTPUT_FILE = "keyword_detection_results.json"
THRESHOLD = 0.8  # Similarity threshold for keyword detection
SAMPLING_RATE = 16000
WINDOW_SIZE = 1.0   # in seconds for test audio segmentation
STEP_SIZE = 0.5      # in seconds for test audio segmentation

In [11]:
def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    """
    Filter classes to ensure each class has between 1 and 7 files.
    If >7, randomly select 7.
    Return a dictionary {class_label: [list_of_file_paths]}
    """
    class_files = {}

    for root, dirs, files in os.walk(train_dir):
        # Identify audio files
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

# Filter dataset
class_files = filter_dataset(TRAIN_DIR)

In [8]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers import Wav2Vec2Model

# Specify model and tokenizer names
model_name = "facebook/wav2vec2-xls-r-300m"
tokenizer_name = "facebook/wav2vec2-large-960h-lv60"

# Load tokenizer and feature extractor
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tokenizer_name)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)

# Create a Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# Load the XLS-R model
model = Wav2Vec2Model.from_pretrained(model_name)


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]


In [9]:
def extract_embedding(audio_path, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    return embedding

In [12]:
class_embeddings = {}
for cls, files in class_files.items():
    embeddings = []
    for file in files:
        emb = extract_embedding(file)
        embeddings.append(emb)
    class_embeddings[cls] = embeddings

# Create prototypes
prototypes = {}
for cls, embs in class_embeddings.items():
    if len(embs) > 0:
        prototype = np.mean(np.stack(embs), axis=0)
        prototypes[cls] = prototype

In [13]:
import pickle

# Paths to pickle files
GT_FILE = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TEST_DUMMY_FINAL/GT_dummy_final.pkl"


# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth = pickle.load(f)  # Currently a list

# Inspect the type and contents
print("Type of ground_truth:", type(ground_truth))
print("Number of entries:", len(ground_truth))

# Inspect the first few entries
for i, entry in enumerate(ground_truth[:5], 1):
    print(f"Entry {i}: {entry}")


Type of ground_truth: <class 'list'>
Number of entries: 250
Entry 1: {'./DATA/TEST_DUMMY_FINAL/0.wav': [{'keyword': 'رقیق', 'start_time': 1.6009375, 'end_time': 2.6009374999999997}, {'keyword': 'خوششون', 'start_time': 5.1909374999999995, 'end_time': 6.1909374999999995}]}
Entry 2: {'./DATA/TEST_DUMMY_FINAL/1.wav': [{'keyword': 'އޮފިސަރ', 'start_time': 2.65, 'end_time': 3.65}]}
Entry 3: {'./DATA/TEST_DUMMY_FINAL/2.wav': [{'keyword': 'принципиальной', 'start_time': 2.63, 'end_time': 3.63}, {'keyword': 'droit', 'start_time': 10.0063125, 'end_time': 11.0063125}, {'keyword': 'خواهم', 'start_time': 13.734874999999999, 'end_time': 14.734874999999999}, {'keyword': 'droit', 'start_time': 17.304875, 'end_time': 18.304875}]}
Entry 4: {'./DATA/TEST_DUMMY_FINAL/3.wav': [{'keyword': 'كذلك', 'start_time': 2.1591875, 'end_time': 3.1591875}]}
Entry 5: {'./DATA/TEST_DUMMY_FINAL/4.wav': [{'keyword': '配合', 'start_time': 2.2025, 'end_time': 3.2025}, {'keyword': 'مریضی', 'start_time': 10.4425, 'end_time': 11

In [23]:
# Load the new Ground Truth
GT_FILE = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TEST_DUMMY_FINAL/GT_dummy_final.pkl"
KW_TO_ID_FILE = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TEST_DUMMY_FINAL/kw_to_id.pkl"

with open(GT_FILE, 'rb') as f:
    ground_truth_data = pickle.load(f)

# Load Keyword to ID Mapping
with open(KW_TO_ID_FILE, 'rb') as f:
    kw_to_id = pickle.load(f)  # {keyword_str: keyword_id}

# Verify the loaded data structure
print("Type of ground_truth_data:", type(ground_truth_data))
print("Number of entries:", len(ground_truth_data))

# Example output should be a dictionary like:
# {"./DATA/TEST_DUMMY_FINAL/0.wav": [{"keyword": "رقیق", "start_time": 1.6, "end_time": 2.6}], ...}

# Initialize an empty ground truth dictionary
ground_truth = {}

for entry in ground_truth_data:
    for file_path, detections in entry.items():
        # Extract file ID without extension
        file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]

        # Initialize list for keywords if not already present
        if file_id_no_ext not in ground_truth:
            ground_truth[file_id_no_ext] = []

        # Iterate through each detection and extract keyword
        for detection in detections:
            keyword_str = detection['keyword']
            ground_truth[file_id_no_ext].append(keyword_str)

# Verify the processed ground truth
print("Processed Ground Truth:")
for file_id, keywords in list(ground_truth.items())[:2]:  # Print the first two entries
    print(f"File ID: {file_id}, Keywords: {keywords}")

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    print(min_dist)
    if min_dist < threshold:
        return proto_classes[min_dist_idx]
    else:
        return None

# Example test audio files
test_audio_files = [
    os.path.join(TEST_AUDIO_DIR, "0.wav"),
    os.path.join(TEST_AUDIO_DIR, "2.wav"),
    os.path.join(TEST_AUDIO_DIR, "40.wav")
]

results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords from ground_truth
    expected_keywords = ground_truth.get(file_id_no_ext, [])

    # Debug: Print expected keywords
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes to only those keywords expected in ground truth for this file
    # This helps reduce false positives drastically.
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

    # Debug: Print filtered prototypes
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    # Segment the audio and detect keywords
    segments = segment_audio(test_file)
    file_results = []

    last_detected_keyword = None
    last_detected_time = -999

    for (start_t, end_t, seg_emb) in segments:
        kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        if kw is not None:
            # Simple post-processing to avoid multiple detections of the same keyword in overlapping windows
            if kw != last_detected_keyword or (start_t - last_detected_time) > (STEP_SIZE * 2):
                kw_id = kw_to_id.get(kw, None)
                detection_entry = {
                    "keyword": kw,
                    "keyword_id": kw_id,
                    "start_time": start_t,
                    "end_time": end_t
                }
                file_results.append(detection_entry)
                last_detected_keyword = kw
                last_detected_time = start_t

    results[file_id_no_ext] = file_results

# Save Results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)


Type of ground_truth_data: <class 'list'>
Number of entries: 250
Processed Ground Truth:
File ID: 0, Keywords: ['رقیق', 'خوششون']
File ID: 1, Keywords: ['އޮފިސަރ']
Processing File ID: 0, Expected Keywords: ['رقیق', 'خوششون']
Filtered Prototypes for File ID 0: ['رقیق', 'خوششون']
0.04346061335453555
0.022952265020089735
0.021797786099904015
0.02509059344259601
0.020082752233126278
0.012265566163336517
0.036610338663502606
0.02497882486916636
0.026673360445400585
0.014865326185279248
0.00903614368363126
0.01463374366887249
0.035827693340814704
0.053470204995486936
0.0556808753137602
0.06330459071907257
Processing File ID: 2, Expected Keywords: ['принципиальной', 'droit', 'خواهم', 'droit']
Filtered Prototypes for File ID 2: ['droit', 'خواهم']
0.07649254599400923
0.054283954921989785
0.04523939623866846
0.09235091953083541
0.07215834848088765
0.054621214363829496
0.034207219977515746
0.018926711816945274
0.0368634550449799
0.05921119071893188
0.03928054271967296
0.1017194131270982
0.0866240

In [None]:
print(embeddings)

In [None]:
print(prototypes)

In [16]:
import pickle
import json

file_path = "/Users/adityasangale/Desktop/SIH2024/wav2vecxlsr/TEST_DUMMY_FINAL/GT_dummy_final.pkl"

with open(file_path, "rb") as file:
    data = pickle.load(file)

# Pretty-print the dictionary as JSON
print(json.dumps(data, indent=2, ensure_ascii=False))

[
  {
    "./DATA/TEST_DUMMY_FINAL/0.wav": [
      {
        "keyword": "رقیق",
        "start_time": 1.6009375,
        "end_time": 2.6009374999999997
      },
      {
        "keyword": "خوششون",
        "start_time": 5.1909374999999995,
        "end_time": 6.1909374999999995
      }
    ]
  },
  {
    "./DATA/TEST_DUMMY_FINAL/1.wav": [
      {
        "keyword": "އޮފިސަރ",
        "start_time": 2.65,
        "end_time": 3.65
      }
    ]
  },
  {
    "./DATA/TEST_DUMMY_FINAL/2.wav": [
      {
        "keyword": "принципиальной",
        "start_time": 2.63,
        "end_time": 3.63
      },
      {
        "keyword": "droit",
        "start_time": 10.0063125,
        "end_time": 11.0063125
      },
      {
        "keyword": "خواهم",
        "start_time": 13.734874999999999,
        "end_time": 14.734874999999999
      },
      {
        "keyword": "droit",
        "start_time": 17.304875,
        "end_time": 18.304875
      }
    ]
  },
  {
    "./DATA/TEST_DUMMY_FINAL/3.wav": [
  