In [None]:
# SikhiToTheMax/Khalis libraries
import banidb
from anvaad_py import firstLetters

import requests # to get raags
from sentence_transformers import SentenceTransformer, util # for embeddings
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches # Import for creating legend patches

In [None]:
import os
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
# Supress Hugging Face's Symlinks warning. It wants to use Symlinks to save disk space
# But that requires you to turn on Windows developer mode, which comes with its own risks.

**Exploration**

In [None]:
shabad = banidb.random()

In [None]:
shabad

In [None]:
shabad.keys()

In [None]:
print(shabad['shabad_id'])

In [None]:
print(shabad['source_uni'])

In [None]:
print(shabad['writer'])

In [None]:
print(shabad['ang'])

In [None]:
shabad['verses'][0]

In [None]:
shabad['verses'][1] # starting verse of the shabad

In [None]:
shabad['verses'][-1] # ending verse of the shabad

**Searching a shabad**

In [None]:
gurbani_text = "ਥਿਰੁ ਘਰਿ ਬੈਸਹੁ ਹਰਿ ਜਨ ਪਿਆਰੇ" # Thir ghar baiso har jan piaare

In [None]:
%%time
# Try searching 10 times across all sources, all angs, all angs, and all writers
# The default search options for banidb.search() are:
# banidb.search(query, searchtype=1, source='all', larivaar=False,
#              ang=None, raag=None, writer='all', page=1, results=None)
for i in range(10):
    ascii_query = firstLetters(gurbani_text)
    shabad_data = banidb.search(ascii_query)

In [None]:
print(ascii_query)
shabad_data

Now, let's find a shabad using a restricted search space

In [None]:
shabad_data['pages_data']['page_1'][0]

To restrict search space, the options in banidb.search are:
1. source
2. ang
3. raag
4. writer<br>
Let's try to extract these from shabad_data

*Source*

In [None]:
banidb.sources()

I think that the search space can be universally restricted to B (I don't know how this is different from S), D, G, N from https://banidbpy.readthedocs.io/en/latest/sources.html but it seems that the search function only allows for a string i.e. a single source instead of multiple ones: https://banidbpy.readthedocs.io/en/latest/searchdb.html

In [None]:
# banidb.search() accepts source ID, not the source in english or unicode (https://banidbpy.readthedocs.io/en/latest/sources.html)
# Map 'source_eng', which is what we get from banidb.search()'s output back to source ID so that we can use that in subsequent
# calls to banidb.search()
source_to_id_dict = {}
for item in banidb.sources():
    source_to_id_dict[item['source_eng']] = item['source_id']
source_to_id_dict

In [None]:
# Extract source for the shabad
shabad_data['pages_data']['page_1'][0]['source']

In [None]:
# Sanity check
shabad_source = shabad_data['pages_data']['page_1'][0]['source']['en']
print(shabad_source)
banidb.search(ascii_query, source = source_to_id_dict.get(shabad_source))

*Ang*

In [None]:
shabad_ang = shabad_data['pages_data']['page_1'][0]['source']['ang']
print(shabad_ang)
# Sanity check
banidb.search(ascii_query, ang = shabad_ang)

*Raag*

In [None]:
try:
    raags = banidb.raags() # If this fails, try the one below
except:
    print("banidb.raags() doesn't work. Retrieving raags directly from the API")
    def get_raags_directly():
        """
        Fetches the list of raags directly from the BaniDB API,
        bypassing the banidb.raags() function.
        """
        # The API endpoint that the banidb library uses for raags
        url = "https://api.banidb.com/v2/raags"
        
        try:
            response = requests.get(url)
            # Raise an exception if the request returned an error (e.g., 404, 500)
            response.raise_for_status()
            
            # Convert the JSON response to a Python dictionary
            data = response.json()
            
            # The actual raag data is in the 'rows' key, skipping the header row
            raags_list = []
            for row in data['rows'][1:]:
                raag = {
                    'raag_id': row.get('RaagID'),
                    'raag_uni': row.get('RaagUnicode'),
                    'raag_eng': row.get('RaagEnglish')
                }
                raags_list.append(raag)
                
            return raags_list
    
        except requests.exceptions.RequestException as e:
            print(f"A network error occurred: {e}")
            return None
        except (KeyError, IndexError) as e:
            print(f"The API response format may have changed. Error: {e}")
            return None

        # --- Usage ---
    raags = get_raags_directly()
        
if raags:
    print("Successfully retrieved:")
    # Print the first 5 raags as an example
    for raag in raags:
        print(raag)

In [None]:
# banidb.search() accepts source ID, not the source in english or unicode (https://banidbpy.readthedocs.io/en/latest/sources.html)
# Map 'raag_eng', which is what we get from banidb.search()'s output back to 'raag_id' so that we can use that in subsequent
# calls to banidb.search()
raag_to_id_dict = {}
for item in raags: # or banidb.raags() if it doesn't fail
    raag_to_id_dict[item['raag_eng']] = item['raag_id']
raag_to_id_dict

In [None]:
shabad_raag = shabad_data['pages_data']['page_1'][0]['source']['raagen']
print(shabad_raag)
# Sanity check
banidb.search(ascii_query, raag = raag_to_id_dict.get(shabad_raag))

*Writer*

In [None]:
banidb.writers()

In [None]:
# banidb.search() accepts writer ID, not the writer name or unicode (https://banidbpy.readthedocs.io/en/latest/writers.html)
# Map 'writer_name', which is what we get from banidb.search()'s output back to 'writer_id' so that we can use that in subsequent
# calls to banidb.search()
writer_to_id_dict = {}
for item in banidb.writers():
    writer_to_id_dict[item['writer_name']] = item['writer_id']
writer_to_id_dict

In [None]:
shabad_writer = shabad_data['pages_data']['page_1'][0]['source']['writer']
print(shabad_writer)
# Sanity check
banidb.search(ascii_query, writer = writer_to_id_dict.get(shabad_writer))

*Checking speed for restricted space search*

In [None]:
shabad_source_id = source_to_id_dict.get(shabad_source)
shabad_raag_id   = raag_to_id_dict.get(shabad_raag)
shabad_writer_id = writer_to_id_dict.get(shabad_writer)

In [None]:
%%time
for i in range(10):
    ascii_query = firstLetters(gurbani_text)
    shabad_data = banidb.search(ascii_query, source=shabad_source_id, ang=shabad_ang, raag=shabad_raag_id, writer=shabad_writer_id)

In [None]:
banidb.search(ascii_query, source=shabad_source_id, ang=shabad_ang, raag=shabad_raag_id, writer=shabad_writer_id)

Wow. Seems like this didn't help at all. Also, while it is good to see how to improve the tuk/line retrieval time, at the end of the day, we don't even know if we can get ASR correct. If the ASR is incorrect, then the first letters would be incorrect too, and doing a banidb.search() would be futile no matter how optimized it is. Similarly, it is very common for ragis to repeat phrases e.g. "satgur tumre kaaj savare, kaaj savare, kaaj savare". This would also cause a banidb.search() to fail. So, a better approach is to take the ASR output and convert that into an embedding that can then be compared against the verses of the shabad. To do so, we first need to get all the verses of the shabad.

**Shabad verses and embeddings**

It is great that we can identify a shabad from across all Sikh scriptures. 
1. We have also found that restricting the search space in banidb.search() doesn't seem to improve the time required to identify.
2. In addition, banidb would require us to have the correct search query to find a shabad. Since we will rely on ASR for getting the text for the gurbani tuk/line, we can't be certain of that. <br>

So, we need to find a new method. Here is the proposed approach:
1. Identify the shabad by using banidb search via skip-gram approach. That is, take the whole text sequence and split it into multiple parts. Then, pass this into banidb.search(). If the text sequence's first letters using ASR turn out to be "tgvhjp" ("thir ghar baiso har jan piaare"), where "baiso" is incorrectly identified as "vaiso", then we can pass multiple queries by splitting - "tgvhjp", "gvhj", "tvhj", "vhjp" etc. The return from these queries can be used to identify shabad IDs. The median of these results can be picked as the relevant shabad. We can also store other results in memory in case the search query fails in the future.
2. Once the shabad is identified, we will get all its verses using banidb.shabad(). Then, we will no longer use the skip-gram approach for subsequent searches. Instead, we will use the more powerful method of embeddings. We will embed the ASR output as embeddings and compare them to the embeddings of the verses for that shabad. This search space should be quite small and we hope to be able to identify the verse fairly quickly.
3. We need to continuously monitor the words being sung and ensure that the verse/tuk identified is same as the previous one. If it changes (with high confidence), we need to switch to giving a new verse as the output.

In [None]:
import re

In [None]:
def clean_gurbani_verse(text):
    """
    Removes verse numbers, punctuation, and the word 'Rahaao' from a Gurbani line.
    This is needed because in ASR, we will never get these in the output as no raagi
    sings these.
    """
    # Remove numbers (both Gurmukhi and Arabic)
    text = re.sub(r'[\u0A66-\u0A6F0-9]+', '', text)
    # Remove the word 'Rahaao' (ਰਹਾਉ)
    text = text.replace('ਰਹਾਉ', '')
    # Remove Danda (॥), Visarg (ਃ), and other common punctuation
    text = re.sub(r'[॥।☬ਃ|]', '', text)
    # Remove extra spaces and strip whitespace from ends
    text = ' '.join(text.split())
    return text

In [None]:
# Sanity check
shabad_id = shabad_data['pages_data']['page_1'][0]['shabad_id']
shabad_data2 = banidb.shabad(shabad_id)
[(clean_gurbani_verse(item['verse']), item['verse']) for item in shabad_data2['verses']]

Using sentence-transformers/distiluse-base-multilingual-cased-v2 gives poor results. Same with ai4bharat/indic-bert' but that's because it is not a sentence transformer.

In [None]:
from sentence_transformers import SentenceTransformer, util, models
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

class ShabadMatcher:
    def __init__(self, model_name: str, pooling: str = 'cls', use_cls: bool = False):
        self.model_name = model_name
        self.pooling = pooling
        self.use_cls = use_cls

        if model_name == "ai4bharat/indic-bert":
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
            self.model = AutoModel.from_pretrained(model_name)
            self.encode_fn = self._encode_indic_bert

        elif model_name == "bert-base-multilingual-cased":
            word_embedding_model = models.Transformer(model_name, max_seq_length=128)
            pooling_model = models.Pooling(
                word_embedding_model.get_word_embedding_dimension(),
                pooling_mode_mean_tokens=(pooling == 'mean'),
                pooling_mode_cls_token=(pooling == 'cls'),
                pooling_mode_max_tokens=False
            )
            self.model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
            self.encode_fn = self._encode_sbert

        elif model_name in [
            "l3cube-pune/indic-sentence-similarity-sbert",
            "sentence-transformers/LaBSE",
            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
        ]:
            self.model = SentenceTransformer(model_name)
            self.encode_fn = self._encode_sbert

        else:
            raise ValueError(f"Unsupported model: {model_name}")

    def _encode_sbert(self, texts):
        return self.model.encode(texts, convert_to_tensor=True)

    def _encode_indic_bert(self, texts):
        if isinstance(texts, str):
            texts = [texts]
        inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
        if self.use_cls:
            embeddings = outputs.last_hidden_state[:, 0]
        else:
            mask = inputs['attention_mask'].unsqueeze(-1).expand(outputs.last_hidden_state.size())
            masked_embeddings = outputs.last_hidden_state * mask
            summed = masked_embeddings.sum(dim=1)
            counts = mask.sum(dim=1)
            embeddings = summed / counts
        return F.normalize(embeddings, p=2, dim=1)

    def match(self, shabad_data, asr_output_phrase, min_words=3):
        if len(asr_output_phrase.strip().split()) < min_words:
            print(f"[Skipping] Phrase too short (min_words={min_words}): '{asr_output_phrase}'")
            return None
            
        verses = [clean_gurbani_verse(item['verse']) for item in shabad_data['verses']]
        shabad_embeddings = self.encode_fn(verses)
        asr_embedding = self.encode_fn(asr_output_phrase)

        if isinstance(shabad_embeddings, torch.Tensor):
            scores = F.cosine_similarity(asr_embedding, shabad_embeddings)
            best_idx = scores.argmax()
            best_score = scores[best_idx]
        else:
            scores = util.cos_sim(asr_embedding, shabad_embeddings)
            best_idx = scores.argmax()
            best_score = scores[0][best_idx]

        return {
            'best_verse': verses[best_idx],
            'score': round(float(best_score), 2),
            'index': int(best_idx)
        }

In [None]:
models_to_test = {
    "BBMC": ShabadMatcher("bert-base-multilingual-cased", pooling="cls"),
    "L3Cube": ShabadMatcher("l3cube-pune/indic-sentence-similarity-sbert"),
    "LaBSE": ShabadMatcher("sentence-transformers/LaBSE"),
    "MiniLM": ShabadMatcher("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"),
    "IndicBERT (CLS)": ShabadMatcher("ai4bharat/indic-bert", use_cls=True),
    "IndicBERT (Mean)": ShabadMatcher("ai4bharat/indic-bert", use_cls=False)
}

In [None]:
def test_phrases(asr_phrases, correct_indices):
    results_verse =  {key:[] for key in models_to_test.keys()}
    results_status =  {key:[] for key in models_to_test.keys()}
    results_score =  {key:[] for key in models_to_test.keys()}
    assert(len(asr_phrases) == len(correct_indices))
    for i, phrase in enumerate(asr_phrases):
        print(f"\n------For {phrase}------")
        for name, matcher in models_to_test.items():
            result = matcher.match(shabad_data2, phrase)
            status = False if result is None else int(result['index'] == correct_indices[i])
            results_status[name].append(status)
            verse = None if result is None else result['best_verse']
            results_verse[name].append(verse)
            score = None if result is None else result['score']
            results_score[name].append(score)
            print(f"{name}: {result}. Correct? {bool(status) if status is not None else status}")
    return results_status, results_verse, results_score

In [None]:
def plot_similarity_by_status(models_to_test, statuses, scores):
    """
    Creates a 3x2 grid of bar plots showing similarity scores,
    grouped and sorted by status, with a clear legend.
    """
    fig, axs = plt.subplots(3, 2, figsize=(15, 13))
    axs = axs.flatten()

    for i, model in enumerate(models_to_test):
        status_list = statuses.get(model, [])
        score_list = scores.get(model, [])

        # Separate data by status
        status_1_data = sorted(
            [(score_list[j], f'Inst {j}') for j, s in enumerate(status_list) if s == 1],
            key=lambda x: x[0], reverse=True
        )
        status_0_data = sorted(
            [(score_list[j], f'Inst {j}') for j, s in enumerate(status_list) if s == 0],
            key=lambda x: x[0], reverse=True
        )
        
        # Unzip sorted data, handling cases where a status might be missing
        scores_1, labels_1 = zip(*status_1_data) if status_1_data else ([], [])
        scores_0, labels_0 = zip(*status_0_data) if status_0_data else ([], [])

        # Combine for plotting
        combined_scores = list(scores_1) + list(scores_0)
        combined_labels = list(labels_1) + list(labels_0)
        combined_colors = ['tab:blue'] * len(scores_1) + ['tab:orange'] * len(scores_0)
        
        ax = axs[i]
        if combined_scores:
            ax.bar(range(len(combined_scores)), combined_scores, color=combined_colors)
        
        ax.set_title(model)
        ax.set_xlabel('Instance')
        ax.set_ylabel('Similarity Score')
        ax.set_ylim(np.min(combined_scores)-0.01, np.max(combined_scores)+0.01)
        ax.set_xticks(range(len(combined_labels)))
        ax.set_xticklabels(combined_labels, rotation=45, ha='right')

    # Create legend handles and add them to the figure just once
    legend_patch_1 = mpatches.Patch(color='tab:blue', label='Correct verse')
    legend_patch_0 = mpatches.Patch(color='tab:orange', label='Incorrect verse')
    fig.legend(handles=[legend_patch_1, legend_patch_0], loc='upper right')

    plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout to make space for the legend
    plt.show()

In [None]:
short_asr_phrases = [
    'ਦੁਸਟ', 'ਦੁਸ਼ਟ', 'ਦੁਸਟ ਦੂਤ', 'ਕਰ ਦੀਨੇ', 'ਕਰਤਾਰੇ', 'ਕੀਨੋ ਦਾਨ',
    'ਦਾਨ', 'ਨਿਰਭੌ', 'ਸਾਧ ਸੰਗ', 'ਸਾਧ', 'ਅੰਤਰ', 'ਜਾਮੀ', 'ਪਕੜੀ ਪ੍ਰਭ'
]
short_asr_correct = [3, 3, 3, 5, 4, 8,
                     8, 7, 8, 8, 9, 9, 10]
statuses, verses, scores = test_phrases(short_asr_phrases, short_asr_correct)

In [None]:
# 3 word phrases, all beginning and ending within the same verse (ideal use case)
long_asr_phrases = [
    'ਸਤਗੁਰ ਤੁਮਰੇ ਕਾਜ', 'ਤੁਮਰੇ ਕਾਜ ਸਵਾਰੇ', 'ਦੁਸ਼ਟ ਦੂਤ ਪਰਮੇਸਰ', 'ਦੂਤ ਪਰਮੇਸਰ ਮਾਰੇ', 'ਜਨ ਕੀ ਪੈਜ', 'ਕੀ ਪੈਜ ਰਖੀ', 'ਪੈਜ ਰਖੀ ਕਰਤਾਰੇ',
    'ਨਿਰਭੌ ਹੋ ਭਜੋ', 'ਹੋ ਭਜੋ ਭਗਵਾਨ', 'ਹੋ ਭਜੋ ਪਗਵਾਨ', 'ਬਾਦ ਸ਼ਾਹ ਸ਼ਾਹ', 'ਸ਼ਾਹ ਸ਼ਾਹ ਸਬ', 'ਵੱਸ ਕਰ ਦੀਨੇ', 'ਅੰਮ੍ਰਿਤ ਨਾਮ ਮਹਾ']
long_asr_correct = [2, 2, 3, 3, 4, 4, 4,
                       7, 7, 7, 5, 5, 5, 6]
    
long_statuses, long_verses, long_scores = test_phrases(long_asr_phrases, long_asr_correct)

In [None]:
plot_similarity_by_status(models_to_test, long_statuses, long_scores)

0 'ਗਉੜੀ ਮਹਲਾ',
1 'ਥਿਰੁ ਘਰਿ ਬੈਸਹੁ ਹਰਿ ਜਨ ਪਿਆਰੇ',
2 'ਸਤਿਗੁਰਿ ਤੁਮਰੇ ਕਾਜ ਸਵਾਰੇ',
3 'ਦੁਸਟ ਦੂਤ ਪਰਮੇਸਰਿ ਮਾਰੇ',
4 'ਜਨ ਕੀ ਪੈਜ ਰਖੀ ਕਰਤਾਰੇ',
5 'ਬਾਦਿਸਾਹ ਸਾਹ ਸਭ ਵਸਿ ਕਰਿ ਦੀਨੇ',
6 'ਅੰਮ੍ਰਿਤ ਨਾਮ ਮਹਾ ਰਸ ਪੀਨੇ',
7 'ਨਿਰਭਉ ਹੋਇ ਭਜਹੁ ਭਗਵਾਨ',
8 'ਸਾਧਸੰਗਤਿ ਮਿਲਿ ਕੀਨੋ ਦਾਨੁ',
9 'ਸਰਣਿ ਪਰੇ ਪ੍ਰਭ ਅੰਤਰਜਾਮੀ',
10 'ਨਾਨਕ ਓਟ ਪਕਰੀ ਪ੍ਰਭ ਸੁਆਮੀ'

In [None]:
# 3 word phrases, beginning in one verse and ending in another (edge case)
mixed_asr_phrases = [
    'ਜਨ ਪਿਆਰੇ ਸਤਗੁਰ', 'ਪਿਆਰੇ ਸਤਗੁਰ ਤੁਮਰੇ', 'ਕਾਜ ਸਵਾਰੇ ਦੁਸ਼ਟ', 'ਸਵਾਰੇ ਦੁਸ਼ਟ ਦੂਤ', 'ਪਰਮੇਸਰ ਮਾਰੇ ਜਨ', 
    'ਮਾਰੇ ਜਨ ਕੀ',  'ਰਖੀ ਕਰਤਾਰੇ ਬਾਦ', 'ਪੀਨੇ ਨਿਰਭੌ ਹੋ', 'ਭਜੋ ਪਗਵਾਨ ਸਾਧ', 'ਕੀਨੋ ਦਾਨ ਸਰਣ', 'ਦਾਨ ਸਰਣ ਪਰੇ', 'ਦਾਨ ਸਰਣ ਭਰੇ']
# detect the correct phrase/index as the one with the dominant number of words
mixed_asr_correct = [1, 2, 2, 3, 3, 
                     4, 4, 7, 7, 8, 9, 9]
mixed_statuses, mixed_verses, mixed_scores = test_phrases(mixed_asr_phrases, mixed_asr_correct)

In [None]:
plot_similarity_by_status(models_to_test, mixed_statuses, mixed_scores)

In [None]:
shabad_data['pages_data']['page_1'][0]['shabad_id']

In [None]:
#TODO:
# 1. Add logic so that if all the words of a verse have not been spoken, we will not switch to the next verse. This might be tricky because ragis could, 
# in principle, just sing a few words of a tuk and switch to another. Usually, these would only be the ending words of a verse i.e. they might sing
# 'kaaj savare' multiple times. But it is unlikely that they will sing the first few words of a verse multiple times without saying the next ones
# in the verse.
# 2. The 3-word method will fail when very similar lines are being spoken. For example, in Sukhmani Sahib, there are multiple lines starting with
#    'prabh ke simran'. Similarly, in Bhai Nand Lal's vaaran, there are repeated lines like 
#     'Nasro Mansoor Gur Gobind Singh,... Hak hak aaena Gur Gobind Singh' etc. where 'Gur Gobind Singh' is repeated. In such instances, picking 
#      one verse with the highest cosine similarity will fail. To fix this, we have to have a logic along the lines of:
#    a) Find cosine similarity with all verses
#    b) Pick the one with the highest similarity, provided that the cosine similarity of the highest one is above a certain threshold AND the difference
#       between the highest and second highest is more than a separate threshold. If both of these conditions aren't satisfied, wait to hear more words
#       and try verse matching with a larger dataset.
# Make the initial shabad identification method more robust if need be

**Automatic Speech Recognition**

In [None]:
# TODO:
# Try https://huggingface.co/gagan3012/wav2vec2-xlsr-punjabi
# Need to search more for other models. Google speech to text is amazing and we can leverage their $300 free credits as well.