<a href="https://colab.research.google.com/github/ww-6/youtube-chatbot/blob/main/youtube-chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# YouTube Chatbot

In [None]:
!pip install -q -U transformers accelerate bitsandbytes langchain chromadb jq \
sentence-transformers gradio yt_dlp pydantic

In [2]:
import torch
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import GenerationConfig
from transformers import pipeline

from langchain import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import JSONLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain

import yt_dlp
import json
import gc
import gradio as gr
from gradio_client import Client
import datetime
from time import time
from operator import itemgetter

## 1. Prepare data

In [None]:
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
whisper_jax = Client(whisper_jax_api)

def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
    text, runtime = whisper_jax.predict(
        audio_path,
        task,
        return_timestamps,
        api_name='/predict_1',
    )
    return text

In [4]:
def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list:

    '''
    Returns a list of dict with keys 'start', 'end', 'text'
    The segments from whisper jax output are merged to form paragraphs.

    `max_duration` controls how many seconds of the audio's transcripts are merged

    For example, if `max_duration`=60, in the final output, each segment is roughly
    60 seconds.
    '''

    final_output = []
    max_duration = datetime.timedelta(seconds=max_duration)
    segments = whisper_jax_output.split('\n')
    current_start = datetime.datetime.strptime('00:00', '%M:%S')
    current_text = ''

    for i, seg in enumerate(segments):

        text = seg.split(']')[-1].strip()
        end = datetime.datetime.strptime(seg[14:19], '%M:%S')

        if (end - current_start > max_duration) or (i == len(segments)-1):
            # If we have exceeded max duration or
            # at the last segment, stop merging
            # and append to final_output
            current_text += text
            final_output.append({'start': current_start.strftime('%H:%M:%S'),
                                 'end': end.strftime('%H:%M:%S'),
                                 'text': current_text
                                })

            # Update current start and text
            current_start = end
            current_text = ''

        else:
            # If we have not exceeded max duration,
            # keep merging.
            current_text += text

    return final_output

In [5]:
audio_file_number = 1
def yt_audio_to_text(url: str,
                     max_duration: int = 60
                    ):

    global audio_file_number
    global progress
    progress = gr.Progress()
    progress(0.1)

    with yt_dlp.YoutubeDL({'extract_audio': True,
                           'format': 'bestaudio',
                           'outtmpl': f'{audio_file_number}.mp3'}) as video:

        info_dict = video.extract_info(url, download=False)
        global video_title
        video_title = info_dict['title']
        video.download(url)

    progress(0.4)
    audio_file = f'{audio_file_number}.mp3'
    audio_file_number += 1

    result = transcribe_audio(audio_file, return_timestamps=True)
    progress(0.7)

    result = format_whisper_jax_output(result, max_duration=max_duration)
    progress(0.9)

    with open('audio.json', 'w') as f:
        json.dump(result, f)


## 2. Load data

In [6]:
def metadata_func(record: dict, metadata: dict) -> dict:

    metadata['start'] = record.get('start')
    metadata['end'] = record.get('end')
    metadata['source'] =  metadata['start'] + '->' + metadata['end']

    return metadata


def load_data():
    loader = JSONLoader(
        file_path='audio.json',
        jq_schema='.[]',
        content_key='text',
        metadata_func=metadata_func
    )

    data = loader.load()

    return data

## 3. Create embeddings and vector store

In [17]:
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model_kwargs = {'device': device}

embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
                                   model_kwargs=embedding_model_kwargs)

def create_vectordb(data, k: int):
    '''
    `k` is the number of retrieved documents
    '''

    vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
    retriever = vectordb.as_retriever(search_type='similarity',
                                      search_kwargs={'k': k})

    return vectordb, retriever

## 4. Load LLM

In [None]:
from langchain.llms import HuggingFaceHub
import os
os.environ['HUGGINGFACEHUB_API_TOKEN'] = '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'
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024})

## 5. Summarisation
We will use the map-reduce method for summarisation. The documents will be first summarised individually (map step). Then their summaries are combined and reduced further to give a single global summary (reduce step).

In [9]:
# Map
map_template = """Summarise the following text:
{docs}

Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)

In [10]:
# Reduce
reduce_template = """The following is a set of summaries:
{docs}

Take these and distill it into a final, consolidated summary of the main themes.
Answer:"""

reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)

# Takes a list of documents, combines them into a single string, and passes this to llm
combine_documents_chain = StuffDocumentsChain(
    llm_chain=reduce_chain, document_variable_name="docs"
)

After summarising individual documents, the combined summaries could still exceed the max tokens. In that case, we pass in the summaries in batches and create batched summaries. Once the combined batched summaries are less than the max tokens, we pass them all to the LLM.

In [11]:
# Combines and iteravely reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
    # This is final chain that is called.
    combine_documents_chain=combine_documents_chain,
    # If documents exceed context for `StuffDocumentsChain`
    collapse_documents_chain=combine_documents_chain,
    # The maximum number of tokens to group documents into.
    token_max=4000,
)

Finally, we combine our map and reduce chains into one.

In [12]:
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
    # Map chain
    llm_chain=map_chain,
    # Reduce chain
    reduce_documents_chain=reduce_documents_chain,
    # The variable name in the llm_chain to put the documents in
    document_variable_name="docs",
    # Return the results of the map steps in the output
    return_intermediate_steps=False,
)

def get_summary():
    summary = map_reduce_chain.run(data)
    return summary

## 6. Q&A
### 6.1 Contextualising the question
The latest user question may make reference to information in the chat history but we don't want to use the entire history to search for answer in the database because not all of the information is relevant. We want to reformulate the question such that it contains only the relevant information, and can be understood without the chat history.

In [13]:
contextualise_q_prompt = PromptTemplate.from_template(
    '''Given a chat history and the latest user question \
    which might reference the chat history, formulate a standalone question \
    which can be understood without the chat history. Do NOT answer the question, \
    just reformulate it if needed and otherwise return it as is.

    Chat history: {chat_history}

    Question: {question}

    Answer:
    '''
)

contextualise_q_chain = contextualise_q_prompt | llm


# Test
chat_history = []
first_question = 'What is the capital of Australia?'
ai_msg = 'Canberra'
chat_history.extend([HumanMessage(content=first_question),
                     AIMessage(content=ai_msg)])

second_question = 'How far is it from Sydney?'
answer = contextualise_q_chain.invoke({'question': second_question,
                                       'chat_history': chat_history})
print(answer)


    What is the distance between Sydney and Canberra?


### 6.2 Standalone question chain
Reformulating the question takes time and not all questions need contextualising. To speed up the process, we add a sub-chain which determines whether a question needs contextualising or not. If the question is a standalone question, then we can use the user input directly without modifications.

In [14]:
standalone_prompt = PromptTemplate.from_template(
    '''Given a chat history and the latest user question, \
    identify whether the question is a standalone question or the question \
    references the chat history. Answer 'yes' if the question is a standalone \
    question, and 'no' if the question references the chat history. Do not \
    answer anything other than 'yes' or 'no'.

    Chat history:
    {chat_history}

    Question:
    {question}

    Answer:
    '''
)

def format_output(answer: str) -> str:
    # All lower case and remove all whitespace
    return ''.join(answer.lower().split())

standalone_chain = standalone_prompt | llm | format_output


# Test
answer = standalone_chain.invoke({'question': second_question,
                                  'chat_history': chat_history})
print(answer)

no


### 6.3 Q&A chain
Finally, we can build our Q&A chain. The process goes as follows:


1.   Check whether the latest user question needs contextualising or not, using the `standalone_chain`.
2.   If the question is a standalone question, use it to retrieve documents from the database. Otherwise, reformulate the question using `contextualise_q_chain` to get a contextualised question and use it to retrieve documents from the database.
3.   Pass the retrieved documents as `context`, together with the contextualised question to the LLM to receive an answer.


We do not want the LLM to use outside knowledge so we tell the LLM that it can only use the information given in `context` to answer the question.



In [15]:
qa_prompt = PromptTemplate.from_template(
    '''You are an assistant for question-answering tasks. \
    ONLY use the following context to answer the question. \
    Do NOT answer with information that is not contained in \
    the context. If you don't know the answer, just say:\
    "Sorry, I cannot find the answer to that question in the video."

    Context:
    {context}

    Question:
    {question}

    Answer:
    '''
)


def format_docs(docs: list) -> str:
    '''
    Combine documents
    '''
    global sources
    sources = [doc.metadata['start'] for doc in docs]

    return '\n\n'.join(doc.page_content for doc in docs)


def standalone_question(input_: dict) -> str:
    '''
    If the question is a not a standalone question, run contextualise_q_chain
    '''
    if input_['standalone']=='yes':
        return contextualise_q_chain
    else:
        return input_['question']


def format_answer(answer: str,
                  n_sources: int=1,
                  timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:

    if 'cannot find the answer' in answer:
        return answer.strip()
    else:
        timestamps = filter_timestamps(n_sources, timestamp_interval)
        answer_with_sources = (answer.strip()
        + ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
        )
        return answer_with_sources


def filter_timestamps(n_sources: int,
                      timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
    '''Returns a list of timestamps with length `n_sources`.
    The timestamps are at least an `timestamp_interval` apart.
    This prevents returning a list of timestamps that are too
    close together.
    '''
    sorted_timestamps = sorted(sources)
    output = [sorted_timestamps[0]]
    i=1
    while len(output)<n_sources:
        timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')

        try:
            timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
        except IndexError:
            break

        time_diff = timestamp2 - timestamp1

        if time_diff>timestamp_interval:
            output.append(str(timestamp2.time()))

        i += 1

    return output


def setup_rag(url):
    '''Given a YouTube url, set up the vector database and the RAG chain.
    '''

    yt_audio_to_text(url)

    global data
    data = load_data()

    global retriever
    _, retriever = create_vectordb(data, k)

    global rag_chain
    rag_chain = (
        RunnablePassthrough.assign(standalone=standalone_chain)
        | {'question':standalone_question,
           'context':standalone_question|retriever|format_docs
          }
        | qa_prompt
        | llm
    )

    return url



def get_answer(question: str) -> str:

    global chat_history

    ai_msg = rag_chain.invoke({'question': question,
                               'chat_history': chat_history
                              })

    answer = format_answer(ai_msg, n_sources, timestamp_interval)

    chat_history.extend([HumanMessage(content=question),
                         AIMessage(content=answer)])

    return answer

### 6.4 Test our chatbot

In [None]:
# Chatbot settings
n_sources = 3 # Number of sources provided in the answer
k = 5 # Number of documents returned by the retriever
timestamp_interval = datetime.timedelta(minutes=2)
chat_history = []
url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'
setup_rag(url)

#### 6.4.1 Test 1: A question that cannot be answered by the video alone.

In [19]:
question = 'What is the capital city of Australia?'
answer = get_answer(question)
print(answer)

Sorry, I cannot find the answer to that question in the video.


#### 6.4.2 Test 2: A question that can be answered by the video.

In [20]:
question = 'In what ways did transformers improve upon RNN?'
answer = get_answer(question)
print(answer)

Transformers improved upon RNN in several ways. Firstly, they leverage the power of the attention mechanism, which allows them to have an infinite window to reference from, while RNNs suffer from short-term memory and can only reference a limited window. This makes transformers better at encoding and generating longer sequences. Secondly, transformers can stack multiple layers, each layer taking in inputs from the encoder and the layers before it, allowing them to learn to extract and focus You can find more information at these timestamps: 00:00:00, 00:02:15, 00:13:51.


In [None]:
del rag_chain, retriever, data
gc.collect()

## 7. Web app

In [None]:
# Chatbot settings
n_sources = 3 # Number of sources provided in the answer
k = 5 # Number of documents returned by the retriever
timestamp_interval = datetime.timedelta(minutes=2)
default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'

In [None]:
def greet():
    summary = get_summary()
    global gradio_chat_history
    summary_message = f'Here is a summary of the video "{video_title}":'
    gradio_chat_history.append((None, summary_message))
    gradio_chat_history.append((None, summary))
    greeting_message = f'You can ask me anything about the video. I will do my best to answer!'
    gradio_chat_history.append((None, greeting_message))
    return gradio_chat_history

def question(user_message):
    global gradio_chat_history
    gradio_chat_history.append((user_message, None))
    return gradio_chat_history

def respond():
    global gradio_chat_history
    ai_message = get_answer(gradio_chat_history[-1][0])
    gradio_chat_history.append((None, ai_message))
    return '', gradio_chat_history

def clear_chat_history():
    global chat_history
    global gradio_chat_history
    chat_history = []
    gradio_chat_history = []

In [None]:
chat_history = []
gradio_chat_history = []

with gr.Blocks() as demo:

    # Structure
    with gr.Row():
        url_input = gr.Textbox(value=default_youtube_url,
                               label='YouTube URL',
                               scale=5)
        button = gr.Button(value='Go', scale=1)

    chatbot = gr.Chatbot()
    user_message = gr.Textbox(label='Ask a question:')
    clear = gr.ClearButton([user_message, chatbot])


    # Actions
    button.click(setup_rag,
                 inputs=[url_input],
                 outputs=[url_input],
                 trigger_mode='once').then(greet,
                                           inputs=[],
                                           outputs=[chatbot])

    user_message.submit(question,
                        inputs=[user_message],
                        outputs=[chatbot]).then(respond,
                                                inputs=[],
                                                outputs=[user_message, chatbot])

    clear.click(clear_chat_history)



demo.launch(share=True, debug=True)