<a href="https://colab.research.google.com/github/thias42/offline-speech-summarization/blob/main/whisper_diarization_summarization_offline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://raw.githubusercontent.com/thias42/offline-speech-summarization/refs/heads/main/requirements.txt
!pip install -r requirements.txt

In [8]:
import llm
import torch
import typer
import logging
import whisper
from os import environ
from dotenv import load_dotenv
from pyannote.audio import Pipeline

load_dotenv()

# Load Whisper model
whisper_model = whisper.load_model("base")

# Load Pyannote.audio pipeline
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
                                                use_auth_token=environ.get("HF_AUTH_TOKEN"))

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
diarization_pipeline.to(device) # switch to gpu if available

llm_model = llm.get_model(environ.get("LLM_MODEL"))

def transcribe_audio(audio_path):
    """Transcribe audio file to text using Whisper"""
    result = whisper_model.transcribe(audio_path)
    return result

def diarize_audio(audio_path):
    """Perform speaker diarization using pyannote.audio"""
    diarization = diarization_pipeline(audio_path)
    return diarization

def merge_transcription_and_diarization(transcription, diarization, margin=0.2):
    """Merge Whisper transcription with pyannote.audio diarization"""
    merged_output = []
    for segment, _, speaker in diarization.itertracks(yield_label=True):
        segment_start = segment.start
        segment_end = segment.end

        # Find all words that fall within this segment
        segment_words = [word for word in transcription["segments"]
                         if word["start"] >= (segment_start - margin) and word["end"] <= (segment_end + margin)]

        if segment_words:
            segment_text = " ".join([word["text"] for word in segment_words])
            merged_output.append(f"Speaker {speaker}: {segment_text}")

    return "\n".join(merged_output)

def generate_summary_with_chatgpt(text):
    """Generate summary using LLM"""
    response = llm_model.prompt(
        text,
        system="You are a helpful assistant, who creates a summary of a given conversation. Capture the essence and summarize in bullet points."
    )
    return response.text()

  checkpoint = torch.load(fp, map_location=device)
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../root/.cache/torch/pyannote/models--pyannote--segmentation/snapshots/c4c8ceafcbb3a7a280c2d357aee9fbc9b0be7f9b/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.4.1+cu121. Bad things might happen unless you revert torch to 1.x.


  state_dict = torch.load(path, map_location=device)
  stats = torch.load(path, map_location=device)


In [9]:
audio_path = 'audio_recording.wav'
transcription = transcribe_audio(audio_path)

In [10]:
diarization = diarize_audio(audio_path)

It can be re-enabled by calling
   >>> import torch
   >>> torch.backends.cuda.matmul.allow_tf32 = True
   >>> torch.backends.cudnn.allow_tf32 = True
See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.



In [11]:
full_text = merge_transcription_and_diarization(transcription, diarization)

In [15]:
generate_summary_with_chatgpt(full_text)

'* The speakers discuss their frustration with finding music that perfectly captures a certain mood or moment.\n* They compare it to trying to describe an unfamiliar color.\n* The speakers mention using genre filters as a solution but find them unsatisfactory.'

In [14]:
whisper_model.to(torch.device(device))

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-5): 6 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=False)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (attn_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (mlp_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((512,), eps=1e-05,