In [2]:
import os
import random
import json
import torch
import librosa
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from scipy.spatial.distance import cdist
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
#model.eval()

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
TRAIN_DIR = "F:\KWS\TRAIN\TRAIN"
TEST_AUDIO_DIR = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL"  # Directory containing test audio files
OUTPUT_FILE = "results.json"
GT_FILE = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\GT_dummy_final.pkl"
KW_TO_ID_FILE = "F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\kw_to_id.pkl"
THRESHOLD = 0.8  # Similarity threshold for keyword detection
SAMPLING_RATE = 16000
WINDOW_SIZE = 1.0   # in seconds for test audio segmentation
STEP_SIZE = 0.5      # in seconds for test audio segmentation
ALPHA = 0.9 

In [5]:
# 1. Filter Dataset
def filter_dataset(train_dir, extensions=('.wav', '.mp3')):
    class_files = {}
    for root, _, files in os.walk(train_dir):
        audio_files = [f for f in files if f.lower().endswith(extensions)]
        if audio_files:
            class_label = os.path.basename(root)
            if len(audio_files) > 7:
                audio_files = random.sample(audio_files, 7)
            if 1 <= len(audio_files) <= 7:
                file_paths = [os.path.join(root, f) for f in audio_files]
                class_files[class_label] = file_paths
    return class_files

In [6]:

# 2. Extract Embeddings
def extract_embedding(audio_path, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(inputs.input_values)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    return embedding

In [18]:
# 3. Segment Audio
def segment_audio(audio_path, window_size=WINDOW_SIZE, step_size=STEP_SIZE, sr=SAMPLING_RATE):
    audio, _ = librosa.load(audio_path, sr=sr)
    hop_length = int(step_size * sr)
    win_length = int(window_size * sr)

    segments = []
    start = 0
    end = win_length
    while end <= len(audio):
        segment_audio = audio[start:end]
        inputs = processor(segment_audio, sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            outputs = model(inputs.input_values)
            seg_emb = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        segment_start_time = start / sr
        segment_end_time = end / sr
        segments.append((segment_start_time, segment_end_time, seg_emb))

        start += hop_length
        end = start + win_length

    return segments

In [26]:
# 4. Detect Keyword in Segment
def detect_keyword_in_segment(segment_emb, prototypes, start_t, end_t, threshold=THRESHOLD):
    if not prototypes:
        return None, 0.0, start_t, end_t
    # existing logic...
    
    proto_classes = list(prototypes.keys())
    proto_mat = np.stack(list(prototypes.values()), axis=0)
    distances = cdist([segment_emb], proto_mat, metric='cosine')[0]
    min_dist_idx = np.argmin(distances)
    min_dist = distances[min_dist_idx]
    confidence = 1 - min_dist
    if min_dist < threshold:
        return proto_classes[min_dist_idx], confidence, start_t, end_t
    else:
        return None, 0.0, start_t, end_t

In [20]:
# 5. Group Overlapping Detections
def group_overlapping_detections(detections, step_size):
    grouped_detections = {}
    for det in detections:
        keyword = det["keyword"]
        start_time = det["start_time"]
        end_time = det["end_time"]
        confidence = det["confidence"]

        if keyword not in grouped_detections:
            grouped_detections[keyword] = []

        if grouped_detections[keyword] and start_time <= grouped_detections[keyword][-1]["end_time"]:
            prev_det = grouped_detections[keyword][-1]
            prev_det["end_time"] = max(prev_det["end_time"], end_time)
            prev_det["confidence"] = max(prev_det["confidence"], confidence)
        else:
            grouped_detections[keyword].append({
                "start_time": start_time,
                "end_time": end_time,
                "confidence": confidence
            })

    merged_detections = []
    for keyword, dets in grouped_detections.items():
        for det in dets:
            merged_detections.append({
                "keyword": keyword,
                "start_time": det["start_time"],
                "end_time": det["end_time"],
                "confidence": det["confidence"]
            })

    return merged_detections

In [21]:
# 6. Update Prototypes
def update_prototypes(prototypes, false_positives, false_negatives, segment_embeddings, alpha=ALPHA):
    for file_id, fps in false_positives.items():
        for fp in fps:
            if fp in prototypes:
                prototype = prototypes[fp]
                for seg_emb in segment_embeddings[file_id]:
                    prototype = alpha * prototype - (1 - alpha) * seg_emb
                prototypes[fp] = prototype

    for file_id, fns in false_negatives.items():
        for fn in fns:
            if fn in prototypes:
                prototype = prototypes[fn]
                for seg_emb in segment_embeddings[file_id]:
                    prototype = alpha * prototype + (1 - alpha) * seg_emb
                prototypes[fn] = prototype
    return prototypes

In [22]:
# # 7. Main Workflow
# if __name__ == "__main__":
#     # Filter dataset and compute prototypes
#     class_files = filter_dataset(TRAIN_DIR)
#     class_embeddings = {cls: [extract_embedding(file) for file in files] for cls, files in class_files.items()}
#     prototypes = {cls: np.mean(embs, axis=0) for cls, embs in class_embeddings.items()}

#     # Load Ground Truth and Keyword-to-ID Mapping
#     with open(GT_FILE, 'rb') as f:
#         ground_truth_list = pickle.load(f)
#     with open(KW_TO_ID_FILE, 'rb') as f:
#         kw_to_id = pickle.load(f)
#     id_to_kw = {v: k for k, v in kw_to_id.items()}
#     ground_truth = {os.path.splitext(os.path.basename(k))[0]: [id_to_kw.get(d['keyword'], None) for d in v.get(k, [])] for entry in ground_truth_list for k, v in entry.items()}

#     results = {}
#     for test_file in os.listdir(TEST_AUDIO_DIR):
#         file_id = os.path.splitext(test_file)[0]
#         test_path = os.path.join(TEST_AUDIO_DIR, test_file)
#         expected_keywords = ground_truth.get(file_id, [])
#         filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

#         segments = segment_audio(test_path)
#         file_results = []
#         for (start_t, end_t, seg_emb) in segments:
#             kw, confidence = detect_keyword_in_segment(seg_emb, filtered_prototypes, threshold=THRESHOLD)
#             if kw is not None:
#                 file_results.append({
#                     "keyword": kw,
#                     "start_time": start_t,
#                     "end_time": end_t,
#                     "confidence": confidence
#                 })

#         file_results = group_overlapping_detections(file_results, STEP_SIZE)
#         results[file_id] = file_results

#     with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
#         json.dump(results, f, ensure_ascii=False, indent=2)

#     print("Keyword detection complete. Results saved to:", OUTPUT_FILE)

In [1]:
if __name__ == "__main__":
    # Filter dataset and compute prototypes
    class_files = filter_dataset(TRAIN_DIR)
    class_embeddings = {cls: [extract_embedding(file) for file in files] for cls, files in class_files.items()}
    prototypes = {cls: np.mean(embs, axis=0) for cls, embs in class_embeddings.items()}

    # Load Ground Truth and Keyword-to-ID Mapping
    with open(GT_FILE, 'rb') as f:
        ground_truth_list = pickle.load(f)
    with open(KW_TO_ID_FILE, 'rb') as f:
        kw_to_id = pickle.load(f)
    id_to_kw = {v: k for k, v in kw_to_id.items()}

    # Process ground truth into a usable dictionary format
    ground_truth = {}
    for entry in ground_truth_list:
        for file_path, data in entry.items():
            # Extract file ID without extension
            file_id = os.path.splitext(os.path.basename(file_path))[0]
            if file_id not in ground_truth:
                ground_truth[file_id] = []
            
            # If data is a list of detection dictionaries, just use it directly
            detections = data  # Assuming data is already a list
            for detection in detections:
                keyword_id = detection.get('keyword')
                keyword_str = id_to_kw.get(keyword_id)
                if keyword_str:
                    ground_truth[file_id].append(keyword_str)


    # Test only on the specified files
    test_audio_files = [
        r"F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\0.wav",
        r"F:\KWS\TEST_DUMMY_FINAL\TEST_DUMMY_FINAL\1.wav"
    ]

    results = {}
    for test_file in test_audio_files:
        file_id = os.path.splitext(os.path.basename(test_file))[0]
        expected_keywords = ground_truth.get(file_id, [])
        filtered_prototypes = {k: v for k, v in prototypes.items() if k in expected_keywords}

        segments = segment_audio(test_file)
        file_results = []
        for (start_t, end_t, seg_emb) in segments:
            kw, confidence, seg_start, seg_end = detect_keyword_in_segment(seg_emb, filtered_prototypes, start_t, end_t, threshold=THRESHOLD)
            if kw is not None:
                file_results.append({
                    "keyword": kw,
                    "start_time": seg_start,
                    "end_time": seg_end,
                    "confidence": confidence
                })


        file_results = group_overlapping_detections(file_results, STEP_SIZE)
        results[file_id] = file_results

    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print("Keyword detection complete. Results saved to:", OUTPUT_FILE)


NameError: name 'filter_dataset' is not defined