Transformer 기반 ASR전체코드

In [None]:
import os
import subprocess
import json
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore")
ROOT_DIR = os.getcwd()

# 1) 데이터 다운로드 및 압축 해제
TAR_PATH = os.path.join(ROOT_DIR, "KoreanReadSpeechCorpus.tar.gz")
if not os.path.isdir(os.path.join(ROOT_DIR, "KoreanReadSpeechCorpus")):
    subprocess.run(
        ["wget", "-O", TAR_PATH, "https://www.openslr.org/resources/97/KoreanReadSpeechCorpus.tar.gz"],
        check=True
    )
    subprocess.run(["tar", "-xvzf", TAR_PATH], check=True)

# 2) Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000): #초기화 method
        super().__init__()
        pe = torch.zeros(max_len, d_model)#0으로 초기화
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) #position 정보
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term) #sin과 cos으로 position을 잡는다.
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:x.size(0), :]

# 3) Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.scale = math.sqrt(self.d_k)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None): #순전파 학습 알고리즘
        B = q.size(0)
        Q = self.w_q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            ext_mask = mask.unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(ext_mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        ctx = torch.matmul(attn, V)
        ctx = ctx.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
        return self.w_o(ctx)

# 4) Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# 5) Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        a = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.drop(a))
        f = self.ff(x)
        x = self.norm2(x + self.drop(f))
        return x

# 6) Transformer Decoder Layer
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, memory, tgt_mask=None, mem_mask=None):
        a1 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.drop(a1))
        a2 = self.cross_attn(x, memory, memory, mem_mask)
        x = self.norm2(x + self.drop(a2))
        f = self.ff(x)
        x = self.norm3(x + self.drop(f))
        return x

# 7) Audio Feature Extractor
class AudioFeatureExtractor(nn.Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.conv1 = nn.Conv1d(80, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(512, d_model, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = F.relu(self.conv3(x))
        return x.transpose(1, 2)

# 8) Korean ASR Transformer Model
class KoreanASRTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6,
                 d_ff=2048, max_len=5000, dropout=0.1):
        super().__init__()
        self.audio_extractor = AudioFeatureExtractor(d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.tgt_emb = nn.Embedding(vocab_size, d_model)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.out_proj = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def _generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        return mask == 0

    def forward(self, audio: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        mem = self.audio_extractor(audio)
        mem = self.pos_enc(mem * math.sqrt(mem.size(-1)))
        mem = self.dropout(mem)
        for layer in self.encoder_layers:
            mem = layer(mem)

        tgt = self.tgt_emb(target) * math.sqrt(self.tgt_emb.embedding_dim)
        tgt = self.pos_enc(tgt)
        tgt = self.dropout(tgt)

        tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        out = tgt
        for layer in self.decoder_layers:
            out = layer(out, mem, tgt_mask, None)

        return self.out_proj(out)

# 9) Dataset for Korean Read Speech Corpus
class KoreanSpeechDataset(Dataset):
    def __init__(self, json_path: str, root_dir: str):
        self.root_dir = root_dir
        self.tokenizer = self._build_tokenizer()
        self.samples = self._load_metadata(json_path)

    def _build_tokenizer(self):
        chars = ['ㄱ','ㄴ','ㄷ','ㄹ','ㅁ','ㅂ','ㅅ','ㅇ','ㅈ','ㅊ','ㅋ','ㅌ','ㅍ','ㅎ',
                 'ㅏ','ㅑ','ㅓ','ㅕ','ㅗ','ㅛ','ㅜ','ㅠ','ㅡ','ㅣ']
        syllables = [chr(i) for i in range(0xAC00, 0xD7A4)]
        english = [*map(chr, range(ord('a'), ord('z')+1)), *map(chr, range(ord('A'), ord('Z')+1))]
        digits = [str(i) for i in range(10)]
        specials = [' ','.',',','!','?','-',"'" ]
        vocab = ['<pad>','<sos>','<eos>','<unk>'] + chars + syllables[:1000] + english + digits + specials
        c2i = {c:i for i,c in enumerate(vocab)}
        i2c = {i:c for c,i in c2i.items()}
        return {'char_to_idx': c2i, 'idx_to_char': i2c, 'vocab_size': len(vocab)}

    def _load_metadata(self, json_path: str):
        with open(json_path, 'r', encoding='utf-8') as f:
            meta = json.load(f)
        samples = []
        for loc, utts in meta.items():
            for uid, info in utts.items():
                path = os.path.join(self.root_dir, loc, f"{uid}.wav")
                samples.append({'audio_path': path, 'text': info.get('text','')})
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        wav, sr = torchaudio.load(s['audio_path'])
        if sr != 16000:
            wav = torchaudio.transforms.Resample(sr,16000)(wav)
        mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, n_mels=80, n_fft=1024, hop_length=256, win_length=1024
        )(wav)
        mel = torch.log(mel + 1e-9).squeeze(0)
        seq = [self.tokenizer['char_to_idx']['<sos>']]
        for ch in s['text']:
            seq.append(self.tokenizer['char_to_idx'].get(ch, self.tokenizer['char_to_idx']['<unk>']))
        seq.append(self.tokenizer['char_to_idx']['<eos>'])
        return {'audio_features': mel, 'text_sequence': torch.tensor(seq, dtype=torch.long)}

# 10) Collate function
def collate_fn(batch):
    audios = [b['audio_features'] for b in batch]
    texts = [b['text_sequence'] for b in batch]
    max_t = max(a.shape[1] for a in audios)
    max_l = max(t.size(0) for t in texts)
    pa = torch.zeros(len(batch), 80, max_t)
    pt = torch.zeros(len(batch), max_l, dtype=torch.long)
    for i, (a, t) in enumerate(zip(audios, texts)):
        pa[i, :, :a.shape[1]] = a
        pt[i, :t.size(0)] = t
    return {'audio_features': pa, 'text_sequences': pt}

# 11) Training loop
def train(model, dataloader, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    model.to(device).train()
    for epoch in range(1, 31):
        total_loss = 0.0
        for batch in dataloader:
            audio = batch['audio_features'].to(device)
            text = batch['text_sequences'].to(device)
            inp, tgt = text[:, :-1], text[:, 1:]
            logits = model(audio, inp)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch:02d} Loss: {total_loss / len(dataloader):.4f}")

# 12) Greedy Inference
def infer(model, dataset, audio_path, device, max_len=100):
    model.to(device).eval()
    wav, sr = torchaudio.load(audio_path)
    if sr != 16000:
        wav = torchaudio.transforms.Resample(sr,16000)(wav)
    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000, n_mels=80, n_fft=1024, hop_length=256, win_length=1024
    )(wav)
    mel = torch.log(mel + 1e-9).squeeze(0).unsqueeze(0).to(device)
    seq = torch.tensor([[dataset.tokenizer['char_to_idx']['<sos>']]], dtype=torch.long).to(device)
    with torch.no_grad():
        for _ in range(max_len):
            logits = model(mel, seq)
            nxt = logits[0, -1].argmax().unsqueeze(0).unsqueeze(0)
            if nxt.item() == dataset.tokenizer['char_to_idx']['<eos>']:
                break
            seq = torch.cat([seq, nxt], dim=1)
    res = ""
    for idx in seq[0, 1:]:
        ch = dataset.tokenizer['idx_to_char'].get(idx.item(), "")
        if ch == "<eos>":
            break
        res += ch
    return res

# 13) Main
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    JSON_PATH = os.path.join(ROOT_DIR, "Korean_Read_Speech_Corpus_sample.json")
    ds = KoreanSpeechDataset(JSON_PATH, ROOT_DIR)
    loader = DataLoader(ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
    model = KoreanASRTransformer(vocab_size=ds.tokenizer['vocab_size'])
    train(model, loader, device)
    torch.save(model.state_dict(), "korspeech_transformer.pth")
    sample = ds.samples[0]
    print("라벨 :", sample['text'])
    print("추론 결과:", infer(model, ds, sample['audio_path'], device))

Epoch 01 Loss: 1.4340
Epoch 01 Loss: 1.4340
Epoch 02 Loss: 1.2252
Epoch 02 Loss: 1.2252
Epoch 03 Loss: 1.1517
Epoch 03 Loss: 1.1517
Epoch 04 Loss: 1.0814
Epoch 04 Loss: 1.0814
Epoch 05 Loss: 1.0034
Epoch 05 Loss: 1.0034
Epoch 06 Loss: 0.9134
Epoch 06 Loss: 0.9134
Epoch 07 Loss: 0.8193
Epoch 07 Loss: 0.8193
Epoch 08 Loss: 0.7535
Epoch 08 Loss: 0.7535
Epoch 09 Loss: 0.6884
Epoch 09 Loss: 0.6884
Epoch 10 Loss: 0.6414
Epoch 10 Loss: 0.6414
Epoch 11 Loss: 0.6073
Epoch 11 Loss: 0.6073
Epoch 12 Loss: 0.5734
Epoch 12 Loss: 0.5734
Epoch 13 Loss: 0.5445
Epoch 13 Loss: 0.5445
Epoch 14 Loss: 0.5328
Epoch 14 Loss: 0.5328
Epoch 15 Loss: 0.5111
Epoch 15 Loss: 0.5111
Epoch 16 Loss: 0.4834
Epoch 16 Loss: 0.4834
Epoch 17 Loss: 0.4658
Epoch 17 Loss: 0.4658
Epoch 18 Loss: 0.4506
Epoch 18 Loss: 0.4506
Epoch 19 Loss: 0.4368
Epoch 19 Loss: 0.4368
Epoch 20 Loss: 0.4241
Epoch 20 Loss: 0.4241
Epoch 21 Loss: 0.4252
Epoch 21 Loss: 0.4252
Epoch 22 Loss: 0.4246
Epoch 22 Loss: 0.4246
Epoch 23 Loss: 0.4000
Epoch 23 L

사전 학습 모델

In [None]:
!pip install transformers torch librosa ipywidgets

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import librosa
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from IPython.display import Audio, display, HTML, Javascript
import ipywidgets as widgets
from google.colab import output
import io
import base64
import wave

# 전역 변수로 모델 초기화
model = None
processor = None

def init_model():
    """모델 초기화 (한 번만 실행)"""
    global model, processor

    if model is None:
        print("Wav2Vec2 한국어 모델 로딩 중...")

        # 추천 pre-trained 모델 목록 (성능순)
        # 1. w11wo/wav2vec2-xls-r-300m-korean (가장 큰 모델, 높은 성능)
        # 2. kresnik/wav2vec2-large-xlsr-korean (현재 사용 중)
        # 3. Hyuk/wav2vec2-korean-v2 (v2 버전)
        # 4. hyyoka/wav2vec2-xlsr-korean-senior (시니어 특화)

        model_name = "kresnik/wav2vec2-large-xlsr-korean"  # 더 큰 모델로 변경

        try:
            processor = Wav2Vec2Processor.from_pretrained(model_name)
            model = Wav2Vec2ForCTC.from_pretrained(model_name)
            print(f"모델 로딩 완료! 사용 모델: {model_name}")
        except Exception as e:
            print(f"메인 모델 로딩 실패: {e}")
            print("백업 모델로 전환...")
            # 백업 모델
            model_name = "kresnik/wav2vec2-large-xlsr-korean"
            processor = Wav2Vec2Processor.from_pretrained(model_name)
            model = Wav2Vec2ForCTC.from_pretrained(model_name)
            print(f"백업 모델 로딩 완료: {model_name}")

def init_model_with_choice(model_choice="best"):
    """
    모델 선택 옵션을 제공하는 초기화 함수

    Args:
        model_choice: 모델 선택
            - "best": 가장 성능이 좋은 모델 (기본값)
            - "fast": 빠른 처리 속도
            - "balanced": 균형잡힌 성능
            - "senior": 시니어 음성 특화
    """
    global model, processor

    if model is None:
        print("Wav2Vec2 한국어 모델 로딩 중...")

        # 모델 선택 매핑
        model_options = {
            "best": "w11wo/wav2vec2-xls-r-300m-korean",        # 가장 큰 모델, 높은 성능
            "fast": "Kkonjeong/wav2vec2-base-korean",          # 빠른 처리
            "balanced": "kresnik/wav2vec2-large-xlsr-korean",  # 균형 잡힌 성능
            "senior": "hyyoka/wav2vec2-xlsr-korean-senior"     # 시니어 특화
        }

        model_name = model_options.get(model_choice, model_options["best"])

        try:
            processor = Wav2Vec2Processor.from_pretrained(model_name)
            model = Wav2Vec2ForCTC.from_pretrained(model_name)
            print(f"모델 로딩 완료! 사용 모델: {model_name}")
        except Exception as e:
            print(f"선택한 모델 로딩 실패: {e}")
            print("기본 모델로 전환...")
            # 기본 모델
            model_name = "kresnik/wav2vec2-large-xlsr-korean"
            processor = Wav2Vec2Processor.from_pretrained(model_name)
            model = Wav2Vec2ForCTC.from_pretrained(model_name)
            print(f"기본 모델 로딩 완료: {model_name}")

def wav2vec2_korean_stt(audio_data, sample_rate=16000):
    """
    Wav2Vec2 한국어 모델로 음성인식

    Args:
        audio_data: 오디오 데이터 (numpy array)
        sample_rate: 샘플레이트

    Returns:
        인식된 텍스트
    """
    global model, processor

    # 모델이 로드되지 않았으면 초기화
    if model is None:
        init_model()

    # 오디오 데이터 전처리
    if len(audio_data) == 0:
        return ""

    # 16kHz로 리샘플링 (필요한 경우)
    if sample_rate != 16000:
        audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)

    # 전처리
    inputs = processor(
        audio_data,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True
    )

    # 예측 수행
    with torch.no_grad():
        logits = model(inputs.input_values).logits

    # 가장 확률이 높은 토큰 선택
    predicted_ids = torch.argmax(logits, dim=-1)

    # 텍스트로 디코딩
    text = processor.batch_decode(predicted_ids)[0]

    return text

def setup_microphone_recording():

    # JavaScript 코드로 마이크 녹음 구현
    js_code = """
    <script>
    let mediaRecorder;
    let audioChunks = [];
    let isRecording = false;

    async function startRecording() {
        try {
            const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
            mediaRecorder = new MediaRecorder(stream);
            audioChunks = [];

            mediaRecorder.ondataavailable = event => {
                audioChunks.push(event.data);
            };

            mediaRecorder.onstop = async () => {
                const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
                const arrayBuffer = await audioBlob.arrayBuffer();
                const base64Audio = btoa(String.fromCharCode(...new Uint8Array(arrayBuffer)));

                // Python으로 오디오 데이터 전송
                google.colab.kernel.invokeFunction('process_audio', [base64Audio], {});
            };

            mediaRecorder.start();
            isRecording = true;

            document.getElementById('recordBtn').innerText = '녹음 중... (클릭하여 중지)';
            document.getElementById('recordBtn').style.backgroundColor = '#ff4444';

        } catch (err) {
            console.error('마이크 접근 오류:', err);
            alert('마이크 접근 권한이 필요합니다.');
        }
    }

    function stopRecording() {
        if (mediaRecorder && isRecording) {
            mediaRecorder.stop();
            isRecording = false;

            document.getElementById('recordBtn').innerText = '음성 인식 중...';
            document.getElementById('recordBtn').disabled = true;

            // 스트림 중지
            mediaRecorder.stream.getTracks().forEach(track => track.stop());
        }
    }

    function toggleRecording() {
        if (isRecording) {
            stopRecording();
        } else {
            startRecording();
        }
    }

    function resetButton() {
        document.getElementById('recordBtn').innerText = ' 녹음 시작';
        document.getElementById('recordBtn').style.backgroundColor = '#4CAF50';
        document.getElementById('recordBtn').disabled = false;
    }
    </script>

    <div style="text-align: center; margin: 20px;">
        <button id="recordBtn" onclick="toggleRecording()"
                style="padding: 15px 30px; font-size: 16px; background-color: #4CAF50;
                       color: white; border: none; border-radius: 5px; cursor: pointer;">
             녹음 시작
        </button>
        <div id="result" style="margin-top: 20px; padding: 10px;
                              background-color: #f0f0f0; border-radius: 5px;
                              min-height: 50px; font-size: 14px;">
            인식 결과가 여기에 표시됩니다...
        </div>
    </div>
    """

    display(HTML(js_code))

def process_audio(audio_base64):
    """녹음된 오디오 처리"""
    try:
        # Base64 디코딩
        audio_data = base64.b64decode(audio_base64)

        # 임시 파일로 저장
        import tempfile
        import os

        with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as temp_file:
            temp_file.write(audio_data)
            temp_filename = temp_file.name

        try:
            # librosa로 오디오 로드 (webm 형식 처리)
            audio_array, sr = librosa.load(temp_filename, sr=16000)

            # 음성인식 수행
            if len(audio_array) > 0:
                result_text = wav2vec2_korean_stt(audio_array)

                # 결과 표시
                display(HTML(f"""
                <script>
                    document.getElementById('result').innerHTML = '<strong>인식 결과:</strong> {result_text}';
                    resetButton();
                </script>
                """))

                print(f"인식 결과: {result_text}")
            else:
                display(HTML("""
                <script>
                    document.getElementById('result').innerHTML = '<strong>오류:</strong> 음성이 감지되지 않았습니다.';
                    resetButton();
                </script>
                """))

        finally:
            # 임시 파일 삭제
            if os.path.exists(temp_filename):
                os.unlink(temp_filename)

    except Exception as e:
        error_msg = str(e).replace("'", "\\'")
        display(HTML(f"""
        <script>
            document.getElementById('result').innerHTML = '<strong>오류:</strong> {error_msg}';
            resetButton();
        </script>
        """))
        print(f"오류 발생: {e}")

def start_realtime_stt(model_choice="best"):
    """
    실시간 STT 시작

    Args:
        model_choice: 모델 선택
            - "best": w11wo/wav2vec2-xls-r-300m-korean (가장 높은 성능)
            - "fast": Kkonjeong/wav2vec2-base-korean (빠른 처리)
            - "balanced": kresnik/wav2vec2-large-xlsr-korean (균형)
            - "senior": hyyoka/wav2vec2-xlsr-korean-senior (시니어 특화)
    """
    print("=== 실시간 한국어 음성인식 ===")
    print("1. 모델 초기화 중...")

    # 모델 초기화
    init_model_with_choice(model_choice)

    print("2. 마이크 녹음 인터페이스 설정...")

    # 오디오 처리 함수 등록
    output.register_callback('process_audio', process_audio)

    # 마이크 녹음 UI 설정
    setup_microphone_recording()

    print("3. 설정 완료!")
    print("위의 '🎤 녹음 시작' 버튼을 클릭하여 음성인식을 시작하세요.")
    print("녹음을 중지하려면 버튼을 다시 클릭하세요.")

    # 모델 정보 출력
    print(f"\n 모델 성능 정보:")
    model_info = {
        "best": "w11wo/wav2vec2-xls-r-300m-korean - 가장 높은 정확도 (300M 파라미터)",
        "fast": "Kkonjeong/wav2vec2-base-korean - 빠른 처리 속도",
        "balanced": "kresnik/wav2vec2-large-xlsr-korean - 균형잡힌 성능",
        "senior": "hyyoka/wav2vec2-xlsr-korean-senior - 시니어 음성 특화"
    }
    print(f"현재 사용 모델: {model_info.get(model_choice, model_info['best'])}")

    print("\n💡 다른 모델 사용법:")
    print("start_realtime_stt('fast')      # 빠른 처리")
    print("start_realtime_stt('balanced')  # 균형잡힌 성능")
    print("start_realtime_stt('senior')    # 시니어 특화")

def simple_file_stt(audio_file):
    """파일 기반 STT (기존 기능 유지)"""
    init_model()

    # 오디오 로드
    audio, sr = librosa.load(audio_file, sr=16000)

    # 음성인식
    result = wav2vec2_korean_stt(audio)

    return result

# 사용 예시
if __name__ == "__main__":

    # 필요한 라이브러리 설치 안내
    print("필요한 라이브러리 설치:")
    print("!pip install transformers torch librosa ipywidgets")
    print("!apt-get update && apt-get install -y ffmpeg")  # ffmpeg 추가
    print()

    # 실시간 STT 시작 (최고 성능 모델 사용)
    start_realtime_stt("best")

    print("\n 추천 모델 성능 순위:")
    print("1. w11wo/wav2vec2-xls-r-300m-korean (가장 높은 정확도)")
    print("2. kresnik/wav2vec2-large-xlsr-korean (현재 기본 모델)")
    print("3. Hyuk/wav2vec2-korean-v2 (개선된 v2 버전)")
    print("4. hyyoka/wav2vec2-xlsr-korean-senior (시니어 음성 특화)")

    print("\n💡 사용 예시:")
    print("start_realtime_stt('best')      # 최고 성능 모델")
    print("start_realtime_stt('fast')      # 빠른 처리 모델")
    print("start_realtime_stt('senior')    # 시니어 특화 모델")

    # 파일 기반 STT 예시 (옵션)
    # result = simple_file_stt("audio_file.wav")
    # print(f"파일 인식 결과: {result}")

필요한 라이브러리 설치:
!pip install transformers torch librosa ipywidgets
!apt-get update && apt-get install -y ffmpeg

=== 실시간 한국어 음성인식 ===
1. 모델 초기화 중...
Wav2Vec2 한국어 모델 로딩 중...
모델 로딩 완료! 사용 모델: w11wo/wav2vec2-xls-r-300m-korean
2. 마이크 녹음 인터페이스 설정...


3. 설정 완료!
위의 '🎤 녹음 시작' 버튼을 클릭하여 음성인식을 시작하세요.
녹음을 중지하려면 버튼을 다시 클릭하세요.

 모델 성능 정보:
현재 사용 모델: w11wo/wav2vec2-xls-r-300m-korean - 가장 높은 정확도 (300M 파라미터)

💡 다른 모델 사용법:
start_realtime_stt('fast')      # 빠른 처리
start_realtime_stt('balanced')  # 균형잡힌 성능
start_realtime_stt('senior')    # 시니어 특화

 추천 모델 성능 순위:
1. w11wo/wav2vec2-xls-r-300m-korean (가장 높은 정확도)
2. kresnik/wav2vec2-large-xlsr-korean (현재 기본 모델)
3. Hyuk/wav2vec2-korean-v2 (개선된 v2 버전)
4. hyyoka/wav2vec2-xlsr-korean-senior (시니어 음성 특화)

💡 사용 예시:
start_realtime_stt('best')      # 최고 성능 모델
start_realtime_stt('fast')      # 빠른 처리 모델
start_realtime_stt('senior')    # 시니어 특화 모델
