In [1]:
# Import necessary packages
!pip install --q git+https://github.com/m-bain/whisperx.git
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import whisperx
import torch
import logging
import warnings
import os
import numpy as np
from pyannote.audio import Pipeline
from pyannote.core import Segment

# Suppress specific logging messages and warnings
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", message="No language specified")
warnings.filterwarnings("ignore", message="Model was trained with pyannote.audio 0.0.1, yours is 3.1.1")
warnings.filterwarnings("ignore", message="Model was trained with torch 1.10.0+cu102, yours is 2.3.1+cu121")

# Define CSS styles for better UI
styles = """
<style>
    body { background-color: #ffe6f0; }
    .container { background-color: #ffffff; padding: 20px; border-radius: 10px; box-shadow: 0 0 15px rgba(0, 0, 0, 0.1); margin: 20px auto; max-width: 600px; }
    .header { font-size: 28px; font-weight: bold; color: #d147a3; margin-bottom: 20px; text-align: center; }
    .subheader { font-size: 20px; color: #b03060; margin-top: 20px; margin-bottom: 10px; text-align: center; }
    .description { font-size: 16px; color: #660033; margin-bottom: 20px; text-align: center; }
    .form-group { margin-bottom: 20px; }
    label { display: block; font-weight: bold; margin-bottom: 5px; color: #b03060; }
    .form-item { margin-bottom: 20px; }
    input, select, .btn-process, .output-area { width: 100%; padding: 10px; border-radius: 5px; border: 1px solid #ccc; margin-bottom: 10px; }
    .btn-process { background-color: #d147a3; color: white; font-weight: bold; border: none; cursor: pointer; }
    .btn-process:hover { background-color: #b03060; }
    .output-area { background-color: #ffe6f0; border: 1px solid #d147a3; border-radius: 5px; padding: 10px; margin-top: 20px; }
</style>
"""

# Define widgets
file_upload = widgets.FileUpload(accept='.wav', multiple=False, description='Upload Audio', style={'description_width': 'initial'})
output_lang = widgets.Dropdown(
    options=['en', 'es', 'fr', 'de', 'zh'],
    value='en',
    description='Output Language:',
    style={'description_width': 'initial'}
)
transcription_type = widgets.RadioButtons(
    options=['Regular', 'Word-Level', 'Diarized'],
    value='Regular',
    description='Transcription Type:',
    style={'description_width': 'initial'}
)
convert_to_text = widgets.Checkbox(
    value=False,
    description='Convert to Text Script',
    style={'description_width': 'initial'}
)

# Button to trigger the processing
process_button = widgets.Button(
    description='Process Audio',
    disabled=False,
    button_style='info',  # Use a predefined style instead of a custom one
    tooltip='Click to process the audio file',
    icon='check'
)
process_button.add_class('btn-process')  # Add custom CSS class

output = widgets.Output()

# Function to split segments into word-level timestamps
def word_level_transcription(segments):
    word_segments = []
    for segment in segments:
        words = segment['text'].split()
        duration = segment['end'] - segment['start']
        word_duration = duration / len(words)
        for i, word in enumerate(words):
            word_start = segment['start'] + i * word_duration
            word_end = word_start + word_duration
            word_segments.append({'start': word_start, 'end': word_end, 'text': word})
    return word_segments

# Function to perform diarization using Pyannote
def perform_diarization(audio_file):
    # Load the Pyannote pipeline for speaker diarization
    pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="YOUR_AUTH_TOKEN")

    # Apply diarization
    diarization = pipeline(audio_file)

    # Convert diarization output to a list of segments with speaker labels
    diarized_segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        diarized_segments.append({
            'start': turn.start,
            'end': turn.end,
            'speaker': speaker
        })
    return diarized_segments

# Function to process the file
def process_audio(change):
    with output:
        clear_output()
        if file_upload.value:
            try:
                # Save the uploaded file
                uploaded_file = list(file_upload.value.values())[0]
                file_path = 'uploaded_audio.wav'
                with open(file_path, 'wb') as f:
                    f.write(uploaded_file['content'])

                # Load and process the audio
                device = "cuda" if torch.cuda.is_available() else "cpu"
                compute_type = "float16" if torch.cuda.is_available() else "int8"
                whisper_arch = 'small'  # You can choose 'tiny', 'base', 'small', 'medium', 'large'

                # Load the model
                model = whisperx.load_model(whisper_arch, device=device, compute_type=compute_type)
                if not model:
                    raise RuntimeError("Failed to load the Whisper model.")

                # Load the audio
                audio = whisperx.load_audio(file_path)
                if audio is None:
                    raise RuntimeError("Failed to load the audio file.")

                # Perform transcription
                result = model.transcribe(audio, language=output_lang.value)
                if result is None or 'segments' not in result:
                    raise RuntimeError("Transcription failed. No segments found.")

                # Handle different transcription types
                if transcription_type.value == 'Diarized':
                    # Perform diarization using Pyannote
                    diarized_segments = perform_diarization(file_path)

                    # Match diarized segments with transcribed segments
                    for segment in result['segments']:
                        for diarized_segment in diarized_segments:
                            if Segment(segment['start'], segment['end']).overlaps(Segment(diarized_segment['start'], diarized_segment['end'])):
                                segment['speaker'] = diarized_segment['speaker']

                # Extract and display the transcription text
                if transcription_type.value == 'Word-Level':
                    # Display each word with timestamps
                    word_segments = word_level_transcription(result['segments'])
                    for word_segment in word_segments:
                        start = word_segment['start']
                        end = word_segment['end']
                        text = word_segment['text']
                        print(f"[{start:.2f}s - {end:.2f}s] {text}")
                else:
                    # Display regular or diarized transcription
                    transcription_text = " ".join([f"[{segment.get('speaker', 'Speaker 1')}] {segment['text']}" for segment in result['segments']])
                    if convert_to_text.value:
                        print("Transcription Result:")
                        print(transcription_text)
                    else:
                        print("Full Result:", result)

                # Clean up the audio file after processing
                os.remove(file_path)

            except Exception as e:
                print(f"An error occurred: {e}")
        else:
            print("Please upload an audio file.")

# Attach the function to the button click event
process_button.on_click(process_audio)

# Display the UI
display(HTML(styles))
display(HTML('<div class="container">'))
display(HTML('<div class="header">Audio Transcription Service</div>'))
display(HTML('<div class="description">Upload an audio file and choose your transcription preferences. The result will be displayed below.</div>'))
display(file_upload)
display(output_lang)
display(transcription_type)
display(convert_to_text)
display(process_button)
display(output)
display(HTML('</div>'))


  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.5 MB/s[0m eta [36

  torchaudio.set_audio_backend("soundfile")


FileUpload(value={}, accept='.wav', description='Upload Audio')

Dropdown(description='Output Language:', options=('en', 'es', 'fr', 'de', 'zh'), style=DescriptionStyle(descri…

RadioButtons(description='Transcription Type:', options=('Regular', 'Word-Level', 'Diarized'), style=Descripti…

Checkbox(value=False, description='Convert to Text Script', style=DescriptionStyle(description_width='initial'…

Button(button_style='info', description='Process Audio', icon='check', style=ButtonStyle(), tooltip='Click to …

Output()