In [2]:
import os
import random
import json
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
TRAIN_DIR = "F:\KWS\TRAIN"
TEST_AUDIO_DIR = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL"
OUTPUT_FILE = "dummy_final_results.json"
THRESHOLD = 0.8  # Similarity threshold for keyword detection
SAMPLING_RATE = 16000
WINDOW_SIZE = 1.0   # in seconds for test audio segmentation
STEP_SIZE = 0.5      # in seconds for test audio segmentation

In [15]:

def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    """
    Filter classes to ensure each class has between 1 and 7 files.
    If >7, randomly select 7.
    Return a dictionary {class_label: [list_of_file_paths]}
    """
    class_files = {}

    for root, dirs, files in os.walk(train_dir):
        # Identify audio files
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

# Filter dataset
class_files = filter_dataset(TRAIN_DIR)

In [16]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
#model.eval()

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
def extract_embedding(audio_path, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    return embedding

In [18]:
class_embeddings = {}
for cls, files in class_files.items():
    embeddings = []
    for file in files:
        emb = extract_embedding(file)
        embeddings.append(emb)
    class_embeddings[cls] = embeddings

# Create prototypes
prototypes = {}
for cls, embs in class_embeddings.items():
    if len(embs) > 0:
        prototype = np.mean(np.stack(embs), axis=0)
        prototypes[cls] = prototype

In [34]:
import pickle

# Paths to pickle files
GT_FILE = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\GT_dummy_final.pkl"

# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth = pickle.load(f)  # Currently a list

# Inspect the type and contents
print("Type of ground_truth:", type(ground_truth))
print("Number of entries:", len(ground_truth))

# Inspect the first few entries
for i, entry in enumerate(ground_truth[:5], 1):
    print(f"Entry {i}: {entry}")


Type of ground_truth: <class 'list'>
Number of entries: 250
Entry 1: {'./DATA/TEST_DUMMY_FINAL/0.wav': [{'keyword': 'رقیق', 'start_time': 1.6009375, 'end_time': 2.6009374999999997}, {'keyword': 'خوششون', 'start_time': 5.1909374999999995, 'end_time': 6.1909374999999995}]}
Entry 2: {'./DATA/TEST_DUMMY_FINAL/1.wav': [{'keyword': 'އޮފިސަރ', 'start_time': 2.65, 'end_time': 3.65}]}
Entry 3: {'./DATA/TEST_DUMMY_FINAL/2.wav': [{'keyword': 'принципиальной', 'start_time': 2.63, 'end_time': 3.63}, {'keyword': 'droit', 'start_time': 10.0063125, 'end_time': 11.0063125}, {'keyword': 'خواهم', 'start_time': 13.734874999999999, 'end_time': 14.734874999999999}, {'keyword': 'droit', 'start_time': 17.304875, 'end_time': 18.304875}]}
Entry 4: {'./DATA/TEST_DUMMY_FINAL/3.wav': [{'keyword': 'كذلك', 'start_time': 2.1591875, 'end_time': 3.1591875}]}
Entry 5: {'./DATA/TEST_DUMMY_FINAL/4.wav': [{'keyword': '配合', 'start_time': 2.2025, 'end_time': 3.2025}, {'keyword': 'مریضی', 'start_time': 10.4425, 'end_time': 11

In [35]:
# Paths to pickle files
GT_FILE = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\GT_dummy_final.pkl"
KW_TO_ID_FILE = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\kw_to_id.pkl"

# Load Keyword to ID Mapping
with open(KW_TO_ID_FILE, 'rb') as f:
    kw_to_id = pickle.load(f)  # {keyword_str: keyword_id}

# Invert the mapping to get ID to Keyword
id_to_kw = {v: k for k, v in kw_to_id.items()}

# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth_list = pickle.load(f)  # List of dictionaries

# Verify Ground Truth Type
print("Type of ground_truth:", type(ground_truth_list))
print("Number of entries:", len(ground_truth_list))
# Initialize empty dictionary for ground truth
ground_truth = {}

for idx, entry in enumerate(ground_truth_list, 1):
    for file_path, detections in entry.items():
        # Extract file ID without extension
        file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        
        # Initialize list for keywords if not already present
        if file_id_no_ext not in ground_truth:
            ground_truth[file_id_no_ext] = []
        
        # 'detections' is already a list of dictionaries
        for detection in detections:
            keyword_str = detection.get('keyword')
            if keyword_str:
                ground_truth[file_id_no_ext].append(keyword_str)
            else:
                print(f"Warning: No keyword found in detection: {detection}")


# for idx, entry in enumerate(ground_truth_list, 1):
#     for file_path, nested_dict in entry.items():
#         # Extract file ID without extension
#         file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        
#         # Initialize list for keywords if no    t already present
#         if file_id_no_ext not in ground_truth:
#             ground_truth[file_id_no_ext] = []
        
#         # Check if nested_dict is a list
#         if isinstance(nested_dict, list):
#             for detection in nested_dict:
#                 keyword_id = detection.get('keyword')
#                 keyword_str = id_to_kw.get(keyword_id, None)
#                 if keyword_str:
#                     ground_truth[file_id_no_ext].append(keyword_str)
#                 else:
#                     print(f"Warning: Keyword ID {keyword_id} not found in mapping.")
#         else:
#             print(f"Unexpected data structure in entry {idx}: {nested_dict}")


Type of ground_truth: <class 'list'>
Number of entries: 250


In [36]:

# # Paths to pickle files
# GT_FILE = "F:\\KWS\\TEST_DUMMY_CORRECTED\\TEST_DUMMY_CORRECTED\\GT_dummy.pickle"
# KW_TO_ID_FILE = "F:\\KWS\\TEST_DUMMY_CORRECTED\\TEST_DUMMY_CORRECTED\\kw_to_id.pkl"

# # Load Keyword to ID Mapping
# with open(KW_TO_ID_FILE, 'rb') as f:
#     kw_to_id = pickle.load(f)  # {keyword_str: keyword_id}

# # Invert the mapping to get ID to Keyword
# id_to_kw = {v: k for k, v in kw_to_id.items()}

# # Load Ground Truth
# with open(GT_FILE, 'rb') as f:
#     ground_truth_list = pickle.load(f)  # List of dictionaries

# # Verify Ground Truth Type
# print("Type of ground_truth:", type(ground_truth_list))
# print("Number of entries:", len(ground_truth_list))

# # Process Ground Truth to create a mapping: {file_id_no_ext: [keyword_str, ...]}
# ground_truth = {}  # Initialize empty dictionary

# for idx, entry in enumerate(ground_truth_list, 1):
#     for file_path, nested_dict in entry.items():
#         # Extract file ID without extension
#         file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        
#         # Initialize list for keywords if not already present
#         if file_id_no_ext not in ground_truth:
#             ground_truth[file_id_no_ext] = []
        
#         # Extract detections
#         detections = nested_dict.get(file_path, [])
        
#         for detection in detections:
#             keyword_id = detection.get('keyword')
#             keyword_str = id_to_kw.get(keyword_id, None)
#             if keyword_str:
#                 ground_truth[file_id_no_ext].append(keyword_str)
#             else:
#                 print(f"Warning: Keyword ID {keyword_id} not found in mapping.")

# Debug: Inspect the processed ground_truth
print("Processed Ground Truth:")
for key, value in list(ground_truth.items())[:2]:  # Print first two entries
    print(f"File ID: {key}, Keywords: {value}")

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    if min_dist < threshold:
        return proto_classes[min_dist_idx]
    else:
        return None

# Only these two files
test_audio_files = [
    os.path.join(TEST_AUDIO_DIR, "0.wav"),
    os.path.join(TEST_AUDIO_DIR, "1.wav")
]

results = {}



for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords from ground_truth
    expected_keywords = ground_truth.get(file_id_no_ext, [])
    
    # Debug: Print expected keywords
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes to only those keywords expected in ground truth for this file
    # This helps reduce false positives drastically.
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

    # Debug: Print filtered prototypes
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    segments = segment_audio(test_file)
    file_results = []

    last_detected_keyword = None
    last_detected_time = -999

    # for (start_t, end_t, seg_emb) in segments:
    #     kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
    #     if kw is not None:
    #         # Simple post-processing to avoid multiple detections of the same keyword in overlapping windows
    #         if kw != last_detected_keyword or (start_t - last_detected_time) > (STEP_SIZE * 2):
    #             kw_id = kw_to_id.get(kw, None)
    #             detection_entry = {
    #                 "keyword": kw,
    #                 "keyword_id": kw_id,
    #                 "start_time": start_t,
    #                 "end_time": end_t
    #             }
    #             file_results.append(detection_entry)
    #             last_detected_keyword = kw
    #             last_detected_time = start_t

    # Initialize detection variables
    current_keyword = None
    current_keyword_start_time = None

    for (start_t, end_t, seg_emb) in segments:
        kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        if kw is not None:
            if kw == current_keyword:
                # Update the end time if the same keyword is still being detected
                last_detected_time = end_t
            else:
                # Finalize the previous keyword's detection if a new one starts
                if current_keyword is not None:
                    detection_entry = {
                        "keyword": current_keyword,
                        "keyword_id": kw_to_id.get(current_keyword, None),
                        "start_time": current_keyword_start_time,
                        "end_time": last_detected_time
                    }
                    file_results.append(detection_entry)
                
                # Start tracking the new keyword
                current_keyword = kw
                current_keyword_start_time = start_t
                last_detected_time = end_t
        else:
            # Finalize the previous keyword's detection if no keyword is detected
            if current_keyword is not None:
                detection_entry = {
                    "keyword": current_keyword,
                    "keyword_id": kw_to_id.get(current_keyword, None),
                    "start_time": current_keyword_start_time,
                    "end_time": last_detected_time
                }
                file_results.append(detection_entry)
                current_keyword = None
                current_keyword_start_time = None

    # Finalize the last detected keyword, if any
    if current_keyword is not None:
        detection_entry = {
            "keyword": current_keyword,
            "keyword_id": kw_to_id.get(current_keyword, None),
            "start_time": current_keyword_start_time,
            "end_time": last_detected_time
        }
        file_results.append(detection_entry)


    results[file_id_no_ext] = file_results

# Save Results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)

Processed Ground Truth:
File ID: 0, Keywords: ['رقیق', 'خوششون']
File ID: 1, Keywords: ['އޮފިސަރ']
Processing File ID: 0, Expected Keywords: ['رقیق', 'خوششون']
Filtered Prototypes for File ID 0: ['خوششون', 'رقیق']
Processing File ID: 1, Expected Keywords: ['އޮފިސަރ']
Filtered Prototypes for File ID 1: ['އޮފިސަރ']
Keyword detection complete. Results saved to: dummy_final_results.json


In [37]:
def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            # Store frame-level embeddings
            frame_embeddings = outputs.last_hidden_state.squeeze(0).numpy()  # shape: [frames, hidden_dim]
            seg_emb = frame_embeddings.mean(axis=0) # average over frames
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb, frame_embeddings))

        start += hop_length
        end = start + win_length

    return segments


def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None, None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]

    confidence = 1 - min_dist

    if min_dist < threshold:
        return proto_classes[min_dist_idx], confidence
    else:
        return None, None


def refine_boundaries(frame_embeddings, prototype_vector, threshold=THRESHOLD, sr=SAMPLING_RATE, segment_start=0.0, segment_end=1.0):
    """
    Refine start/end times by examining each frame in the segment.
    Identify contiguous frames that strongly match the prototype.
    """
    # Compute frame-level distances
    distances = cdist(frame_embeddings, [prototype_vector], metric='cosine').squeeze(1) # shape: [frames,]
    # Frames that are considered 'keyword frames'
    keyword_frames = np.where(distances < threshold)[0]

    if len(keyword_frames) == 0:
        # If no frames pass threshold, return the entire segment or None.
        # Here we return the entire segment by default, but ideally handle this gracefully.
        return segment_start, segment_end

    # Find contiguous block of frames with minimal average distance or simply the first to last in the group
    # We'll choose the largest contiguous block of keyword_frames as the keyword region.
    breaks = np.where(np.diff(keyword_frames) > 1)[0]
    # start indices of each block
    block_starts = np.insert(keyword_frames[breaks+1], 0, keyword_frames[0])
    # end indices of each block
    block_ends = np.append(keyword_frames[breaks], keyword_frames[-1])

    # Select the largest block (most frames)
    block_lengths = block_ends - block_starts
    max_block_index = np.argmax(block_lengths)
    best_start_frame = block_starts[max_block_index]
    best_end_frame = block_ends[max_block_index]

    # Convert frame index to time
    # Duration of each frame: 
    # The model output frames roughly correspond to ~20-25 ms frames depending on the model. 
    # Check model documentation for exact hop length. For demonstration, let's assume:
    frame_duration = (segment_end - segment_start) / frame_embeddings.shape[0]
    refined_start_time = segment_start + best_start_frame * frame_duration
    refined_end_time = segment_start + (best_end_frame + 1) * frame_duration

    return refined_start_time, refined_end_time


# After loading prototypes, ground_truth, etc.

results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords
    expected_keywords = ground_truth.get(file_id_no_ext, [])
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    segments = segment_audio(test_file)
    # Instead of finalizing immediately, keep track of the best detection for each keyword
    keyword_best_detection = {}  # {keyword: (confidence, start_t, end_t, frame_embeddings, prototype_vec, segment_start, segment_end)}

    for (segment_start_time, segment_end_time, seg_emb, frame_embeddings) in segments:
        kw, kw_confidence = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        
        if kw is not None:
            # If this keyword is better than the one recorded before, update it
            if kw not in keyword_best_detection or kw_confidence > keyword_best_detection[kw][0]:
                prototype_vec = filtered_prototypes[kw]
                keyword_best_detection[kw] = (
                    kw_confidence,
                    segment_start_time,
                    segment_end_time,
                    frame_embeddings,
                    prototype_vec
                )

    # Now refine boundaries and create final entries for each keyword
    file_results = []
    for kw, (conf, seg_start, seg_end, frame_emb, proto_vec) in keyword_best_detection.items():
        refined_start, refined_end = refine_boundaries(
            frame_embeddings=frame_emb,
            prototype_vector=proto_vec,
            threshold=THRESHOLD,
            sr=SAMPLING_RATE,
            segment_start=seg_start,
            segment_end=seg_end
        )

        detection_entry = {
            "keyword": kw,
            "keyword_id": kw_to_id.get(kw, None),
            "start_time": refined_start,
            "end_time": refined_end,
            "confidence": conf
        }
        file_results.append(detection_entry)

    results[file_id_no_ext] = file_results

# Save only the best occurrences with refined times
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)


Processing File ID: 0, Expected Keywords: ['رقیق', 'خوششون']
Filtered Prototypes for File ID 0: ['خوششون', 'رقیق']
Processing File ID: 1, Expected Keywords: ['އޮފިސަރ']
Filtered Prototypes for File ID 1: ['އޮފިސަރ']
Keyword detection complete. Results saved to: dummy_final_results.json


In [27]:
import os
import glob
import json
import numpy as np
import librosa
import torch
from scipy.spatial.distance import cdist

test_audio_files = glob.glob(os.path.join(TEST_AUDIO_DIR, "*.wav"))

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            frame_embeddings = outputs.last_hidden_state.squeeze(0).numpy()  # [frames, hidden_dim]
            seg_emb = frame_embeddings.mean(axis=0) # average over frames
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb, frame_embeddings))

        start += hop_length
        end = start + win_length

    return segments


def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None, None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]

    confidence = 1 - min_dist

    if min_dist < threshold:
        return proto_classes[min_dist_idx], confidence
    else:
        return None, None


def refine_boundaries(frame_embeddings, prototype_vector, threshold=THRESHOLD, sr=SAMPLING_RATE, segment_start=0.0, segment_end=1.0):
    """
    Refine start/end times by examining each frame in the segment.
    Identify contiguous frames that strongly match the prototype.
    """
    distances = cdist(frame_embeddings, [prototype_vector], metric='cosine').squeeze(1) # [frames,]
    keyword_frames = np.where(distances < threshold)[0]

    if len(keyword_frames) == 0:
        # If no frames are under threshold, fall back to entire segment
        return segment_start, segment_end

    # Find contiguous blocks of keyword frames
    breaks = np.where(np.diff(keyword_frames) > 1)[0]
    block_starts = np.insert(keyword_frames[breaks+1], 0, keyword_frames[0])
    block_ends = np.append(keyword_frames[breaks], keyword_frames[-1])

    # Select the largest block (most frames)
    block_lengths = block_ends - block_starts
    max_block_index = np.argmax(block_lengths)
    best_start_frame = block_starts[max_block_index]
    best_end_frame = block_ends[max_block_index]

    # Calculate frame duration
    frame_duration = (segment_end - segment_start) / frame_embeddings.shape[0]
    refined_start_time = segment_start + best_start_frame * frame_duration
    refined_end_time = segment_start + (best_end_frame + 1) * frame_duration

    return refined_start_time, refined_end_time


results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    expected_keywords = ground_truth.get(file_id_no_ext, [])
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    segments = segment_audio(test_file)

    # Track best detection per keyword
    keyword_best_detection = {}  # {kw: (confidence, seg_start, seg_end, frame_embeddings, prototype_vec)}

    for (segment_start_time, segment_end_time, seg_emb, frame_embeddings) in segments:
        kw, kw_confidence = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        
        if kw is not None:
            if kw not in keyword_best_detection or kw_confidence > keyword_best_detection[kw][0]:
                prototype_vec = filtered_prototypes[kw]
                keyword_best_detection[kw] = (
                    kw_confidence,
                    segment_start_time,
                    segment_end_time,
                    frame_embeddings,
                    prototype_vec
                )

    file_results = []
    for kw, (conf, seg_start, seg_end, frame_emb, proto_vec) in keyword_best_detection.items():
        refined_start, refined_end = refine_boundaries(
            frame_embeddings=frame_emb,
            prototype_vector=proto_vec,
            threshold=THRESHOLD,
            sr=SAMPLING_RATE,
            segment_start=seg_start,
            segment_end=seg_end
        )

        detection_entry = {
            "keyword": kw,
            "keyword_id": kw_to_id.get(kw, None),
            "start_time": refined_start,
            "end_time": refined_end,
            "confidence": conf
        }
        file_results.append(detection_entry)

    results[file_id_no_ext] = file_results

# Save results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)


Processing File ID: 0, Expected Keywords: ['رقیق', 'خوششون']
Filtered Prototypes for File ID 0: ['خوششون', 'رقیق']
Processing File ID: 1, Expected Keywords: ['އޮފިސަރ']
Filtered Prototypes for File ID 1: ['އޮފިސަރ']
Processing File ID: 10, Expected Keywords: ['klepněte', 'خواهم', '缓慢', 'بازرسی']
Filtered Prototypes for File ID 10: ['klepněte', 'بازرسی', 'خواهم', '缓慢']
Processing File ID: 100, Expected Keywords: ['خواهم', 'ސީނިއަރ']
Filtered Prototypes for File ID 100: ['ސީނިއަރ', 'خواهم']
Processing File ID: 101, Expected Keywords: ['各种']
Filtered Prototypes for File ID 101: ['各种']
Processing File ID: 102, Expected Keywords: ['خواهم', 'odehrává']
Filtered Prototypes for File ID 102: ['odehrává', 'خواهم']
Processing File ID: 103, Expected Keywords: ['گروهی', 'خواهم']
Filtered Prototypes for File ID 103: ['خواهم', 'گروهی']
Processing File ID: 104, Expected Keywords: ['தின', 'séance', 'خواهم']
Filtered Prototypes for File ID 104: ['خواهم', 'séance', 'தின']
Processing File ID: 105, Expecte

KeyboardInterrupt: 

In [25]:
print(entry)

{'./DATA/TEST_DUMMY_FINAL/249.wav': [{'keyword': 'tedy', 'start_time': 1.848625, 'end_time': 2.848625}, {'keyword': 'droit', 'start_time': 5.68925, 'end_time': 6.68925}, {'keyword': 'میتوانید', 'start_time': 8.4469375, 'end_time': 9.4469375}, {'keyword': 'droit', 'start_time': 12.118500000000001, 'end_time': 13.118500000000001}]}


In [26]:
print(nested_dict)

[{'keyword': 'tedy', 'start_time': 1.848625, 'end_time': 2.848625}, {'keyword': 'droit', 'start_time': 5.68925, 'end_time': 6.68925}, {'keyword': 'میتوانید', 'start_time': 8.4469375, 'end_time': 9.4469375}, {'keyword': 'droit', 'start_time': 12.118500000000001, 'end_time': 13.118500000000001}]


In [92]:
print(embeddings)

[array([-4.18751389e-02,  3.53388190e-02, -1.39812157e-01, -4.33332883e-02,
        4.09340225e-02, -7.54778907e-02,  5.45662716e-02, -4.34737690e-02,
        1.79194007e-02, -3.19845229e-01, -7.36276582e-02, -9.73511115e-02,
        6.46035746e-02,  4.43768688e-02, -6.67822883e-02, -4.32628095e-02,
       -3.01845849e-01,  2.63552725e-01,  1.45699810e-02,  9.41181183e-02,
       -1.27641678e-01,  5.30654751e-02,  1.35900065e-01,  1.74343288e-02,
        4.83447939e-01, -3.24696712e-02, -3.19065064e-01,  4.32119407e-02,
        1.43292308e-01, -1.28692165e-01,  2.44765535e-01, -9.68862884e-03,
       -4.84450012e-02, -5.46687245e-02, -2.87910610e-01,  7.88955465e-02,
        4.95807491e-02, -2.65641302e-01, -8.59601647e-02,  5.86813241e-02,
       -1.21404804e-01, -1.89780444e-02, -1.40521049e-01, -1.20388381e-02,
       -1.15862601e-01, -2.86114891e-03, -6.60270900e-02, -1.80117100e-01,
       -3.09299696e-02,  2.96610929e-02, -1.11592568e-01, -2.49166079e-02,
        2.29526713e-01, 

In [16]:
print(prototypes)

{'poušti': array([-5.83259854e-03,  2.55051740e-02,  3.59268375e-02, -5.12781665e-02,
        7.80151039e-02, -1.15963563e-01,  7.06832483e-02, -2.60463990e-02,
        4.16785702e-02, -2.84854323e-01, -3.89134474e-02, -2.57166140e-02,
        5.37563674e-02,  3.83820049e-02, -1.97520629e-02, -1.39310453e-02,
       -3.48632395e-01,  2.99872547e-01,  2.44159959e-02,  9.10632536e-02,
       -1.54806703e-01,  8.58023986e-02,  1.41568884e-01,  1.25933867e-02,
        2.43039772e-01,  3.08111147e-03, -3.92104536e-01,  3.98511812e-02,
        2.33306773e-02, -1.90186918e-01,  1.41262367e-01, -1.01808151e-02,
       -6.61576837e-02, -9.71952975e-02, -2.10650057e-01,  1.16081730e-01,
       -4.39525023e-02, -2.22683579e-01, -1.50440156e-01,  9.56525654e-02,
       -1.40965968e-01, -1.41415119e-01, -3.16491313e-02,  3.22392941e-01,
       -1.78120285e-01,  1.31282538e-01, -3.13538536e-02, -7.23644495e-02,
       -2.15356108e-02,  1.72202028e-02, -1.03934422e-01, -6.16248325e-02,
        7.032

In [29]:
# Define paths to save embeddings
CLASS_EMBEDDINGS_FILE = "class_embeddings.pkl"
PROTOTYPES_FILE = "prototypes.pkl"

# After computing class_embeddings and prototypes
# Save class_embeddings
with open(CLASS_EMBEDDINGS_FILE, 'wb') as f:
    pickle.dump(class_embeddings, f)
    print(f"Saved class embeddings to {CLASS_EMBEDDINGS_FILE}")

# Save prototypes
with open(PROTOTYPES_FILE, 'wb') as f:
    pickle.dump(prototypes, f)
    print(f"Saved prototypes to {PROTOTYPES_FILE}")


Saved class embeddings to class_embeddings.pkl
Saved prototypes to prototypes.pkl


In [18]:
"facebook/wav2vec2-base-960h")
model.eval()

# Device configuration (optional: use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def extract_embedding(audio_path, sr=16000):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embedding

# Function to filter dataset
def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    class_files = {}
    for root, dirs, files in os.walk(train_dir):
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

# Paths
TRAIN_DIR = "F:\\KWS\\TRAIN\\TRAIN"

# Filter dataset
class_files = filter_dataset(TRAIN_DIR)

# Check if embeddings are already saved
if os.path.exists(CLASS_EMBEDDINGS_FILE) and os.path.exists(PROTOTYPES_FILE):
    # Load class_embeddings
    with open(CLASS_EMBEDDINGS_FILE, 'rb') as f:
        class_embeddings = pickle.load(f)
        print(f"Loaded class embeddings from {CLASS_EMBEDDINGS_FILE}")
    
    # Load prototypes
    with open(PROTOTYPES_FILE, 'rb') as f:
        prototypes = pickle.load(f)
        print(f"Loaded prototypes from {PROTOTYPES_FILE}")
else:
    # Compute class_embeddingsimport os
import pickle
import numpy as np
import torch
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist

# Define paths
CLASS_EMBEDDINGS_FILE = "class_embeddings.pkl"
PROTOTYPES_FILE = "prototypes.pkl"

# Initialize Wav2Vec2 Processor and Model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained(
    class_embeddings = {}
    for cls, file_paths in class_files.items():
        embeddings = []
        for file_path in file_paths:
            try:
                emb = extract_embedding(file_path)
                embeddings.append(emb)
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
        if embeddings:
            class_embeddings[cls] = embeddings
    
    # Save class_embeddings
    with open(CLASS_EMBEDDINGS_FILE, 'wb') as f:
        pickle.dump(class_embeddings, f)
        print(f"Saved class embeddings to {CLASS_EMBEDDINGS_FILE}")
    
    # Create prototypes
    prototypes = {}
    for cls, embs in class_embeddings.items():
        if len(embs) > 0:
            prototype = np.mean(np.stack(embs), axis=0)
            prototypes[cls] = prototype
    
    # Save prototypes
    with open(PROTOTYPES_FILE, 'wb') as f:
        pickle.dump(prototypes, f)
        print(f"Saved prototypes to {PROTOTYPES_FILE}")


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded class embeddings from class_embeddings.pkl
Loaded prototypes from prototypes.pkl
