In [4]:
import os
import random
import json
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist
import pickle

# Define directories and files
TRAIN_DIR = "F:\\KWS\\TRAIN\\TRAIN"
TEST_AUDIO_DIR = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL"
OUTPUT_FILE = "test_keyword_detection_results.json"
THRESHOLD = 0.8  # More stringent threshold for keyword detection
SAMPLING_RATE = 16000
WINDOW_SIZE = 1.0   # in seconds for test audio segmentation
STEP_SIZE = 0.5      # in seconds for test audio segmentation

# Define paths to save embeddings
CLASS_EMBEDDINGS_FILE = "class_embeddings.pkl"
PROTOTYPES_FILE = "prototypes.pkl"
TEST_SEGMENTS_EMBEDDINGS_FILE = "test_segments_embeddings.pkl"

# Initialize Wav2Vec2 Processor and Model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
model.eval()

# Device configuration (optional: use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def extract_embedding(audio_path, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embedding

def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

def detect_keyword_in_segment(segment_emb, prototypes, threshold=THRESHOLD):
    if not prototypes:
        return None
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    if min_dist < threshold:
        return proto_classes[min_dist_idx]
    else:
        return None

# Load Ground Truth
GT_FILE = "F:\\KWS\\TEST_DUMMY_FINAL\\TEST_DUMMY_FINAL\\GT_dummy_final.pkl"

with open(GT_FILE, 'rb') as f:
    ground_truth_list = pickle.load(f)  # List of dictionaries

# Verify Ground Truth Type
print("Type of ground_truth:", type(ground_truth_list))
print("Number of entries:", len(ground_truth_list))

# Process Ground Truth to create a mapping: {file_id_no_ext: [keyword_str, ...]}
ground_truth = {}  # Initialize empty dictionary

for idx, entry in enumerate(ground_truth_list, 1):
    for file_path, detections in entry.items():
        # Extract file ID without extension
        file_id_no_ext = os.path.splitext(os.path.basename(file_path))[0]
        if file_id_no_ext not in ground_truth:
            ground_truth[file_id_no_ext] = []
        ground_truth[file_id_no_ext].extend(detections)

# Debug: Inspect the processed ground_truth
print("Processed Ground Truth:")
for key, value in list(ground_truth.items())[:2]:  # Print first two entries
    print(f"File ID: {key}, Keywords: {value}")

# Only these test audio files
test_audio_files = [
    os.path.join(TEST_AUDIO_DIR, "1.wav"),
    os.path.join(TEST_AUDIO_DIR, "0.wav")
]

# Initialize or load test segments embeddings
if os.path.exists(TEST_SEGMENTS_EMBEDDINGS_FILE):
    with open(TEST_SEGMENTS_EMBEDDINGS_FILE, 'rb') as f:
        test_segments_embeddings = pickle.load(f)
        print(f"Loaded test segments embeddings from {TEST_SEGMENTS_EMBEDDINGS_FILE}")
else:
    test_segments_embeddings = {}

results = {}

for test_file in test_audio_files:
    file_id = os.path.basename(test_file)
    file_id_no_ext = os.path.splitext(file_id)[0]

    # Retrieve expected keywords from ground_truth
    expected_keywords = [detection["keyword"] for detection in ground_truth.get(file_id_no_ext, [])]
    
    # Debug: Print expected keywords
    print(f"Processing File ID: {file_id_no_ext}, Expected Keywords: {expected_keywords}")

    # Filter prototypes to only those keywords expected in ground truth for this file
    filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}
    
    # Check if embeddings for this test file are already saved
    if file_id_no_ext in test_segments_embeddings:
        segments = test_segments_embeddings[file_id_no_ext]
        print(f"Loaded segments for {file_id_no_ext} from {TEST_SEGMENTS_EMBEDDINGS_FILE}")
    else:
        segments = segment_audio(test_file)
        test_segments_embeddings[file_id_no_ext] = segments
        print(f"Computed and saved segments for {file_id_no_ext}")
    
    # Optionally, save after processing each file to handle interruptions
    with open(TEST_SEGMENTS_EMBEDDINGS_FILE, 'wb') as f:
        pickle.dump(test_segments_embeddings, f)
        print(f"Updated test segments embeddings saved to {TEST_SEGMENTS_EMBEDDINGS_FILE}")
    
    file_results = []

    last_detected_keyword = None
    last_detected_time = -999

    for (start_t, end_t, seg_emb) in segments:
        kw = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
        if kw is not None:
            # Simple post-processing to avoid multiple detections of the same keyword in overlapping windows
            if kw != last_detected_keyword or (start_t - last_detected_time) > (STEP_SIZE * 2):
                detection_entry = {
                    "keyword": kw,
                    "start_time": round(start_t, 2),
                    "end_time": round(end_t, 2)
                }
                file_results.append(detection_entry)
                last_detected_keyword = kw
                last_detected_time = start_t

    results[file_id_no_ext] = file_results

# Save Results
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print("Keyword detection complete. Results saved to:", OUTPUT_FILE)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Type of ground_truth: <class 'list'>
Number of entries: 250
Processed Ground Truth:
File ID: 0, Keywords: [{'keyword': 'رقیق', 'start_time': 1.6009375, 'end_time': 2.6009374999999997}, {'keyword': 'خوششون', 'start_time': 5.1909374999999995, 'end_time': 6.1909374999999995}]
File ID: 1, Keywords: [{'keyword': 'އޮފިސަރ', 'start_time': 2.65, 'end_time': 3.65}]
Loaded test segments embeddings from test_segments_embeddings.pkl
Processing File ID: 1, Expected Keywords: ['އޮފިސަރ']
Loaded segments for 1 from test_segments_embeddings.pkl
Updated test segments embeddings saved to test_segments_embeddings.pkl
Processing File ID: 0, Expected Keywords: ['رقیق', 'خوششون']
Loaded segments for 0 from test_segments_embeddings.pkl
Updated test segments embeddings saved to test_segments_embeddings.pkl
Keyword detection complete. Results saved to: test_keyword_detection_results.json
