In [4]:
# Use a pipeline as a high-level helper
from transformers import pipeline
import torch
import torch.nn.functional as F
import pandas as pd
import chromadb

  from .autonotebook import tqdm as notebook_tqdm


In [128]:
pipe = pipeline("feature-extraction", model="./mitra-english-similarity")

### Inference Example:

In [112]:
q1 = """tapaḥsvādhyāyanirataṃ tapasvī vāgvidāṃ varam nāradaṃ paripapraccha vālmīkir munipuṃgavam"""

q2 =  """tapaḥ svādhyāya niratam tapasvī vāc vidām varam nāradam paripapraccha vālmīkiḥ muni puṃgavam"""

In [113]:
o1 = torch.tensor(pipe.predict(q1)[0])
o2 = torch.tensor(pipe.predict(q2)[0])

ao1 = torch.mean(o1, dim=0) 
ao2 = torch.mean(o2, dim=0) 

cosine_similarity = F.cosine_similarity(ao1.unsqueeze(0), ao2.unsqueeze(0), dim=1)
print(f"Cosine Similarity: {cosine_similarity.item()}")

Cosine Similarity: 0.9583775997161865


### Load Ramayana Data:

In [5]:
with open('sa_rAmAyaNa.txt', 'r', encoding= 'utf-8') as f:
  text = f.readlines()

data, codes, buffer = [], [], []
for line in text:

  buffer.append(line.strip())
  if 'R_' in line:

    line_r = line.split(' ')
    code = line_r[-1]
    line_r = line_r[:-1]
    line_r = ' '.join(line_r)
    codes.append(code)
    buffer[-1] = line_r

    data.append(' '.join(buffer))
    buffer = []

data = [line.strip() for line in data]
codes = [line.strip() for line in codes]

assert len(data) > 18000, 'Incorrect Preprocessing'
assert len(data) == len(codes), 'Incorrest Preprocessing'

print(len(data))

18761


In [7]:
# Create a ChromaDB client and store all embeddings in there

client = chromadb.Client()
collection = client.create_collection(name="sa_VR")

In [10]:
embeds = torch.load('Ramayana_Embeds.pt')
embeds.shape

torch.Size([18761, 1024])

In [17]:
for i in range(18761):

    collection.add(
        ids = codes[i],
        documents= data[i],
        embeddings= embeds[i, :].tolist()
    )

In [198]:
# Lets
N_RESULTS = 10
N_RESULTS += 1
seed = torch.randint(0, 18761, size = (1, )).item()

true_data = data[seed]
true_code = codes[seed]

print("Query Verse:\n\n", f"{true_code}: {true_data}")
print('-------------------------------')
true_embed = embeds[seed, :]

# print("Query Embed is:\n\n", true_embed)
# print('-------------------------------')

results = collection.query(
  query_embeddings= true_embed.tolist(),
  n_results=N_RESULTS
)

for i in range(N_RESULTS):
    print(f"{i}: {results['ids'][0][i]}: Distance:{results['distances'][0][i]:.2f}\n\t\t{results['documents'][0][i]}")

Query Verse:

 R_3,036.008: bālo 'py eṣa mahātejāḥ samarthas tasya nigrahe gamiṣye rāmam ādāya svasti te 'stu paraṃtapaḥ
-------------------------------
0: R_3,036.008: Distance:0.00
		bālo 'py eṣa mahātejāḥ samarthas tasya nigrahe gamiṣye rāmam ādāya svasti te 'stu paraṃtapaḥ
1: R_3,011.021: Distance:144.11
		evam uktvā mahābāhur agastyaṃ sūryavarcasaṃ jagrāha paramaprītas tasya pādau paraṃtapaḥ
2: R_4,012.001: Distance:146.68
		etac ca vacanaṃ śrutvā sugrīveṇa subhāṣitam pratyayārthaṃ mahātejā rāmo jagrāha kārmukam
3: R_7,004.022: Distance:148.65
		saṃdhyāyās tanayāṃ labdhvā vidyutkeśo niśācaraḥ ramate sa tayā sārdhaṃ paulomyā maghavān iva
4: R_5,033.080: Distance:149.39
		hate 'sure saṃyati śambasādane kapipravīreṇa maharṣicodanāt tato 'smi vāyuprabhavo hi maithili prabhāvatas tatpratimaś ca vānaraḥ
5: R_6,083.021: Distance:151.61
		tasya tad vacanaṃ śrutvā mahāpārśvo 'bravīd vacaḥ balādhyakṣān sthitāṃs tatra balaṃ saṃtvaryatām iti
6: R_2,077.023: Distance:154.46
		niveśya gaṅgām an