In [None]:


import re
from pprint import pp
from tqdm import tqdm

from sentence_transformers import SentenceTransformer

import functools
import platform
import argparse
from time import time
import os
from scipy.spatial.distance import cosine

from langchain_community.llms import HuggingFaceHub,LlamaCpp

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from modelscope.pipelines import pipeline as modelscope_pipeline
from modelscope.utils.constant import Tasks
import torch
from transformers import AutoProcessor,AutoModelForSpeechSeq2Seq,pipeline
from langchain.retrievers import BM25Retriever, EnsembleRetriever

from typing import List
from punctuators.models import PunctCapSegModelONNX

import numpy as np
from FlagEmbedding import FlagReranker
from sklearn.cluster import KMeans

import warnings
warnings.filterwarnings("ignore")

<h1><font color='blue'>1. Function : 移除重複出現的字串</font></h1>

In [2]:
def remove_repeated_words(text): 
    pattern = r'(.+?)\1+'
    while True:
        new_text = re.sub(pattern, r'\1', text) 
        if new_text == text: break 
        text = new_text 
    return text

<h1><font color='blue'>2. Initial Whisper Model</font></h1>

In [None]:

def strtobool(val):
    val = val.lower()
    if val in ('y', 'yes', 't', 'true', 'on', '1'):
        return True
    elif val in ('n', 'no', 'f', 'false', 'off', '0'):
        return False
    else:
        raise ValueError("invalid truth value %r" % (val,))


def str_none(val):
    if val == 'None':
        return None
    else:
        return val


def add_arguments(argname, type, default, help, argparser, **kwargs):
    type = strtobool if type == bool else type
    type = str_none if type == str else type
    argparser.add_argument("--" + argname,
                           default=default,
                           type=type,
                           help=help + ' Default: %(default)s.',
                           **kwargs)

class infer_obj:
    def __init__(self ,audio_path=None ,model_path=None ,use_gpu=None ,language=None,
                 num_beams=None ,batch_size=None ,use_compile=None ,task=None,
                 assistant_model_path=None ,local_files_only=None ,use_flash_attention_2=None ,
                 use_bettertransformer=None,punc_model_name=None):
        import sys
        sys.argv=['']

        parser = argparse.ArgumentParser(description=__doc__)
        add_arg = functools.partial(add_arguments, argparser=parser)
        add_arg("audio_path",  type=str,  default="dataset/test.wav", help="预测的音频路径")
        add_arg("model_path",  type=str,  default="openai/whisper-small", help="合并模型的路径，或者是huggingface上模型的名称")
        add_arg("use_gpu",     type=bool, default=True,      help="是否使用gpu进行预测")
        add_arg("language",    type=str,  default="Chinese", help="设置语言，如果为None则预测的是多语言")
        add_arg("num_beams",   type=int,  default=1,         help="解码搜索大小")
        add_arg("batch_size",  type=int,  default=16,        help="预测batch_size大小")
        add_arg("use_compile", type=bool, default=False,     help="是否使用Pytorch2.0的编译器")
        add_arg("task",        type=str,  default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
        add_arg("assistant_model_path",  type=str,  default=None,  help="助手模型，可以提高推理速度，例如openai/whisper-tiny")
        add_arg("local_files_only",      type=bool, default=True,  help="是否只在本地加载模型，不尝试下载")
        add_arg("use_flash_attention_2", type=bool, default=False, help="是否使用FlashAttention2加速")
        add_arg("use_bettertransformer", type=bool, default=False, help="是否使用BetterTransformer加速")
        add_arg("punc_model_name", type=str, default="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", help="")
        self.args = parser.parse_args()
        
        if not audio_path is None: self.args.audio_path = audio_path
        if not model_path is None: self.args.model_path = model_path
        if not use_gpu is None: self.args.use_gpu = use_gpu
        if not language is None: self.args.language = language
        if not num_beams is None: self.args.num_beams = num_beams
        if not batch_size is None: self.args.batch_size = batch_size
        if not use_compile is None: self.args.use_compile = use_compile
        if not task is None: self.args.task = task
        if not assistant_model_path is None: self.args.assistant_model_path = assistant_model_path
        if not local_files_only is None: self.args.local_files_only = local_files_only
        if not use_flash_attention_2 is None: self.args.use_flash_attention_2 = use_flash_attention_2
        if not use_bettertransformer is None: self.args.use_bettertransformer = use_bettertransformer
        if not punc_model_name is None: self.args.punc_model_name = punc_model_name
        
        pp(self.args)
                
        # 设置设备
        self.device = "cuda" if torch.cuda.is_available() and self.args.use_gpu else "cpu"
        self.torch_dtype = torch.float16 if torch.cuda.is_available() and self.args.use_gpu else torch.float32
        
        # 获取Whisper的特征提取器、编码器和解码器
        self.processor = AutoProcessor.from_pretrained(self.args.model_path)
        
        # 获取模型
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            self.args.model_path, torch_dtype=self.torch_dtype, low_cpu_mem_usage=False, use_safetensors=True,
            use_flash_attention_2=self.args.use_flash_attention_2
        )
        if self.args.use_bettertransformer and not self.args.use_flash_attention_2:
            self.model = self.model.to_bettertransformer()
        # 使用Pytorch2.0的编译器
        if self.args.use_compile:
            if torch.__version__ >= "2" and platform.system().lower() != 'windows':
                self.model = torch.compile(self.model)
        self.model.to(self.device)
        
        # 获取助手模型
        self.generate_kwargs_pipeline = None
        if self.args.assistant_model_path is not None:
            self.assistant_model = AutoModelForCausalLM.from_pretrained(
                self.args.assistant_model_path, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
            )
            self.assistant_model.to(self.device)
            self.generate_kwargs_pipeline = {"assistant_model": self.assistant_model}
        
        # 获取管道
        self.infer_pipe = pipeline("automatic-speech-recognition",
                                   model=self.model,
                                   tokenizer=self.processor.tokenizer,
                                   feature_extractor=self.processor.feature_extractor,
                                   max_new_tokens=128,
                                   chunk_length_s=30,
                                   batch_size=self.args.batch_size,
                                   torch_dtype=self.torch_dtype,
                                   generate_kwargs=self.generate_kwargs_pipeline,
                                   device=self.device)
        
        # 推理参数
        self.generate_kwargs = {"task": self.args.task, "num_beams": self.args.num_beams}
        if self.args.language is not None:
            self.generate_kwargs["language"] = self.args.language
            
        self.punc_model = PunctCapSegModelONNX.from_pretrained(self.args.punc_model_name)
           
    def infer(self, audio_path=None):
        self.result = self.infer_pipe(self.args.audio_path if audio_path is None else audio_path ,
                                      return_timestamps=False,
                                      generate_kwargs=self.generate_kwargs)
        
        self.result = self.result['text'].replace('�','').replace(' ','').replace('\n','')
        self.remove_repeated_words()
        self.add_punc()
        
        return self.result
    
    def remove_repeated_words(self):
        try:
            self.result = remove_repeated_words(self.result)
        except:
            pass

    def add_punc(self):        
        self.result = ''.join(self.punc_model.infer( texts= [self.result] , apply_sbd=True,)[0])
  

whisper_model = 'openai/whisper-large-v2'
whisper = infer_obj(model_path = whisper_model)


<h1><font color='blue'>3. Initial LLM</font></h1>
<h2><font color='blue'>3-1. Ori : Breeze-7B-Instruct-v1.0-Q5_K_M</font></h2>
<h2><font color='blue'>3-2. Fine tune : Breeze-7B-Instruct-v1_0_fine_tuning_20240318_Q5_K_M.gguf</font></h2>

In [None]:
def get_llm(llm_model , temperature=0. , n_gpu_layers=-1):
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    llm = LlamaCpp(model_path=llm_model,
                   max_tokens=4096,
                   n_gpu_layers=n_gpu_layers,
                   n_batch=128,
                   callback_manager=callback_manager,
                   n_ctx=4096,
                   verbose=True,
                   temperature=temperature,
                   streaming=False)
    llm.verbose=False
    return llm

llm_fine_tune = './Breeze-7B-Instruct-v1_0_fine_tuning_20240318_Q5_K_M.gguf'
llm_ori = './breeze-7b-instruct-v1_0-q5_k_m.gguf'

langchain_llm = get_llm(llm_model = llm_fine_tune)
langchain_llm_ori = get_llm(llm_model = llm_ori)




<h1><font color='blue'>4. Initial Embedding Model : bge-reranker-large</font></h1>

In [None]:
model_name = 'BAAI/bge-reranker-large'
langchain_embeddings = HuggingFaceEmbeddings(model_name = model_name)

<h1><font color='blue'>5. Initial Rerank Model : bge-reranker-large</font></h1>

In [6]:
model_name = 'BAAI/bge-reranker-large'
reranker = FlagReranker(model_name , use_fp16=True)

<h1><font color='blue'>6. Function : 產生會議逐字稿</font></h1>

In [7]:
def get_transcript(audio_file):
    transcript = whisper.infer(audio_path = audio_file)
    return transcript

<h1><font color='blue'>7. Function : 產生會議記錄</font></h1>

In [8]:

def conference_assistant(texts_ori):

    # 找到相關"句子"以後, 再往前/後合併的句子數量
    forward_num , backward_num = 2 , 2
    
    rerank_threshold = 0
    
    def get_split_text(text , chunk_size , chunk_overlap):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap, 
            separators=["\n\n", "\n","。"])
        return text_splitter.split_text(text)
    
    # 用來將llm推論結果做split和清理
    replace_punch , split_str , split_str2 = re.compile('\n|。| ') , re.compile('\d{1,2}、') , re.compile('。')
    def clean_str(tmp):
        tmp = split_str.split(replace_punch.sub('' , tmp))
        tmp = [i for i in tmp if not i in ['' , '無']]
        return tmp
    
    # 類似項目合併
    def merge_similar_items(idx , all_item , sub_prompt):
        similar_indexes = np.array(reranker.compute_score([[all_item[idx] ,i] for i in all_item])) > rerank_threshold
        remaining_indexes = similar_indexes == False
        summarize_item = all_item[similar_indexes].tolist()
        if sum(similar_indexes)>1:
            summarize_item = check_token_and_summarize(langchain_llm , summarize_item , sub_prompt)
        return remaining_indexes , summarize_item
    
    # 避免token超過, 如果太長就refine摘要
    def check_token_and_summarize(langchain_llm , summarize_item , sub_prompt , prefix='<<<' , suffix='>>>'):
        ttl_len , first_idx , tmp_topic = len(summarize_item) , 0 , []
        last_idx = ttl_len
        while first_idx < last_idx:
            while langchain_llm.get_num_tokens(f"{prefix}{','.join(tmp_topic + summarize_item[first_idx:last_idx])}{suffix}{sub_prompt}") >= langchain_llm.max_tokens:
                last_idx -= 1
            tmp_topic = clean_str(langchain_llm.invoke(f"{prefix}{','.join(tmp_topic + summarize_item[first_idx:last_idx])}{suffix}{sub_prompt}"))
            first_idx = last_idx
            if first_idx == ttl_len:
                summarize_item = tmp_topic
                break
            else:
                last_idx = ttl_len
        return summarize_item
    
    # 過濾重複出現的句子, 最後給使用者確認真實性使用
    def filter_duplicates(paragraphs):
        final_paragraphs = []
        for num , paragraph in enumerate(paragraphs):
            main_paragraph , included = paragraph , False
            for num2 , check_paragraph in enumerate(paragraphs):
                if num2 != num and check_paragraph in main_paragraph: included = True
            if not included: final_paragraphs.append(main_paragraph)
        final_paragraphs = '    ◎ 摘要段落:\n' + '\n'.join([f'        {num+1}、{paragraph}。'.replace('、。','、') for num,paragraph in enumerate(final_paragraphs)]) + '\n'
        return final_paragraphs

    # 1. 用 RecursiveCharacterTextSplitter切不同chunk_size
    texts = []
    for chunk_size,chunk_overlap in [[100,20],[300,50],[500,100]]:
        texts+=get_split_text(texts_ori , chunk_size , chunk_overlap)
                                     
    # 2. 透過fine tune LLM摘要各段落主題
    summary_items = []
    for text in texts:
        summary_items += clean_str(langchain_llm.invoke(f"<<<{text}>>>這段句子中,提到的<<<主題>>>有哪些?"))

        
    # 3. 透過未fine tune LLM 和關鍵字保留/過濾掉非議題的段落
    # 關鍵字
    command = ['目標', '排程', '議案', '提示', '限時', '提醒', '計畫', '意圖', '負責單位', '提案', '議題',
               '布局', '命令', '案件', '宣導','規劃', '策略', '建議', '指示', '佈局', '限期', '決定', '構想',
               '負責人', '討論', '方案', '指令', '裁定', '評估', '分析', '活動','定案', '期限', '到期日',
               '執行者', '執行官', '負責部門', '決議', '截止日' ,'成立', '整合', '措施', '改善', '改進', '設置',
               '合作','舉辦','推動','計劃','應對','準備','建立' ,'協調','重建','應用','審議','推廣','發展','提升',
               '非法','支持','處理','管理','行動','設計','部署','任務','巡視','修正','調查','應變','預防','監控','處置','籌備',
               '開鑿','事宜','使用','防範','控制','作業','草案','事件','安排','閒置','啟動','分配','需求','協商','接管','修復',
               '引進','監督','通知','檢討','資源','利用','整備','工作','補助','津貼','預算','費用','調整','業務','檢查','事務',
               '列管','提交','展望','督導','原則','規則','健檢','影響','稽查','稽核','檢測','檢驗','安全','工安','服務','維護',
               '生產','預計','預估','估計','報廢','移除','拆除','解除','清除','企劃','編列','管控','專案','演練','加強','價格',
               '加強','強化','保養','良率','yield','稼動率','down','setup','巡檢']
    
    # 排除關鍵字
    exclude_keyword = ['頒獎', '獻獎', '表彰', '表揚','獎勵','得獎','頒發']
    
    is_issue = []
    for item in summary_items:
        # 當主題段落包含關鍵字、且不包含排除關鍵字就直接保留
        if sum([i in item for i in command])>0 and sum([i in item for i in exclude_keyword])==0:
            is_issue.append(True)
        else:
            # 剩下的靠LLM判斷
            infer_result = check_token_and_summarize(langchain_llm_ori , command , f'是否隱含在"""{item}"""資訊中?請回答是或否,不需回答其它敘述' , prefix='"""' , suffix='"""')
            is_issue.append(True if '是' in infer_result else False)
    is_issue = np.array(is_issue)
    summary_items , texts_np = np.array(summary_items)[is_issue] , np.array(texts)[is_issue] 
    
    # 初步濾除後, 剩下的主題 summary_items 作為索引用途
    summary_items_merge_similar = summary_items.copy()

    # 逐項利用rerank主題計算與其他主題的相關性分數
    # 合併相似主題, 再次用LLM摘要
    # 這裡做兩次才做的乾淨...
    for _ in range(2):
        summary_items_len = summary_items_merge_similar.shape[0]
        if summary_items_len>1:
            reserve_indexes = np.array([True]*summary_items_len)
            final_topics = []
            for idx in range(summary_items_len):
                if reserve_indexes[idx]:
                    remaining_indexes , summarize_topic = merge_similar_items(idx , summary_items_merge_similar , "這段句子中,提到的<<<主題>>>有哪些?")
                    reserve_indexes *= remaining_indexes
                    final_topics += summarize_topic
            summary_items_merge_similar = np.array(list(set(final_topics)))
        else:
            break

    # 合併後的主題有出現幻覺, 用rerank 和原始文本段落做比較、並過濾
    check_ori_doc_score = lambda item: sum(np.array(reranker.compute_score([[item ,i] for i in texts_np])) > rerank_threshold)>0
    summary_items_merge_similar = summary_items_merge_similar[[check_ori_doc_score(item) for item in summary_items_merge_similar]]
    summary_items_merge_similar = list(set(summary_items_merge_similar))
    
    # function : 回傳主題對各段落的連結
    get_similar_doc = lambda topic: texts_np[np.array(reranker.compute_score([[topic ,i] for i in summary_items])) > rerank_threshold].tolist()

    # 逐一處理各主題對應的四大項
    items = ['決議' , '計畫' , '期限' , '負責人']
    
    # 紀錄最後格式化後的結果
    summarize_all = ''
    
    # 開始對不同的主題做四大項摘要
    for topic in summary_items_merge_similar:

        # 取得主題對應段落
        topic_doc = get_similar_doc(topic)

        # 加入引用段落回傳用
        final_paragraphs = filter_duplicates(topic_doc)

        # 開始摘要 '決議' , '計畫' , '期限' , '負責人'相關的資訊
        summary = {}
        for item in items:

            # 摘要指定段落的重點
            summarize = check_token_and_summarize(langchain_llm , topic_doc , f'這段句子中,與<<<{topic}>>>相關的<<<{item}>>>有哪些?')

            # 如果超過一個項目, 再判斷這些項目是否合併
            if len(summarize)>1:
                summarize_len = len(summarize)
                reserve_indexes = np.array([True]*summarize_len)
                final_summarize = []
                for idx in range(summarize_len):
                    if reserve_indexes[idx]:
                        remaining_indexes , summarize_item = merge_similar_items(idx , np.array(summarize) , f"這段句子中,與<<<{topic}>>>相關的<<<{item}>>>有哪些?")
                        reserve_indexes *= remaining_indexes
                        final_summarize += summarize_item
                summarize = np.array(list(set(final_summarize)))

            # 有找到重點資訊就再用rerank 過濾掉與主題不相關的部分
            if len(summarize)>0:
                summarize = np.array(summarize)
                summarize = summarize[np.array(reranker.compute_score([[topic , doc] for doc in summarize])) > rerank_threshold].reshape(-1)


            # 整理成輸出格式
            if len(summarize)>0:
                tmp = []
                for num , i in enumerate(summarize):
                    tmp.append(f'        {num+1}、{i}。')
                summarize = "\n" + '\n'.join(tmp) + "\n"
            else:
                summarize = '無。\n'

            # 填回四大項的位置
            summary[item] = summarize


        # 整理最終輸出格式
        tmp = [f'    ◎ {k}:{summary[k] if summary[k]!=[] else "無。"}' for k in summary.keys()]
        summarize_all += f'●  討論項目：{topic}\n' + ''.join(['    ◎ {}:{}'.format(k , summary[k] if summary[k]!=[] else "無。\n") for k in summary.keys()]) + final_paragraphs + '\n'

    return summarize_all



<h1><font color='blue'>8. Function : 會議記錄Gradio </font></h1>

In [None]:
import gradio as gr
import os
os.environ["no_proxy"] = "localhost,0.0.0.0,::1"

def transcribe_fn(file):
    if not file is None:
        text = get_transcript(file)
        summary = conference_assistant(text)
        return summary,text
    else:
        return 'Please Upload Audio File ...' ,'Please Upload Audio File ...'


with gr.Blocks() as demo:

    gr.Markdown("<h1>會議紀錄小幫手AINotes</h1>")

    with gr.Row():
        inputs = gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath")
    with gr.Row():
        inbtw = gr.Button("生成會議紀錄")
        
    with gr.Row():
        with gr.Column(scale=1, min_width=600):
            summary = gr.Textbox(label="會議記錄")
            transcribe = gr.Textbox(label="會議逐字稿")
            
    inbtw.click(transcribe_fn, inputs=[inputs], outputs=[summary,transcribe])


demo.launch()

In [12]:
gr.close_all()