<a href="https://colab.research.google.com/github/sriramthota1/MSTR/blob/main/MSTR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Multilingual Speech Recognition Model for RAG**

---



In [52]:
import gradio as gr
from transformers import pipeline, RagRetriever, RagTokenForGeneration, RagTokenizer
from langdetect import detect, DetectorFactory
import torch
import os
from googletrans import Translator

DetectorFactory.seed = 0

def initialize_rag_model():
    try:
      rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
      retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
      rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
      return retriever, rag_model, rag_tokenizer

    except Exception as e:
        print(f"Error initializing RAG model: {e}")
        return None, None, None

retriever, rag_model, rag_tokenizer = initialize_rag_model()

if not retriever or not rag_model or not rag_tokenizer:
    raise RuntimeError("Failed to load RAG model and components.")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

In [53]:
translator = Translator()

def speech_to_text_with_lang_detection(audio_file):
    """
    Convert speech to text using Whisper model and detect the language.
    """
    if not os.path.exists(audio_file):
        raise FileNotFoundError(f"Audio file not found: {audio_file}")

    asr = pipeline("automatic-speech-recognition", model="openai/whisper-base")
    try:
        result = asr(audio_file)
        text = result['text']
        detected_lang = detect(text)
        return text, detected_lang
    except Exception as e:
        print(f"Error in speech-to-text conversion: {e}")
        return "Error during transcription", "unknown"

**Translation** Function using Google Translate API

In [54]:
def translate_text(text, src_lang, tgt_lang='en'):
    """
    Translate text from source language to target language using Google Translate API.
    """
    try:
        translated = translator.translate(text, src=src_lang, dest=tgt_lang)
        return translated.text
    except Exception as e:
        print(f"Translation error: {e}")
        return text

**RAG Function**

In [55]:
def generate_answer_with_rag(question):
    """
    Use RAG to generate an answer based on the input question.
    """
    inputs = rag_tokenizer(question, return_tensors="pt")
    try:
        print(f"Input to RAG: {inputs}")
        with torch.no_grad():
            outputs = rag_model.generate(**inputs)
        print(f"RAG outputs: {outputs}")
        answer = rag_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        return answer
    except Exception as e:
        print(f"RAG generation error: {e}")
        return "Error generating answer with RAG."

**Process Audio Query Function**

In [56]:
def process_audio_query(audio_file, target_lang='en'):
    """
    Process an audio file by detecting the language, transcribing it, translating it, and using RAG to generate an answer.
    """

    text, detected_lang = speech_to_text_with_lang_detection(audio_file)
    print(f"Detected language: {detected_lang}")
    print(f"Transcribed text: {text}")


    if detected_lang != target_lang and detected_lang != 'unknown':
        translated_text = translate_text(text, src_lang=detected_lang, tgt_lang=target_lang)
    else:
        translated_text = text


    answer = generate_answer_with_rag(translated_text)

    if answer and target_lang != 'en':
        rag_answer_translated = translate_text(answer, src_lang='en', tgt_lang=target_lang)
    else:
        rag_answer_translated = answer

    return (f"Detected Language: {detected_lang}\n\n"
            f"Transcription: {text}\n\n"
            f"Transcription (translated to {target_lang}): {translated_text}\n\n"
            f"RAG Answer: {answer}\n\n"
            f"RAG Answer (translated to {target_lang}): {rag_answer_translated}")


**Gradio Interface**

In [57]:
def gradio_interface(audio, target_lang='en'):
    return process_audio_query(audio, target_lang)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath"), gr.Textbox(label="Target Language (e.g., 'en', 'es', 'te')")],
    outputs="text",
    title="Multilingual Speech Recognition Model for RAG",
    description="Upload an audio file, select the target language, and the system will automatically detect the language, transcribe it, translate it, and use RAG to generate an answer based on the input text."
)


if __name__ == "__main__":
    iface.launch()


Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://e1e3a8cff350b926a7.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
