In [73]:
import os
import random
import json
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist
import pickle

In [None]:
TRAIN_DIR = "F:\KWS\TRAIN\TRAIN"
TEST_AUDIO_DIR = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL"  # Directory containing test audio files
OUTPUT_FILE = "keyword_detection_results.json"
THRESHOLD = 0.8  # Similarity threshold for keyword detection
SAMPLING_RATE = 16000
WINDOW_SIZE = 1.0   # in seconds for test audio segmentation
STEP_SIZE = 0.5      # in seconds for test audio segmentation

In [75]:

def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    """
    Filter classes to ensure each class has between 1 and 7 files.
    If >7, randomly select 7.
    Return a dictionary {class_label: [list_of_file_paths]}
    """
    class_files = {}

    for root, dirs, files in os.walk(train_dir):
        # Identify audio files
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

# Filter dataset
class_files = filter_dataset(TRAIN_DIR)

In [79]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
#model.eval()

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [80]:
def extract_embedding(audio_path, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    return embedding

In [84]:
class_embeddings = {}
for cls, files in class_files.items():
    embeddings = []
    for file in files:
        emb = extract_embedding(file)
        embeddings.append(emb)
    class_embeddings[cls] = embeddings

# Create prototypes
prototypes = {}
for cls, embs in class_embeddings.items():
    if len(embs) > 0:
        prototype = np.mean(np.stack(embs), axis=0)
        prototypes[cls] = prototype

In [None]:
import pickle

# Paths to pickle files
GT_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\GT_dummy.pickle"

# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth = pickle.load(f)  # Currently a list

# Inspect the type and contents
print("Type of ground_truth:", type(ground_truth))
print("Number of entries:", len(ground_truth))

# Inspect the first few entries
for i, entry in enumerate(ground_truth[:5], 1):
    print(f"Entry {i}: {entry}")


Type of ground_truth: <class 'list'>
Number of entries: 250
Entry 1: {'./DATA/TEST_DUMMY_CORRECTED/0.wav': {'./DATA/TEST_DUMMY_CORRECTED/0.wav': [{'keyword': 116, 'start_time': 2.43, 'end_time': 3.43}]}}
Entry 2: {'./DATA/TEST_DUMMY_CORRECTED/1.wav': {'./DATA/TEST_DUMMY_CORRECTED/1.wav': [{'keyword': 3, 'start_time': 3.409375, 'end_time': 4.409375}, {'keyword': 156, 'start_time': 6.859375, 'end_time': 7.859375}]}}
Entry 3: {'./DATA/TEST_DUMMY_CORRECTED/2.wav': {'./DATA/TEST_DUMMY_CORRECTED/2.wav': [{'keyword': 156, 'start_time': 2.45, 'end_time': 3.45}, {'keyword': 156, 'start_time': 5.86, 'end_time': 6.86}, {'keyword': 156, 'start_time': 9.3611875, 'end_time': 10.3611875}]}}
Entry 4: {'./DATA/TEST_DUMMY_CORRECTED/3.wav': {'./DATA/TEST_DUMMY_CORRECTED/3.wav': [{'keyword': 124, 'start_time': 2.5360625, 'end_time': 3.5360625}, {'keyword': 175, 'start_time': 5.530749999999999, 'end_time': 6.530749999999999}]}}
Entry 5: {'./DATA/TEST_DUMMY_CORRECTED/4.wav': {'./DATA/TEST_DUMMY_CORRECTED/4.

In [None]:

# Paths to pickle files
GT_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\GT_dummy.pickle"
KW_TO_ID_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\kw_to_id.pkl"

# Load Keyword to ID Mapping
with open(KW_TO_ID_FILE, 'rb') as f:
    kw_to_id = pickle.load(f)  # {keyword_str: keyword_id}

# Invert the mapping to get ID to Keyword
id_to_kw = {v: k for k, v in kw_to_id.items()}

# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth_list = pickle.load(f)  # List of dictionaries

# Verify Ground Truth Type
print("Type of ground_truth:", type(ground_truth_list))
print("Number of entries:", len(ground_truth_list))

# Process Ground Truth to create a mapping: {file_id_no_ext: [keyword_str, ...]}
ground_truth = {}  # Initialize empty dictionary

for idx, entry in enumerate(ground_truth_list, 1):
    for file_path, nested_dict in entry.items():
        # Extract file ID without extension
        file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        
        # Initialize list for keywords if not already present
        if file_id_no_ext not in ground_truth:
            ground_truth[file_id_no_ext] = []
        
        # Extract detections
        detections = nested_dict.get(file_path, [])
        
        for detection in detections:
            keyword_id = detection.get('keyword')
            keyword_str = id_to_kw.get(keyword_id, None)
            if keyword_str:
                ground_truth[file_id_no_ext].append(keyword_str)
            else:
                print(f"Warning: Keyword ID {keyword_id} not found in mapping.")

# Debug: Inspect the processed ground_truth
print("Processed Ground Truth:")
for key, value in list(ground_truth.items())[:2]:  # Print first two entries
    print(f"File ID: {key}, Keywords: {value}")

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    
    if min_dist < threshold:
        return proto_classes[min_dist_idx]
    else:
        return None

# Only these two files
test_audio_files = [
    os.path.join(TEST_AUDIO_DIR, "1.wav"),
    os.path.join(TEST_AUDIO_DIR, "0.wav"),
    os.path.join(TEST_AUDIO_DIR, "10.wav")
]

results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords from ground_truth
    expected_keywords = ground_truth.get(file_id_no_ext, [])
    
    # Debug: Print expected keywords
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes to only those keywords expected in ground truth for this file
    # This helps reduce false positives drastically.
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

    # Debug: Print filtered prototypes
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    segments = segment_audio(test_file)
    file_results = []

    last_detected_keyword = None
    last_detected_time = -999

    for (start_t, end_t, seg_emb) in segments:
        kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        if kw is not None:
            # Simple post-processing to avoid multiple detections of the same keyword in overlapping windows
            if kw != last_detected_keyword or (start_t - last_detected_time) > (STEP_SIZE * 2):
                kw_id = kw_to_id.get(kw, None)
                detection_entry = {
                    "keyword": kw,
                    "keyword_id": kw_id,
                    "start_time": start_t,
                    "end_time": end_t
                }
                file_results.append(detection_entry)
                last_detected_keyword = kw
                last_detected_time = start_t

    results[file_id_no_ext] = file_results

# Save Results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)

Type of ground_truth: <class 'list'>
Number of entries: 250
Processed Ground Truth:
File ID: 0, Keywords: ['tomu']
File ID: 1, Keywords: ['أطول', 'خواهم']
Processing File ID: 1, Expected Keywords: ['أطول', 'خواهم']
Filtered Prototypes for File ID 1: ['أطول', 'خواهم']
Processing File ID: 0, Expected Keywords: ['tomu']
Filtered Prototypes for File ID 0: ['tomu']
Processing File ID: 10, Expected Keywords: ['خواهم', 'خواهم', 'kampanya', 'apporter']
Filtered Prototypes for File ID 10: ['خواهم', 'apporter', 'kampanya']
Keyword detection complete. Results saved to: keyword_detection_results.json


In [None]:

# Paths to pickle files
GT_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\GT_dummy.pickle"
KW_TO_ID_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\kw_to_id.pkl"

# Load Keyword to ID Mapping
with open(KW_TO_ID_FILE, 'rb') as f:
    kw_to_id = pickle.load(f)  # {keyword_str: keyword_id}

# Invert the mapping to get ID to Keyword
id_to_kw = {v: k for k, v in kw_to_id.items()}

# Load Ground Truth
with open(GT_FILE, 'rb') as f:
    ground_truth_list = pickle.load(f)  # List of dictionaries

# Verify Ground Truth Type
print("Type of ground_truth:", type(ground_truth_list))
print("Number of entries:", len(ground_truth_list))

# Process Ground Truth to create a mapping: {file_id_no_ext: [keyword_str, ...]}
ground_truth = {}  # Initialize empty dictionary

for idx, entry in enumerate(ground_truth_list, 1):
    for file_path, nested_dict in entry.items():
        # Extract file ID without extension
        file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        
        # Initialize list for keywords if not already present
        if file_id_no_ext not in ground_truth:
            ground_truth[file_id_no_ext] = []
        
        # Extract detections
        detections = nested_dict.get(file_path, [])
        
        for detection in detections:
            keyword_id = detection.get('keyword')
            keyword_str = id_to_kw.get(keyword_id, None)
            if keyword_str:
                ground_truth[file_id_no_ext].append(keyword_str)
            else:
                print(f"Warning: Keyword ID {keyword_id} not found in mapping.")

# Debug: Inspect the processed ground_truth
print("Processed Ground Truth:")
for key, value in list(ground_truth.items())[:2]:  # Print first two entries
    print(f"File ID: {key}, Keywords: {value}")

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    print(min_dist)
    if min_dist < threshold:
        return proto_classes[min_dist_idx]
    else:
        return None

# Only these two files
test_audio_files = [
    os.path.join(TEST_AUDIO_DIR, "1.wav"),
    os.path.join(TEST_AUDIO_DIR, "0.wav"),
    os.path.join(TEST_AUDIO_DIR, "10.wav")
]

results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords from ground_truth
    expected_keywords = ground_truth.get(file_id_no_ext, [])
    
    # Debug: Print expected keywords
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes to only those keywords expected in ground truth for this file
    # This helps reduce false positives drastically.
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

    # Debug: Print filtered prototypes
    print(f"Filtered Prototypes for File ID {file_id_no_ext}: {list(filtered_prototypes.keys())}")

    segments = segment_audio(test_file)
    file_results = []

    last_detected_keyword = None
    last_detected_time = -999

    for (start_t, end_t, seg_emb) in segments:
        kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        if kw is not None:
            # Simple post-processing to avoid multiple detections of the same keyword in overlapping windows
            if kw != last_detected_keyword or (start_t - last_detected_time) > (STEP_SIZE * 2):
                kw_id = kw_to_id.get(kw, None)
                detection_entry = {
                    "keyword": kw,
                    "keyword_id": kw_id,
                    "start_time": start_t,
                    "end_time": end_t
                }
                file_results.append(detection_entry)
                last_detected_keyword = kw
                last_detected_time = start_t

    results[file_id_no_ext] = file_results

# Save Results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)

Type of ground_truth: <class 'list'>
Number of entries: 250
Processed Ground Truth:
File ID: 0, Keywords: ['tomu']
File ID: 1, Keywords: ['أطول', 'خواهم']
Processing File ID: 1, Expected Keywords: ['أطول', 'خواهم']
Filtered Prototypes for File ID 1: ['أطول', 'خواهم']
0.07051869920153109
0.17319115337973545
0.10145593520980611
0.11868912073247517
0.19436597073449735
0.10747415201120403
0.04026823181235173
0.030736552959383046
0.055807606005913946
0.08594619493804478
0.09794407895906776
0.12714596782928933
0.08967828845043957
0.053163032820762846
0.02928215632920972
0.08377951950354756
0.10083323621937368
0.15006426559992325
0.14517853848352635
0.11256725568751458
Processing File ID: 0, Expected Keywords: ['tomu']
Filtered Prototypes for File ID 0: ['tomu']
0.05503677862329226
0.13631161785949564
0.1260196971661366
0.05202979778395889
0.028660709641973536
0.048233272322326215
0.04020098004177641
0.04694717466655218
0.12860382418813765
0.18091750704822618
0.1443687873076327
Processing Fil

In [92]:
print(embeddings)

[array([-4.18751389e-02,  3.53388190e-02, -1.39812157e-01, -4.33332883e-02,
        4.09340225e-02, -7.54778907e-02,  5.45662716e-02, -4.34737690e-02,
        1.79194007e-02, -3.19845229e-01, -7.36276582e-02, -9.73511115e-02,
        6.46035746e-02,  4.43768688e-02, -6.67822883e-02, -4.32628095e-02,
       -3.01845849e-01,  2.63552725e-01,  1.45699810e-02,  9.41181183e-02,
       -1.27641678e-01,  5.30654751e-02,  1.35900065e-01,  1.74343288e-02,
        4.83447939e-01, -3.24696712e-02, -3.19065064e-01,  4.32119407e-02,
        1.43292308e-01, -1.28692165e-01,  2.44765535e-01, -9.68862884e-03,
       -4.84450012e-02, -5.46687245e-02, -2.87910610e-01,  7.88955465e-02,
        4.95807491e-02, -2.65641302e-01, -8.59601647e-02,  5.86813241e-02,
       -1.21404804e-01, -1.89780444e-02, -1.40521049e-01, -1.20388381e-02,
       -1.15862601e-01, -2.86114891e-03, -6.60270900e-02, -1.80117100e-01,
       -3.09299696e-02,  2.96610929e-02, -1.11592568e-01, -2.49166079e-02,
        2.29526713e-01, 

In [93]:
print(prototypes)

{'أذهب': array([-2.52930708e-02, -4.21816745e-04, -3.31777744e-02, -4.40373309e-02,
        4.85753641e-02, -1.11618519e-01,  6.02274239e-02, -1.98440254e-02,
        6.99318796e-02, -3.19634348e-01, -9.97075718e-03, -1.08551895e-02,
        5.71031868e-02,  5.25325164e-02,  2.28018463e-02, -5.96911553e-03,
       -3.15417320e-01,  3.07224721e-01,  1.80651862e-02,  4.44613174e-02,
       -2.12319270e-01,  7.39759952e-02,  3.36918771e-01,  1.19639309e-02,
        7.66948834e-02, -2.60738116e-02, -3.55677277e-01,  5.08983135e-02,
       -3.73687074e-02, -1.23734459e-01,  9.89444181e-02, -1.04906876e-02,
       -3.70370857e-02, -7.73759708e-02, -2.02745765e-01,  9.60097089e-02,
        2.65213717e-02, -2.35525861e-01, -1.27828091e-01,  4.74905334e-02,
       -1.55729055e-01, -1.97207332e-01, -1.06978104e-01,  1.54836178e-01,
       -1.68310985e-01, -2.37420034e-02, -4.69452664e-02,  2.06732694e-02,
        4.51109186e-03,  8.54670908e-03, -1.10402457e-01, -1.61775779e-02,
        5.380353

In [96]:
# Define paths to save embeddings
CLASS_EMBEDDINGS_FILE = "class_embeddings.pkl"
PROTOTYPES_FILE = "prototypes.pkl"

# After computing class_embeddings and prototypes
# Save class_embeddings
with open(CLASS_EMBEDDINGS_FILE, 'wb') as f:
    pickle.dump(class_embeddings, f)
    print(f"Saved class embeddings to {CLASS_EMBEDDINGS_FILE}")

# Save prototypes
with open(PROTOTYPES_FILE, 'wb') as f:
    pickle.dump(prototypes, f)
    print(f"Saved prototypes to {PROTOTYPES_FILE}")


Saved class embeddings to class_embeddings.pkl
Saved prototypes to prototypes.pkl


In [97]:
import os
import pickle
import numpy as np
import torch
import librosa
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist

# Define paths
CLASS_EMBEDDINGS_FILE = "class_embeddings.pkl"
PROTOTYPES_FILE = "prototypes.pkl"

# Initialize Wav2Vec2 Processor and Model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
model.eval()

# Device configuration (optional: use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def extract_embedding(audio_path, sr=16000):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embedding

# Function to filter dataset
def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    class_files = {}
    for root, dirs, files in os.walk(train_dir):
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

# Paths
TRAIN_DIR = "F:\\KWS\\TRAIN\\TRAIN"

# Filter dataset
class_files = filter_dataset(TRAIN_DIR)

# Check if embeddings are already saved
if os.path.exists(CLASS_EMBEDDINGS_FILE) and os.path.exists(PROTOTYPES_FILE):
    # Load class_embeddings
    with open(CLASS_EMBEDDINGS_FILE, 'rb') as f:
        class_embeddings = pickle.load(f)
        print(f"Loaded class embeddings from {CLASS_EMBEDDINGS_FILE}")
    
    # Load prototypes
    with open(PROTOTYPES_FILE, 'rb') as f:
        prototypes = pickle.load(f)
        print(f"Loaded prototypes from {PROTOTYPES_FILE}")
else:
    # Compute class_embeddings
    class_embeddings = {}
    for cls, file_paths in class_files.items():
        embeddings = []
        for file_path in file_paths:
            try:
                emb = extract_embedding(file_path)
                embeddings.append(emb)
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
        if embeddings:
            class_embeddings[cls] = embeddings
    
    # Save class_embeddings
    with open(CLASS_EMBEDDINGS_FILE, 'wb') as f:
        pickle.dump(class_embeddings, f)
        print(f"Saved class embeddings to {CLASS_EMBEDDINGS_FILE}")
    
    # Create prototypes
    prototypes = {}
    for cls, embs in class_embeddings.items():
        if len(embs) > 0:
            prototype = np.mean(np.stack(embs), axis=0)
            prototypes[cls] = prototype
    
    # Save prototypes
    with open(PROTOTYPES_FILE, 'wb') as f:
        pickle.dump(prototypes, f)
        print(f"Saved prototypes to {PROTOTYPES_FILE}")


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded class embeddings from class_embeddings.pkl
Loaded prototypes from prototypes.pkl
