<a href="https://colab.research.google.com/github/pneuly/whisper-asr-colab/blob/main/faster_whisper_share.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import datetime
import fcntl
import importlib.util
import os
import re
import subprocess
import sys
import time
import pip
import numpy as np
import torch
from google.colab import files

# @title 自動文字起こし（Wisper）{ display-mode: "form" }
# @markdown 以下の設定項目を入力しセルを実行（Ctrl+Enter）<font color="red">※設定項目の説明は下にあります</font>

# @markdown #<b>設定</b>
audiopath = 'https://www.youtube.com/watch?v=xAmEQOqtMvA'  # @param {type:"string"}
model_size = "large-v3" # @param ["large-v3", "large-v2", "large", "medium", "small", "base", "tiny"] {allow-input: true}
diarization = True  # @param {type:"boolean"}
password = None  # @param {type:"string"}
start_time = None  # @param {type:"string"}
end_time = None  # @param {type:"string"}
timestamp_offset = None  # @param {type:"string"}
initial_prompt = "定刻 なりましたので、 です。 ます。"  # @param {type:"string"}
realtime = False  # @param {type:"boolean"}
CHUNK_SIZE = 20
BATCH_SIZE = 16
HUGGING_FACE_TOKEN = ""

# @markdown ###<br/><b>〔設定の説明〕</b>
# @markdown <b>audiopath:</b> 文字起こしする音声ファイルの場所<br/>
# @markdown 　　Youtubeの場合： https://www.youtube.com/......<br/>
# @markdown 　　手動で音声をアップロードした場合： 230401_1010.mp3 など<br/>
# @markdown 　　（アップロード完了まで待って実行してください）<br/>
# @markdown 　　<font color="red">空欄の場合はファイルアップロードボタンが表示されます</font>
# @markdown <br/><b>model_size:</b> 音声認識のモデルサイズ（mediumにすると少し精度が落ちるが早い）
# @markdown <br/><b>diarization:</b> 発言者別の文字起こしファイルを作成するか（Falseにすると早い）
# @markdown #### <br/><b><font color= "blue">以下は必要な場合のみ設定</font></b>
# @markdown <b>password:</b> パスワードを指定（Webexなど）</b>
# @markdown <br/><b>start_time:</b> 開始時間 hh:mm:ss</b>（指定しない場合は最初から）
# @markdown <br/><b>end_time:</b> 終了時間 hh:mm:ss（指定しない場合は最後まで）
# @markdown <br/><b>timestamp_offset:</b> タイムスタンプを指定の時間だけずらす hh:mm:ss（Noneの場合はstart_timeと連動）
# @markdown <br/><b>initial_prompt:</b> キーワード（です。 ます。は句読点を付けるために入れています）
# @markdown <br/><b>reatime: </b><font color="red">ストリーミングをリアルタイムで文字起こしをする場合のみオンにしてください。</font>

# ----- 以下変更不要 ------
if audiopath == "":
    audiopath = list(files.upload())[0]


def pip_install(module_name: str, install_name=None):
    if install_name is None:
        install_name = module_name
    if importlib.util.find_spec(module_name) is None:
        pip.main(["install", "-q", install_name])


def subprocess_progress(cmd: list):
    p = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=False
    )
    flag = fcntl.fcntl(p.stdout.fileno(), fcntl.F_GETFL)
    fcntl.fcntl(p.stdout.fileno(), fcntl.F_SETFL, flag | os.O_NONBLOCK)
    while True:
        buf = p.stdout.read()
        if buf is not None:
            sys.stdout.write(buf)
            sys.stdout.flush()
        if p.poll() is not None:
            break
        time.sleep(0.5)


def str2seconds(time_str):
    for fmt in ("%H:%M:%S", "%M:%S", "%S", "%H:%M:%S.%f", "%M:%S.%f", "%S.%f"):
        try:
            return (
                datetime.strptime(time_str, fmt) - datetime(1900, 1, 1)
                ).total_seconds()
        except ValueError:
            pass
    print(f"Error: Unable to parse time string '{time_str}'")
    return None


def dl_audio(url, password=""):
    """Download file from Internet"""
    pip_install("yt_dlp", "yt-dlp")
    # YoutubeDLクラスを使うとダウンロードエラーが発生するため
    # 外部コマンドを使用
    options = ["-x", "-S", "+acodec:mp4a", "-o", "%(title)s.%(ext)s"]
    if password:
        options += ["--video-password", password]
    outfilename = subprocess.run(
        ["yt-dlp", "--print", "filename"] + options + [url],
        capture_output=True,
        text=True,
    ).stdout.strip()
    subprocess_progress(["yt-dlp"] + options + [url])
    return outfilename


def format_timestamp(seconds: float):
    # td = timedelta(seconds=seconds)
    # return f"{str(td)[:10]}"
    hours = seconds // 3600
    remain = seconds - (hours * 3600)
    minutes = remain // 60
    seconds = remain - (minutes * 60)
    return "{:01}:{:02}:{:05.2f}".format(int(hours), int(minutes), seconds)

def time_segment_text(segment):
    _offset_seconds = str2seconds(timestamp_offset) if timestamp_offset else 0.0
    start = segment["start"] + _offset_seconds
    end = segment["end"] + _offset_seconds
    return (f"[{format_timestamp(start)} - {format_timestamp(end)}]")


def add_timestamp(segment):
    return (f"{time_segment_text(segment)} {segment['text'].strip()}")


def fill_missing_speakers(segments):
    prev = None
    for item in segments:
        if 'speaker' in item:
            prev = item['speaker']
        else:
            item.update({'speaker' : prev})
    return segments


def combine_same_speaker(segments):
    from itertools import groupby
    segments = fill_missing_speakers(segments)
    _grouped = [
        list(g) for k, g in groupby(segments, lambda x: x["speaker"])
    ]
    _combined = [
        {"start" : segs[0]["start"],
         "end" : segs[-1]["end"],
         "text" : "\n".join([seg["text"] for seg in segs]).strip(),
         "speaker" : segs[0]["speaker"],
         } for segs in _grouped
    ]
    return _combined


def open_stream(url):
    command = ["yt-dlp", "-g", url, "-x", "-S", "+acodec:mp4a"]
    audio_url = subprocess.check_output(command).decode("utf-8").strip()
    return subprocess.Popen(
        [
            "ffmpeg",
            "-i", audio_url,
            "-vn",
            "-f", "s16le",
            "-acodec", "pcm_s16le",
            "-ac", "1",
            "-ar", "16000",
            "-",
            "-loglevel", "quiet"
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )


def realtime_transcribe(model, data, outfh, initial_prompt = "です。 ます。"):
    previous_text = ""
    segments, info =  model.transcribe(
        np.frombuffer(data, np.int16).astype(np.float32) / 32768.0,
        language='ja',
        #vad_filter=True,
        initial_prompt=initial_prompt)
    for segment in segments:
        print(segment.text)
        outfh.write(segment.text + "\n")
        outfh.flush()
        previous_text = segment.text
    return previous_text


# --- main ---
# dl audio if needed
dlaudio = False
if not realtime:
    if re.match(r"^(https://).+", audiopath):
        dlaudio = True
        # download and return file name
        audiopath = dl_audio(audiopath, password)
    else:
        # ファイルサイズが小さい場合はアップロード途中の可能性があるためチェックする
        filesize = os.path.getsize(audiopath)
        if filesize < 10 ** 7:  # 10MB未満の場合
            time.sleep(10)
            filesize2 = os.path.getsize(audiopath)
            if (filesize2 - filesize) > 0:
                sys.exit(
                    "ファイルのアップロードが終わっていない可能性があります。アップロード完了後再度実行してください。"
                    )
    # trim
    if start_time or end_time:
        pip_install("ffmpeg", "ffmpeg-python")
        import ffmpeg
        if start_time and end_time:
            input = ffmpeg.input(audiopath, ss=start_time, to=end_time)
        elif not start_time and end_time:
            input = ffmpeg.input(audiopath, to=end_time)
        else:
            input = ffmpeg.input(audiopath, ss=start_time)
        input_base, input_ext = os.path.splitext(audiopath)
        input_path = f"{input_base}_trimmed{input_ext}"
        print(f"trimming audio from {start_time} to {end_time}.")
        ffmpeg.output(input, input_path, acodec="copy", vcodec="copy").run(
                overwrite_output=True
                )
    else:
        input_path = audiopath


# Transcribe
if diarization and not realtime: # use WhisperX
    pip_install("whisper", "openai-whisper")
    pip_install("whisperx", "git+https://github.com/m-bain/whisperx.git")
    import whisperx
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = whisperx.load_model(
        model_size,
        device=device,
        compute_type="default",
        asr_options={"initial_prompt" : initial_prompt}
    )
    audio = whisperx.load_audio(input_path)
    result = model.transcribe(
        audio,
        language="ja",
        print_progress=True,
        chunk_size=CHUNK_SIZE,
        batch_size=BATCH_SIZE,
        )
    segments = result["segments"]
else:  # use faster-whisper
    pip_install("whisper", "openai-whisper")
    pip_install("faster_whisper")
    from faster_whisper import WhisperModel
    model = WhisperModel(
        model_size, compute_type="default"  # default: equivalent, auto: fastest
    )
    if not realtime:
        segments = model.transcribe(
            input_path,
            language="ja",
            vad_filter=True,
            initial_prompt=initial_prompt,
            without_timestamps=False,
        )[0]
        segments = [seg._asdict() for seg in segments]
    #
    # realtime trascription
    #
    if realtime:
        pip_install("yt_dlp", "yt-dlp")
        process = open_stream(audiopath)
        previous_text = ""
        buffer = b""
        fh1 = open(
            datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".txt",
            "w",
            encoding="utf-8"
        )
        while True:
            audio_data = process.stdout.read(16000 * 2)
            if process.poll() is not None:
                realtime_transcribe(model, buffer, fh1, previous_text)
                break

            buffer += audio_data
            if len(buffer) >= 16000 * 2 * 30:  # 30 seconds
                #print(len(buffer))
                realtime_transcribe(model, buffer, fh1, previous_text)
                previous_text += "です。 ます。"
                buffer = buffer[- 16000:]  # 0.5 seconds overlap
            else:
                time.sleep(0.1)
        fh1.close()
        del model
        torch.cuda.empty_cache()
        sys.exit(0)

# write results to text files
fh1 = open(f"{os.path.basename(input_path)}.txt", "w", encoding="utf-8")
fh2 = open(f"{os.path.basename(input_path)}（タイムスタンプ付）.txt", "w", encoding="utf-8")
for segment in segments:
    fh1.write(segment["text"] + "\n")
    fh2.write(add_timestamp(segment) + "\n")
fh1.close()
fh2.close()

files.download(fh1.name)
files.download(fh2.name)

if diarization:
# Diarize
    diarize_model = whisperx.DiarizationPipeline(
        use_auth_token=HUGGING_FACE_TOKEN,
        device=device
    )

    diarize_segments = diarize_model(input_path)
    result = whisperx.assign_word_speakers(diarize_segments, result)
    segments = [
        {k: v for k, v in d.items() if k != 'words'}
        for d in result["segments"]]

    segments = combine_same_speaker(segments)

    fh3 = open(f"{os.path.basename(input_path)}（発言者別）.txt", "w", encoding="utf-8")
    for segment in segments:
        fh3.write(time_segment_text(segment) + " ")
        try:
            fh3.write(segment["speaker"] + "\n")
        except:
            fh3.write("\n")
        fh3.write(segment["text"].replace(" ", "") + "\n\n")
    fh3.close()
    files.download(fh3.name)

# gc GPU RAM
del model
torch.cuda.empty_cache()

# Generate docx
pip_install("docx", "python-docx")
from pathlib import Path
from docx import Document
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_LINE_SPACING, WD_ALIGN_PARAGRAPH
from docx.shared import Pt, Mm, RGBColor

DEFAULT_FONT_SIZE = 12
SERIF_FONT = "ＭＳ 明朝"
SANS_FONT = "ＭＳ ゴシック"

def set_rFonts(style, key, value):
    style._element.rPr.rFonts.set(qn(f'w:{key}'), value)

def create_attribute(element, name, value):
    element.set(qn(name), value)

def add_page_number(run):
    fldChar1 = OxmlElement('w:fldChar')
    create_attribute(fldChar1, 'w:fldCharType', 'begin')
    instrText = OxmlElement('w:instrText')
    create_attribute(instrText, 'xml:space', 'preserve')
    instrText.text = "PAGE"
    fldChar2 = OxmlElement('w:fldChar')
    create_attribute(fldChar2, 'w:fldCharType', 'end')
    run._r.append(fldChar1)
    run._r.append(instrText)
    run._r.append(fldChar2)

def txt_to_word(filename):
    # Open the text file in read mode with utf8 encoding
    with open(filename, 'r', encoding='utf8') as f:
        lines = f.readlines()

    # Create a new Word document
    doc = Document()
    sec = doc.sections[-1]

    sec.page_height = Mm(297)
    sec.page_width = Mm(210)
    sec.top_margin = Mm(20)
    sec.bottom_margin = Mm(15)
    sec.left_margin = Mm(25)
    sec.right_margin = Mm(25)
    sec.footer_distance = Mm(8)

    style_normal = doc.styles['Normal']
    style_normal.font.name = SERIF_FONT
    style_normal.font.name_eastasia = SERIF_FONT
    style_normal.font.size = Pt(DEFAULT_FONT_SIZE)

    style_speaker = doc.styles["Heading 1"]
    style_speaker.font.color.rgb = RGBColor(0, 0, 0)
    style_speaker.font.bold = False
    style_speaker.font.name = SANS_FONT
    set_rFonts(style_speaker, "asciiTheme", SANS_FONT)
    style_speaker.font.size = Pt(DEFAULT_FONT_SIZE)
    style_speaker.paragraph_format.space_before = Pt(DEFAULT_FONT_SIZE)
    style_speaker.paragraph_format.space_after = Pt(0)
    style_speaker.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE

    style_ts1 = doc.styles.add_style('発言1行目', WD_STYLE_TYPE.PARAGRAPH)
    style_ts1.font.name = SERIF_FONT
    style_ts1.font.name_eastasia = SERIF_FONT
    style_ts1.font.size = Pt(DEFAULT_FONT_SIZE)
    style_ts1.paragraph_format.space_before = Pt(0)
    style_ts1.paragraph_format.space_after = Pt(0)
    style_ts1.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
    style_ts1.paragraph_format.first_line_indent = Pt(- DEFAULT_FONT_SIZE)

    style_ts2 = doc.styles.add_style('発言2行目', WD_STYLE_TYPE.PARAGRAPH)
    style_ts2.base_style = style_ts1
    style_ts2.paragraph_format.first_line_indent = Pt(DEFAULT_FONT_SIZE)

    sec.footer.paragraphs[0].alignment = WD_ALIGN_PARAGRAPH.CENTER
    add_page_number(sec.footer.paragraphs[0].add_run())

    # Process each line
    pline_count = 0
    for line in lines:
        # Blank line
        if line.strip() == "":
            continue
        # Time and speaker info
        if line.startswith('['):
            elements = line.split(' ')
            speaker = elements[3].strip()
            time = " ".join(elements[:3])
            doc.add_paragraph(speaker + ' ' + time, style=style_speaker)
            pline_count = 1
            continue
        # Transcript
        if pline_count == 1:
            doc.add_paragraph("○　", style=style_ts1)
        else:
            doc.add_paragraph("", style=style_ts2)
        doc.paragraphs[-1].add_run(line.strip())
        pline_count += 1

    outfile = f"{Path(filename).stem}.docx"
    doc.save(outfile)
    return outfile

if diarization:
    docxfile = txt_to_word(fh3.name)
    files.download(docxfile)

# DL audio file
if dlaudio:
    files.download(audiopath)
