In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
import re
from pprint import pp
from tqdm import tqdm
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer

import functools
import platform
import argparse
from time import time
import os
from scipy.spatial.distance import cosine

<h1><font color='blue'>0. 設定whisper物件</font></h1>

In [3]:

def strtobool(val):
    val = val.lower()
    if val in ('y', 'yes', 't', 'true', 'on', '1'):
        return True
    elif val in ('n', 'no', 'f', 'false', 'off', '0'):
        return False
    else:
        raise ValueError("invalid truth value %r" % (val,))


def str_none(val):
    if val == 'None':
        return None
    else:
        return val


def add_arguments(argname, type, default, help, argparser, **kwargs):
    type = strtobool if type == bool else type
    type = str_none if type == str else type
    argparser.add_argument("--" + argname,
                           default=default,
                           type=type,
                           help=help + ' Default: %(default)s.',
                           **kwargs)

class infer_obj:
    def __init__(self ,audio_path=None ,model_path=None ,use_gpu=None ,language=None,
                 num_beams=None ,batch_size=None ,use_compile=None ,task=None,
                 assistant_model_path=None ,local_files_only=None ,use_flash_attention_2=None ,use_bettertransformer=None):
        import sys
        sys.argv=['']

        parser = argparse.ArgumentParser(description=__doc__)
        add_arg = functools.partial(add_arguments, argparser=parser)
        add_arg("audio_path",  type=str,  default="dataset/test.wav", help="预测的音频路径")
        add_arg("model_path",  type=str,  default="openai/whisper-small", help="合并模型的路径，或者是huggingface上模型的名称")
        add_arg("use_gpu",     type=bool, default=True,      help="是否使用gpu进行预测")
        add_arg("language",    type=str,  default="Chinese", help="设置语言，如果为None则预测的是多语言")
        add_arg("num_beams",   type=int,  default=1,         help="解码搜索大小")
        add_arg("batch_size",  type=int,  default=16,        help="预测batch_size大小")
        add_arg("use_compile", type=bool, default=False,     help="是否使用Pytorch2.0的编译器")
        add_arg("task",        type=str,  default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
        add_arg("assistant_model_path",  type=str,  default=None,  help="助手模型，可以提高推理速度，例如openai/whisper-tiny")
        add_arg("local_files_only",      type=bool, default=False,  help="是否只在本地加载模型，不尝试下载")
        add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速")
        add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速")
        self.args = parser.parse_args()
        
        if not audio_path is None: self.args.audio_path = audio_path
        if not model_path is None: self.args.model_path = model_path
        if not use_gpu is None: self.args.use_gpu = use_gpu
        if not language is None: self.args.language = language
        if not num_beams is None: self.args.num_beams = num_beams
        if not batch_size is None: self.args.batch_size = batch_size
        if not use_compile is None: self.args.use_compile = use_compile
        if not task is None: self.args.task = task
        if not assistant_model_path is None: self.args.assistant_model_path = assistant_model_path
        if not local_files_only is None: self.args.local_files_only = local_files_only
        if not use_flash_attention_2 is None: self.args.use_flash_attention_2 = use_flash_attention_2
        if not use_bettertransformer is None: self.args.use_bettertransformer = use_bettertransformer
        
        pp(self.args)
                
        # 设置设备
        self.device = "cuda" if torch.cuda.is_available() and self.args.use_gpu else "cpu"
        self.torch_dtype = torch.float16 if torch.cuda.is_available() and self.args.use_gpu else torch.float32
        
        # 获取Whisper的特征提取器、编码器和解码器
        self.processor = AutoProcessor.from_pretrained(self.args.model_path)
        
        # 获取模型
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            self.args.model_path, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True,
            use_flash_attention_2=self.args.use_flash_attention_2
        )
        if self.args.use_bettertransformer and not self.args.use_flash_attention_2:
            self.model = self.model.to_bettertransformer()
        # 使用Pytorch2.0的编译器
        if self.args.use_compile:
            if torch.__version__ >= "2" and platform.system().lower() != 'windows':
                self.model = torch.compile(self.model)
        self.model.to(self.device)
        
        # 获取助手模型
        self.generate_kwargs_pipeline = None
        if self.args.assistant_model_path is not None:
            self.assistant_model = AutoModelForCausalLM.from_pretrained(
                self.args.assistant_model_path, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
            )
            self.assistant_model.to(self.device)
            self.generate_kwargs_pipeline = {"assistant_model": self.assistant_model}
        
        # 获取管道
        self.infer_pipe = pipeline("automatic-speech-recognition",
                                   model=self.model,
                                tokenizer=self.processor.tokenizer,
                                feature_extractor=self.processor.feature_extractor,
                                max_new_tokens=128,
                                chunk_length_s=30,
                                batch_size=self.args.batch_size,
                                torch_dtype=self.torch_dtype,
                                generate_kwargs=self.generate_kwargs_pipeline,
                                device=self.device)
        
        # 推理参数
        self.generate_kwargs = {"task": self.args.task, "num_beams": self.args.num_beams}
        if self.args.language is not None:
            self.generate_kwargs["language"] = self.args.language
           
    def infer(self, audio_path=None):
        self.result = self.infer_pipe(self.args.audio_path, return_timestamps=False, generate_kwargs=self.generate_kwargs)
        return self.result['text']


<h1><font color='blue'>1. 執行whisper, 取得逐字稿</font></h1>

In [4]:

from glob import glob

mp3_files = glob('./*.mp3')
mp3_files = [mp3 for mp3 in mp3_files if '南市政府' in mp3]
whisper_model = 'openai/whisper-large-v2'

transcripts = []
for mp3 in mp3_files:
    if '南市政府' in mp3:
        transcripts.append(infer_obj(audio_path = mp3 , model_path = whisper_model).infer())
        file = mp3.replace('.mp3' , '.txt')
        with open(file , 'w' , encoding='utf8') as f:
            f.write(transcripts[-1])

Namespace(audio_path='.\\20221025 臺南市政府第566次市政會議.mp3', model_path='openai/whisper-large-v2', use_gpu=True, language='Chinese', num_beams=1, batch_size=16, use_compile=False, task='transcribe', assistant_model_path=None, local_files_only=False, use_flash_attention_2=False, use_bettertransformer=False)


This may lead to errors when urllib3 tries to modify verify_mode.
Please report an issue at https://gitlab.com/alelec/pip-system-certs with your
python version included in the description

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Namespace(audio_path='.\\20230606台南市政府第597市政會議 直播.mp3', model_path='openai/whisper-large-v2', use_gpu=True, language='Chinese', num_beams=1, batch_size=16, use_compile=False, task='transcribe', assistant_model_path=None, local_files_only=False, use_flash_attention_2=False, use_bettertransformer=False)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Namespace(audio_path='.\\20230829台南市政府第609次市政會議 直播.mp3', model_path='openai/whisper-large-v2', use_gpu=True, language='Chinese', num_beams=1, batch_size=16, use_compile=False, task='transcribe', assistant_model_path=None, local_files_only=False, use_flash_attention_2=False, use_bettertransformer=False)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Namespace(audio_path='.\\20231225 台南市政府 第626次市政會議 直播.mp3', model_path='openai/whisper-large-v2', use_gpu=True, language='Chinese', num_beams=1, batch_size=16, use_compile=False, task='transcribe', assistant_model_path=None, local_files_only=False, use_flash_attention_2=False, use_bettertransformer=False)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Namespace(audio_path='.\\20240130 台南市政府 第631次市政會議 直播.mp3', model_path='openai/whisper-large-v2', use_gpu=True, language='Chinese', num_beams=1, batch_size=16, use_compile=False, task='transcribe', assistant_model_path=None, local_files_only=False, use_flash_attention_2=False, use_bettertransformer=False)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<h1><font color='blue'>2. 移除重複出現的字串</font></h1>

In [10]:

def remove_repeated_words(text): 
    pattern = r'(\w{2,100})\1'
    while True:
        new_text = re.sub(pattern, r'\1', text) 
        if new_text == text: break 
        text = new_text 
    return text

new_transcript = []
for transcript in transcripts:
    new_transcript.append(remove_repeated_words(transcript))

<h1><font color='blue'>3. 取得會議記錄</font></h1>

In [11]:
def conference_assistant(text , model_name='breeze' , embedding_model = None , max_new_tokens=1000 , mp3_name = None):
    t0 = time()
    assert model_name in ['gemma' , 'breeze'] , 'error model name'
    
    if model_name=='gemma':
        llm_model = r"gemma path"
        pattern = re.compile(r'<end_of_turn>([\W\w]*)[<eos>]?')
    elif model_name=='breeze':
        llm_model = r"Breeze path"
        pattern = re.compile(r'\[/INST\]([\W\w]*)[</s>]?')

    embedding_model = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' if embedding_model is None else embedding_model
        
    def get_prompt(main_prompt , text):
        if model_name=='gemma':
            prompt = f'''<start_of_turn>user
                    {main_prompt}
                    ```{text}```<end_of_turn>'''
        elif model_name=='breeze':
            prompt = f'''<s>You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan.
                         [INST]{main_prompt}
                         text : ```{text}``` [/INST]'''
        return prompt
        
    def get_fix_prompt(text):
        main_prompt = '你是一位非常專業的逐字稿校正專家,以下三個引號 ``` 所包含的文字為逐字稿的一部分,請檢視是否有亂碼、錯別字,若有就進行修正或刪除,在修正完畢後,加入標點符號,除了更正後的文字以外,不需要有額外敘述。'
        prompt = get_prompt(main_prompt , text)

        return prompt

    def get_split_text(text , chunk_size , chunk_overlap):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap, 
            separators=["\n\n", "\n", " ", ""])
        return text_splitter.split_text(text)
    
    def get_llm_tokenizer(llm_model):
        model = AutoModelForCausalLM.from_pretrained(
            llm_model,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )
        tokenizer = AutoTokenizer.from_pretrained(llm_model)
        return model ,tokenizer
    
    def get_fix_text(split_text , model ,tokenizer):
        split_text_after_verify = ''
        for text in tqdm(split_text):
            query_struct = get_fix_prompt(text)
            inputs = tokenizer([query_struct], return_tensors="pt",padding=True,truncation=True)
            output = pattern.findall(tokenizer.batch_decode(model.generate(**inputs, max_new_tokens=max_new_tokens))[0])[0]
            split_text_after_verify += output
        return split_text_after_verify
    
    def get_vectordb(split_text , embedding_model):
        embedding = HuggingFaceEmbeddings(model_name=embedding_model,
                                          model_kwargs={'device': 'cpu'})
        vectordb = Chroma.from_texts(split_text, embedding=embedding)
        return vectordb
    
    def summarize(text , model ,tokenizer):
        main_prompt = '''你是一位非常專業的會議記錄專家,以下三個引號 ``` 所包含的文字為會議記錄重點,請基於這些資訊,進行條列式回答,
                         不要提供錯誤資訊,並確保回答資訊正確,限制在1000字以內。
                         [會議記錄]
                        1. 主要討論點：
                        2. 決策事項：
                        3. 未來行動計畫及截止日期：
                        4. 各項行動計畫的負責人：'''
        prompt = get_prompt(main_prompt , text)
        inputs = tokenizer([prompt], return_tensors="pt",padding=True,truncation=True)
        output = pattern.findall(tokenizer.batch_decode(model.generate(**inputs, max_new_tokens=max_new_tokens))[0])[0]
        return output
    
    each_stage_text = {'input' : text}
   
    # 1. initial model,tokenizer
    model ,tokenizer = get_llm_tokenizer(llm_model)
    
    # 2. use RecursiveCharacterTextSplitter to get split text
    split_text = get_split_text(text , 1000, 0)

    # 3. fix text
    new_text = get_fix_text(split_text , model ,tokenizer)
    each_stage_text['after_fix'] = new_text
    
    # 4. split for RAG
    new_text = get_split_text(new_text , 500 , 50)
    
    # 5. get vectordb
    vectordb = get_vectordb(new_text , embedding_model)
    
    # 6. question to vectordb
    question = '有哪些討論主題?'
    answers = vectordb.max_marginal_relevance_search(question ,k=10 ,fetch_k=10)
    
    # 7. merge all topics
    text = ''.join([answer.page_content for answer in answers])
    each_stage_text['RAG_answers'] = text

    # 8. summary
    summary_text = summarize(text , model ,tokenizer)
    each_stage_text['summary_text'] = summary_text
    
    if not mp3_name is None:
        for k in each_stage_text.keys():
            with open('{}_summary_{}_{}'.format(model_name , k , mp3.replace('.mp3' , '.txt')) , 'w' , encoding='utf8' ) as f:
                f.write(each_stage_text[k])
        
    return each_stage_text , time() - t0 

# using breeze 
each_stage_texts = {}
for transcript , mp3 in zip(new_transcript , mp3_files):
    mp3 = os.path.basename(mp3)
    each_stage_text , spend_time = conference_assistant(transcript , model_name='breeze' , mp3_name = mp3)
    each_stage_texts[mp3] = each_stage_text
    pp(f'{mp3} , spend_time : {spend_time:.2f}s')

        

   

   

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|                                                                                           | 0/19 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  5%|████▎                                                                              | 1/19 [00:27<08:21, 27.83s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 11%|████████▋                                                                          | 2/19 [00:52<07:25, 26.18s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 16%|█████████████                                                                      | 3/19 [01:11<06:02, 22.67s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end gener

'20221025 臺南市政府第566次市政會議.mp3 , spend_time : 387.27s'


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|                                                                                           | 0/20 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  5%|████▏                                                                              | 1/20 [00:22<07:02, 22.25s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 10%|████████▎                                                                          | 2/20 [00:32<04:37, 15.41s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 15%|████████████▍                                                                      | 3/20 [00:54<05:06, 18.03s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end gener

'20230606台南市政府第597市政會議 直播.mp3 , spend_time : 387.18s'


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|                                                                                           | 0/17 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  6%|████▉                                                                              | 1/17 [00:16<04:26, 16.63s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 12%|█████████▊                                                                         | 2/17 [00:40<05:11, 20.74s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 18%|██████████████▋                                                                    | 3/17 [00:54<04:08, 17.74s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end gener

'20230829台南市政府第609次市政會議 直播.mp3 , spend_time : 327.30s'


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|                                                                                           | 0/28 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  4%|██▉                                                                                | 1/28 [00:13<06:04, 13.50s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  7%|█████▉                                                                             | 2/28 [00:24<05:10, 11.93s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 11%|████████▉                                                                          | 3/28 [00:45<06:43, 16.13s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end gener

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'20231225 台南市政府 第626次市政會議 直播.mp3 , spend_time : 434.31s'


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|                                                                                           | 0/18 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  6%|████▌                                                                              | 1/18 [00:19<05:38, 19.89s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 11%|█████████▏                                                                         | 2/18 [00:41<05:38, 21.14s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
 17%|█████████████▊                                                                     | 3/18 [00:52<04:01, 16.10s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end gener

'20240130 台南市政府 第631次市政會議 直播.mp3 , spend_time : 382.03s'


<h1><font color='blue'>4. 計算平均cosine similarity</font></h1>

In [12]:

get_cosine_similarity = lambda GT , pred: 1 - cosine(GT , pred)

model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

cosine_similarity = []
for k in each_stage_texts.keys():
    GT_file = k.replace('.mp3' , '_GT.txt')
    if '南市政府' in k and os.path.exists(GT_file):
        pred = model.encode(each_stage_texts[k]['summary_text'])
        with open(GT_file , 'r' , encoding = 'utf8') as f:
            GT = model.encode(f.read())
        cosine_similarity.append(get_cosine_similarity(GT , pred))

pp(f'AVG cosine similarity : {sum(cosine_similarity) / len(cosine_similarity):.2f}')



'AVG cosine similarity : 0.40'


In [13]:
pp(cosine_similarity)

[0.29021766781806946,
 0.49101749062538147,
 0.3836284577846527,
 0.3403230905532837,
 0.5067058205604553]
