# Compare Fine-tuned model & Base model 

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1"  # Set the GPU 1 to use
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
print("CUDA_DEVICE_ORDER:", os.environ.get("CUDA_DEVICE_ORDER"))
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))

CUDA_DEVICE_ORDER: PCI_BUS_ID
CUDA_VISIBLE_DEVICES: 1


In [2]:
checkpoint = "jiwon65/whisper-small_korean-zeroth"

# Whisper model output은 vocabulary 단어의 index
# 이를 실제 문자와 mapping하기 위해 Tokenizer 사용

from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperFeatureExtractor

tokenizer = WhisperTokenizer.from_pretrained(checkpoint, language="Korean", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(checkpoint)
processor = WhisperProcessor.from_pretrained(checkpoint, language="Korean", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from transformers import WhisperForConditionalGeneration

model_tuned = WhisperForConditionalGeneration.from_pretrained("./test_trainer/checkpoint-150/")
model_base = WhisperForConditionalGeneration.from_pretrained("jiwon65/whisper-small_korean-zeroth")

In [6]:
import torch
print("Number of GPUs available:", torch.cuda.device_count())

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

Number of GPUs available: 1


In [7]:
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration

class WhisperTranscriber:
    def __init__(self, processor, model, device):
        self.processor = processor
        self.model = model
        self.device = device
        self.model.to(self.device)

    def transcribe(self, wav_file_path):
        # Load the audio file
        audio, sample_rate = torchaudio.load(wav_file_path)
        
        # Resample the audio to 16kHz if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            audio = resampler(audio)

        # Preprocess the audio to get input features
        input_features = self.processor(audio.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features

        # Generate the transcription
        with torch.no_grad():
            predicted_ids = self.model.generate(input_features.to(self.device))

        # Decode the predicted ids to text
        transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return transcription

In [34]:
from IPython.display import Audio, display

import glob, os
from datasets import load_dataset
file_path = '/data/freetalk_senior/2.Validation/raw_data/1.AI챗봇/1.AI챗봇_자유대화(노인남여)_VALIDATION/'
comparing_raw_data = sorted(glob.glob(os.path.join(file_path, "노인남여_노인대화09_F_1533066829_61_수도권_실내/*.wav"),recursive=True))
MAX_LEN_RAW = len(comparing_raw_data)
MAX_LEN_RAW

242

In [35]:
file_path = '/data/freetalk_senior/2.Validation/labeled_data/1.AI챗봇/1.AI챗봇_라벨링_자유대화(노인남여)_VALIDATION'
folder_path = "노인남여_노인대화09_F_1533066829_61_수도권_실내/*.json"
comparing_labeled_data = sorted(glob.glob(os.path.join(file_path, folder_path),recursive=True))

In [43]:
data_no = 9
label = load_dataset("json", data_files=comparing_labeled_data[data_no%MAX_LEN_RAW])
label = label['train']['발화정보'][0]['stt']

wav_file_path = comparing_raw_data[data_no%MAX_LEN_RAW]

Generating train split: 0 examples [00:00, ? examples/s]

In [44]:
transcriber = WhisperTranscriber(processor, model_tuned, device)
transcription = transcriber.transcribe(wav_file_path)
print("Transcription_tuned:", transcription)

transcriber = WhisperTranscriber(processor, model_base, device)
transcription = transcriber.transcribe(wav_file_path)
print("Transcription_raw:  ", transcription)

print(f"label:               {label}")

Transcription_tuned: 그런데 지금 쓰레덴을 생각하면 치 안에 심각한 문제가 생기고 있다고 그러네
Transcription_raw:   그런데 지금 쓰레덴을 생각하면 취한의 심각한 문제가 생기고 있다고 그러네
label:               그런데 지금 스웨덴을 생각하면 치안에 심각한 문제가 생기고 있다고 그러네


In [45]:
audio = Audio(comparing_raw_data[data_no])
display(audio)