In [1]:
import os
import torch
import torchaudio
import warnings
from typing import Dict, List, Any
import tempfile
import argparse

# --- Core AI Libraries ---
import assemblyai as aai
from pydub import AudioSegment
from speechbrain.pretrained import EncoderClassifier
from scipy.optimize import linear_sum_assignment
import numpy as np
from dotenv import load_dotenv

# Suppress user warnings for a cleaner output
warnings.filterwarnings("ignore", category=UserWarning)


class AssemblyAIHandler:
    """
    Handles all interactions with the AssemblyAI API, including transcription,
    diarization, and extraction of speaker audio clips into a temporary directory.
    """
    def __init__(self, api_key: str):
        if not api_key:
            raise ValueError("AssemblyAI API key is required.")
        aai.settings.api_key = api_key
        self.transcriber = aai.Transcriber()

    def transcribe_and_extract(self, audio_path: str, temp_dir: str) -> (aai.Transcript, Dict[str, str]):
        """
        Transcribes the audio, extracts each speaker's utterances into merged
        audio files within a temporary directory.
        
        Returns:
            - The full transcript object from AssemblyAI.
            - A dictionary mapping generic speaker labels to their temp audio file paths.
        """
        print(f"🎤 Starting AssemblyAI transcription for '{audio_path}'...")
        config = aai.TranscriptionConfig(speaker_labels=True)
        transcript = self.transcriber.transcribe(audio_path, config)

        if transcript.status == aai.TranscriptStatus.error:
            raise RuntimeError(f"Transcription failed: {transcript.error}")

        if not transcript.utterances:
            raise ValueError("Diarization failed. The audio might be too short or have only one speaker.")

        print("✅ Transcription complete. Extracting speaker audio clips...")
        
        try:
            original_audio = AudioSegment.from_file(audio_path)
        except Exception as e:
            raise RuntimeError(f"Error loading audio with pydub: {e}. Ensure FFmpeg is installed.")

        speaker_segments = {}
        for utterance in transcript.utterances:
            speaker = utterance.speaker
            clip = original_audio[utterance.start:utterance.end]
            if speaker not in speaker_segments:
                speaker_segments[speaker] = clip
            else:
                speaker_segments[speaker] += clip

        unknown_speaker_paths = {}
        for speaker, merged_audio in speaker_segments.items():
            speaker_file_path = os.path.join(temp_dir, f"SPEAKER_{speaker}.mp3")
            print(f"  -> Exporting merged audio for Speaker {speaker}...")
            merged_audio.export(speaker_file_path, format="mp3")
            unknown_speaker_paths[speaker] = speaker_file_path
            
        return transcript, unknown_speaker_paths


class SpeechBrainIdentifier:
    """
    Handles voiceprint creation and speaker identification using SpeechBrain's
    powerful ECAPA-TDNN model and the Hungarian algorithm for optimal assignment.
    """
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"\n🧠 Loading SpeechBrain speaker recognition model on '{self.device}'...")
        self.classifier = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            savedir="pretrained_models/spkrec-ecapa-voxceleb",
            run_opts={"device": self.device}
        )
        print("✅ SpeechBrain model loaded.")

    def _create_voiceprint(self, audio_path: str) -> torch.Tensor:
        """Computes a speaker embedding from an audio file."""
        try:
            signal, fs = torchaudio.load(audio_path)
            if fs != 16000:
                signal = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(signal)
            if signal.shape[0] > 1:
                signal = torch.mean(signal, dim=0, keepdim=True)
            
            with torch.no_grad():
                embedding = self.classifier.encode_batch(signal)
                return torch.nn.functional.normalize(embedding, p=2, dim=2).squeeze()
        except Exception as e:
            print(f"  - WARNING: Could not create voiceprint for {audio_path}: {e}")
            return None

    def enroll_speakers(self, speaker_samples: Dict[str, str]) -> Dict[str, torch.Tensor]:
        """Creates a voiceprint database from enrolled speaker samples."""
        print("\n--- Starting Speaker Enrollment ---")
        voiceprint_db = {}
        for name, path in speaker_samples.items():
            print(f"Creating voiceprint for '{name}'...")
            embedding = self._create_voiceprint(path)
            if embedding is not None:
                voiceprint_db[name] = embedding
        print("--- Enrollment Complete ---\n")
        return voiceprint_db

    def identify_speakers(self, unknown_clips: Dict[str, str], enrolled_voiceprints: Dict[str, torch.Tensor]) -> Dict[str, str]:
        """
        Finds the optimal one-to-one mapping of unknown speakers to enrolled speakers.
        """
        print("--- Identifying Unknown Speakers ---")
        unknown_voiceprints = {}
        for speaker_label, path in unknown_clips.items():
            embedding = self._create_voiceprint(path)
            if embedding is not None:
                unknown_voiceprints[speaker_label] = embedding

        enrolled_names = list(enrolled_voiceprints.keys())
        unknown_labels = list(unknown_voiceprints.keys())
        
        if not unknown_labels:
            print("No unknown speakers to identify.")
            return {}

        # Create a similarity matrix (higher is better)
        similarity_matrix = np.zeros((len(enrolled_names), len(unknown_labels)))
        cosine_similarity = torch.nn.CosineSimilarity(dim=0)
        for i, name in enumerate(enrolled_names):
            for j, label in enumerate(unknown_labels):
                score = cosine_similarity(enrolled_voiceprints[name], unknown_voiceprints[label]).item()
                similarity_matrix[i, j] = score

        # Use the Hungarian algorithm on a cost matrix to find the optimal assignment
        row_ind, col_ind = linear_sum_assignment(1 - similarity_matrix)

        speaker_map = {}
        assigned_unknowns = set()
        confidence_threshold = 0.50 # This can be tuned

        for r, c in zip(row_ind, col_ind):
            enrolled_name = enrolled_names[r]
            unknown_label = unknown_labels[c]
            score = similarity_matrix[r, c]

            if score > confidence_threshold:
                speaker_map[unknown_label] = enrolled_name
                assigned_unknowns.add(unknown_label)
                print(f"  - Matched Speaker {unknown_label} -> {enrolled_name} (Confidence: {score:.2f})")
        
        unknown_count = 1
        for label in unknown_labels:
            if label not in assigned_unknowns:
                unknown_name = f"Unknown Speaker {unknown_count}"
                speaker_map[label] = unknown_name
                print(f"  - Could not confidently match Speaker {label}. Assigning as {unknown_name}.")
                unknown_count += 1
                
        return speaker_map


class TranscriptionPipeline:
    """The main orchestrator for the entire process."""
    def __init__(self, api_key: str):
        self.assembly_handler = AssemblyAIHandler(api_key)
        self.speechbrain_identifier = SpeechBrainIdentifier()

    def run(self, main_audio_path: str, speaker_samples: Dict[str, str]):
        # Use a temporary directory that cleans itself up automatically
        with tempfile.TemporaryDirectory() as temp_dir:
            # Step 1: Transcribe and extract audio clips
            transcript, unknown_clips = self.assembly_handler.transcribe_and_extract(main_audio_path, temp_dir)

            # Step 2: Enroll known speakers
            enrolled_voiceprints = self.speechbrain_identifier.enroll_speakers(speaker_samples)

            # Step 3: Identify the unknown speakers from the extracted clips
            speaker_map = self.speechbrain_identifier.identify_speakers(unknown_clips, enrolled_voiceprints)

            # Step 4: Generate and save the final, named transcript
            self._generate_final_transcript(transcript, speaker_map, main_audio_path)

    def _generate_final_transcript(self, transcript: aai.Transcript, speaker_map: Dict[str, str], audio_path: str):
        print("\n--- Generating Final Named Transcript ---")
        base_name, _ = os.path.splitext(os.path.basename(audio_path))
        output_filename = f"{base_name}_final_transcript.txt"

        with open(output_filename, "w", encoding="utf-8") as f:
            for utterance in transcript.utterances:
                speaker_label = utterance.speaker
                # Get the real name from our map, or keep the generic label if unmapped
                final_name = speaker_map.get(speaker_label, f"Unmapped Speaker {speaker_label}")
                line = f"Speaker {final_name}: {utterance.text}\n"
                f.write(line)
        
        print(f"✅ Final transcript saved to '{output_filename}'")


# --- Main Execution Block ---
if __name__ == "__main__":
    # Load environment variables from .env file
    load_dotenv()
    ASSEMBLYAI_API_KEY = os.getenv("ASSEMBLYAI_API_KEY")

    if not ASSEMBLYAI_API_KEY:
        print("!!! ERROR: ASSEMBLYAI_API_KEY not found. Please create a .env file. !!!")
    else:
        # --- Define Your Inputs Here ---
        
        # 1. A dictionary mapping the desired speaker names to their sample audio file paths.
        SPEAKER_SAMPLES = {
            "spk1": "/Users/sujanh/Downloads/data2/spk1.mp3",
            "spk2": "/Users/sujanh/Downloads/data2/spk2.mp3",
            "spk3": "/Users/sujanh/Downloads/data2/spk3.mp3",
            "spk4": "/Users/sujanh/Downloads/data2/spk4.mp3"
        }

        # 2. The path to the main meeting audio file you want to process.
        MAIN_AUDIO_FILE = "/Users/sujanh/Downloads/data2/Segment 2.mp3"

        # --- Execution ---
        try:
            pipeline = TranscriptionPipeline(api_key=ASSEMBLYAI_API_KEY)
            pipeline.run(main_audio_path=MAIN_AUDIO_FILE, speaker_samples=SPEAKER_SAMPLES)
        except (ValueError, RuntimeError) as e:
            print(f"\n--- A critical error occurred ---")
            print(e)
        except Exception as e:
            print(f"\n--- An unexpected error occurred ---")
            print(e)


  from .autonotebook import tqdm as notebook_tqdm
  available_backends = torchaudio.list_audio_backends()
  from speechbrain.pretrained import EncoderClassifier



🧠 Loading SpeechBrain speaker recognition model on 'cpu'...


  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)


✅ SpeechBrain model loaded.
🎤 Starting AssemblyAI transcription for '/Users/sujanh/Downloads/data2/Segment 2.mp3'...
✅ Transcription complete. Extracting speaker audio clips...
  -> Exporting merged audio for Speaker A...
  -> Exporting merged audio for Speaker B...
  -> Exporting merged audio for Speaker C...

--- Starting Speaker Enrollment ---
Creating voiceprint for 'spk1'...
Creating voiceprint for 'spk2'...
Creating voiceprint for 'spk3'...
Creating voiceprint for 'spk4'...
--- Enrollment Complete ---

--- Identifying Unknown Speakers ---
  - Matched Speaker B -> spk1 (Confidence: 0.71)
  - Matched Speaker A -> spk3 (Confidence: 0.78)
  - Matched Speaker C -> spk4 (Confidence: 0.53)

--- Generating Final Named Transcript ---
✅ Final transcript saved to 'Segment 2_final_transcript.txt'
