<a href="https://colab.research.google.com/github/ww-6/youtube-chatbot/blob/main/youtube-chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# YouTube Chatbot

In [None]:
!pip install -q -U transformers accelerate bitsandbytes langchain chromadb jq \
sentence-transformers gradio yt_dlp

!pip install -q git+https://github.com/m-bain/whisperx.git

In [None]:
import torch
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import GenerationConfig
from transformers import pipeline

from langchain import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import JSONLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage

import whisperx
import yt_dlp
import json
import gc
import gradio as gr
import datetime
from time import time

## 1. Prepare data

In [None]:
audio_file_number = 1
def yt_audio_to_text(url,
                     device = "cuda",
                     batch_size = 8,
                     compute_type = "float32"
                    ):
    global audio_file_number
    global progress
    progress = gr.Progress()
    progress(0.1)

    with yt_dlp.YoutubeDL({'extract_audio': True,
                           'format': 'bestaudio',
                           'outtmpl': f'{audio_file_number}.mp3'}) as video:

        info_dict = video.extract_info(url, download=False)
        global video_title
        video_title = info_dict['title']
        video.download(url)

    progress(0.3)
    audio_file = f'{audio_file_number}.mp3'
    audio_file_number += 1


    model = whisperx.load_model("large-v2", device, compute_type=compute_type)
    progress(0.6)

    audio = whisperx.load_audio(audio_file)

    result = model.transcribe(audio, batch_size=batch_size)
    progress(0.9)

    with open('audio.json', 'w') as f:
        json.dump(result, f)

    gc.collect()
    torch.cuda.empty_cache()
    del model


## 2. Load data

In [None]:
def metadata_func(record: dict, metadata: dict) -> dict:

    metadata['start'] = str(datetime.timedelta(seconds=round(record.get("start"))))
    metadata['end'] = str(datetime.timedelta(seconds=round(record.get("end"))))
    metadata['source'] =  metadata['start'] + '->' + metadata['end']

    return metadata


def load_data():
    loader = JSONLoader(
        file_path='audio.json',
        jq_schema='.segments[]',
        content_key='text',
        metadata_func=metadata_func
    )

    global data
    data = loader.load()

## 3. Create embeddings and vector store

In [None]:
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda"}

embeddings = HuggingFaceEmbeddings(model_name=model_name,
                                   model_kwargs=model_kwargs)

def create_vectordb(k):
    global data
    global vectordb
    global retriever
    vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
    retriever = vectordb.as_retriever(search_type="similarity",
                                      search_kwargs={"k": k})

## 4. Load LLM

In [None]:
from langchain.llms import HuggingFaceHub
import os
os.environ['HUGGINGFACEHUB_API_TOKEN'] = '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'
repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"max_length": 1024})

## 5. Standalone question chain

In [None]:
standalone_prompt = PromptTemplate.from_template(
    """Given a chat history and the latest user question, \
    identify whether the question is a standalone question or the question \
    references the chat history. Answer 'yes' if the question is a standalone \
    question, and 'no' if the question references the chat history. Do not \
    answer anything other than 'yes' or 'no'.

    Chat history:
    {chat_history}

    Question:
    {question}

    Answer:
    """
)

def format_output(answer: str) -> str:
    # All lower case and remove all whitespace
    return ''.join(answer.lower().split())

standalone_chain = standalone_prompt | llm | format_output

chat_history = []

template = """
    Question:
    {question}

    Answer:
    """
prompt = PromptTemplate.from_template(template)
chain = prompt | llm


# Test
question = "What is the capital of Australia?"
ai_msg = chain.invoke({'question':question})
chat_history.extend([HumanMessage(content=question),
                     AIMessage(content=ai_msg)])

second_question = "What is the population?"
answer = standalone_chain.invoke({'question': second_question,
                                  'chat_history': chat_history})
print(answer)

## 6. Condense question chain

In [None]:
condense_q_prompt = PromptTemplate.from_template(
    """Given a chat history and the latest user question \
    which might reference the chat history, formulate a standalone question \
    which can be understood without the chat history. Do NOT answer the question, \
    just reformulate it if needed and otherwise return it as is.

    Chat history: {chat_history}

    Question: {question}

    Answer:
    """
)

condense_q_chain = condense_q_prompt | llm


# Test
answer = condense_q_chain.invoke({'question': second_question,
                                  'chat_history': chat_history})
print(answer)

## 7. Q&A chain

In [None]:
qa_prompt = PromptTemplate.from_template(
    """You are an assistant for question-answering tasks. \
    Only use the following context to answer the question. \
    Do not answer with information that is not contained in \
    the context. If you don't know the answer, just say the \
    following in exact words: {no_answer_msg}.

    Context:
    {context}

    Question:
    {question}

    Answer:
    """
)


def format_docs(docs: list) -> str:
    '''
    Combine documents
    '''
    global sources
    sources = [doc.metadata['start'] for doc in docs]

    return "\n\n".join(doc.page_content for doc in docs)


def standalone_question(input_: dict) -> str:
    '''
    If the question is a not a standalone question, run condense_q_chain
    '''
    if input_['standalone']=='yes':
        return condense_q_chain
    else:
        return input_['question']


def format_answer(answer: str,
                  n_sources: int=1,
                  timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:

    no_whitespace_answer = ' '.join(answer.split())

    if no_whitespace_answer == no_answer_msg:
        return no_answer_msg
    else:
        timestamps = filter_timestamps(n_sources, timestamp_interval)
        answer_with_sources = (answer.strip()
        + ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
        )
        return answer_with_sources


def filter_timestamps(n_sources: int,
                      timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
    '''Returns a list of timestamps with length `n_sources`.
    The timestamps are at least an `timestamp_interval` apart.
    '''
    sorted_timestamps = sorted(sources)
    output = [sorted_timestamps[0]]
    i=1
    while len(output)<n_sources:
        timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')

        try:
            timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
        except IndexError:
            break

        time_diff = timestamp2 - timestamp1

        if time_diff>timestamp_interval:
            output.append(str(timestamp2.time()))

        i += 1

    return output


def setup_rag(url):
    global retriever
    global rag_chain

    yt_audio_to_text(url)
    load_data()
    create_vectordb(k)

    rag_chain = (
        RunnablePassthrough.assign(standalone=standalone_chain)
        | {'question':standalone_question,
           'context':standalone_question|retriever|format_docs,
           'no_answer_msg':lambda input_: input_['no_answer_msg']
          }
        | qa_prompt
        | llm
    )

    return url



def get_answer(question: str) -> str:
    global chat_history
    global rag_chain
    ai_msg = rag_chain.invoke({"question": question,
                               "chat_history": chat_history,
                               "no_answer_msg": no_answer_msg
                              })

    answer = format_answer(ai_msg, n_sources, timestamp_interval)

    chat_history.extend([HumanMessage(content=question), AIMessage(content=answer)])

    return answer

## 8. Test our chatbot

In [None]:
no_answer_msg = "Sorry, I cannot find the answer to that question in the video."
n_sources = 3 # Number of sources provided in the answer
k = 10 # Number of documents returned by the retriever
timestamp_interval = datetime.timedelta(minutes=5)
chat_history = []
url = 'https://www.youtube.com/watch?v=SZorAJ4I-sA'
setup_rag(url)

### 8.1 Test 1: A question that cannot be answered by the video alone.

In [None]:
question = "Where is Canberra?"
answer = get_answer(question)
print(answer)

### 8.2 Test 2: A question that can be answered by the video.

In [None]:
question = "What are transformers good at?"
answer = get_answer(question)
print(answer)

In [None]:
del rag_chain, retriever, vectordb, data
gc.collect()

## 9. Web app

In [None]:
# Chatbot settings
no_answer_msg = "Sorry, I cannot find the answer to that question in the video."
n_sources = 3 # Number of sources provided in the answer
k = 10 # Number of documents returned by the retriever
timestamp_interval = datetime.timedelta(minutes=3)
default_youtube_url = 'https://www.youtube.com/watch?v=SZorAJ4I-sA'

In [None]:
def greet():
    global gradio_chat_history
    greeting_message = f'You can ask me anything about the video "{video_title}". I will do my best to answer!'
    gradio_chat_history.append((None, greeting_message))
    return gradio_chat_history

def respond(message):
    global gradio_chat_history
    ai_message = get_answer(message)
    gradio_chat_history.append((message, ai_message))
    return "", gradio_chat_history

def clear_chat_history():
    global chat_history
    global gradio_chat_history
    chat_history=[]
    gradio_chat_history=[]

In [None]:
chat_history = []
gradio_chat_history = []

with gr.Blocks() as demo:

    # Structure
    with gr.Row():
        url_input = gr.Textbox(value=default_youtube_url,
                               label='YouTube URL',
                               scale=5)
        button = gr.Button(value='Go', scale=1)

    chatbot = gr.Chatbot()
    user_message = gr.Textbox(label='Ask a question:')
    clear = gr.ClearButton([user_message, chatbot])


    # Actions
    button.click(setup_rag,
                 inputs=[url_input],
                 outputs=[url_input],
                 trigger_mode='once').then(greet,
                                           inputs=[],
                                           outputs=[chatbot])

    user_message.submit(respond,
                        inputs=[user_message],
                        outputs=[user_message, chatbot])

    clear.click(clear_chat_history)



demo.launch(share=True, debug=True)