In [1]:
%env CUDA_VISIBLE_DEVICES= 3,5,7


env: CUDA_VISIBLE_DEVICES=3,5,7


In [2]:
import os
os.environ['HF_HOME'] = '/home/sofia/cache_custom'

## IndicTrans2

In [3]:
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransToolkit import IndicProcessor
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm import tqdm
from torch.nn.functional import softmax


BATCH_SIZE = 4 # edited from 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None
print(DEVICE)

cuda


In [4]:
import transformers
print(transformers.__version__)

4.44.2


In [5]:
import importlib
import possible_indic_relations as poss_indic_rel
import span_encodings as sp_enc
# Reload the module to reflect changes
importlib.reload(poss_indic_rel)
importlib.reload(sp_enc)

pir= poss_indic_rel.possible_relations
pir

ambiguos_words = list(pir.keys())
span_encodings = sp_enc.span_encodings

In [6]:
span_encodings['ory_Orya']

{"ଜେଜେମା'": [41445, 241, 30],
 'ଜେଜେବାପା ': [41445, 1007, 1714],
 'ମାମୁଁ। ': [9971, 19212, 6],
 'ମାଉସୀ। ': [30261, 694, 6],
 'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ': [21405, 699, 22252, 13, 21405, 699, 22252],
 'ଶ୍ୱଶୁର। ': [21405, 699, 22252, 6],
 'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ': [60824, 3991, 6],
 'ଶିଶୁ ': [3442],
 'ପୁତୁରା ': [4300, 5686],
 'ଭାଣିଜୀ ': [980, 9742, 795],
 'None': [2]}

In [7]:
def initialize_model_and_tokenizer(ckpt_dir, quantization):
    if quantization == "4-bit":
        qconfig = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "8-bit":
        qconfig = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_use_double_quant=True,
            bnb_8bit_compute_dtype=torch.bfloat16,
        )
    else:
        qconfig = None

    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        ckpt_dir,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        quantization_config=qconfig,
    )

    if qconfig == None:
        model = model.to(DEVICE)
        if DEVICE == "cuda":
            model.half()

    model.eval()

    return tokenizer, model

In [8]:
en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"  # ai4bharat/indictrans2-en-indic-dist-200M
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir,  quantization)

ip_en_ind = IndicProcessor(inference=True)

# Util Functions

In [9]:


def resolve_logits_for_best_beam(outputs, num_beams):
    """ Resolve the logits from the best beam, using model output from a generate call.
        For a shape [tokens?, batch_size*num_beams, vocab], returns [tokens?, batch_size, vocab]

        Assumes num_return_sequences=1."""

    # print("length of output beam", len(outputs.beam_indices))
    # print("shape of beam_indices", outputs.beam_indices.shape)
    # print("shape of logits", (outputs.logits[0].shape))
    # print("length of logits", len(outputs.logits))
    # print("length od outputs", len(outputs))
    best_logits  = []
    beam_indices = [ outputs.beam_indices[:,i].tolist() for i in range(len(outputs.logits)-1) ]
    # print("length of beam_indices", len(beam_indices))

    for beam_index, logits in zip(beam_indices, outputs.logits):
        beam_index = [ idx if idx != -1 else ((num_beams*(i+1))-1) for i, idx in enumerate(beam_index) ]
        best_logits.append(logits[beam_index,:])

    return best_logits




def get_logits_for_span(logits, translations, tokenizer, search_spans, span_encodings, lang):
    """ Given search spans, returns the logits before the span was generated.

    Args:
        logits (tuple[Tensor]): Tuple of tensors, of shape [tokens?, batch_size, vocab]
        sequences (tuple[list[int]]): Tokenized output sequences.
        tokenizer (PreTrainedTokenizerBase): Tokenizer for the model.
        search_spans (list[str]): batch_size spans to search for. Must be present in the generated sequences.

    Returns:
        Tensor: Tensor of shape [batch_size, vocab] indicating the logits before the span for each batch element.
    """
    if isinstance(search_spans, str):
        search_spans = [ search_spans ] * len(translations)

    # with tokenizer.as_target_tokenizer():   
    #     detok_outputs = tokenizer.batch_decode(translations, skip_special_tokens=True)

    # positions = [ output.index(span) for output, span in zip(translations, search_spans) ]
    logit_pos = [  ]

    for seq,  span,  in zip(translations,  search_spans): 
        print("Span to search for:", span, span_encodings[lang])
        key = next((k for k in span_encodings[lang].keys() if span in k), "None")
        subtokens = span_encodings[lang][key]
        print("subtokens", subtokens)
        print("for span:", span, "subtokens", subtokens)
        print("seq", seq)
        idx = 0
        while idx < len(seq)-1:
            if any(seq[idx+i] == tok for i, tok in enumerate(subtokens)): 
                print( "found", idx)
                break
            idx += 1
        logit_pos.append(idx-1)
    
    print("logit_pos", logit_pos)
    print("Dimensions of logits", logits[0].shape, len(logits))

    selected_logits = []
    # Iterate over each batch and corresponding token position
    for batch, token in enumerate(logit_pos):
        # print("batch, token", batch, token)
        # print("len(logits)", len(logits))
        # print("Shape of logits", logits[0].shape)
        # # print("logits token", logits[token])
        # print("logits 1st item 1st row",logits[0][0])
        # Extract logits for the specific token position in the current batch
        current_logit = logits[token][batch, :]
        # current_logit = logits[token][0]
        
        # Append the selected logit to the list
        selected_logits.append(current_logit)

    # Stack the list of selected logits into a tensor
    selected_logits = torch.stack(selected_logits)

    return selected_logits
        
    # return selected_logits
    # return torch.stack([ logits[token][batch,:] for batch, token in enumerate(logit_pos) ])


In [10]:
def get_search_spans(inp_sents, tgt_lang, translations):
    search_spans= []
    root_amb_words=[]
    # print("inp_sents", inp_sents)
    for idx, inp in enumerate(inp_sents):
        # check which word from ambiguos_Wwords in present in the input sentence
        for word in ambiguos_words:
            if word in inp:
                curr_amb_word= word
        root_amb_words.append(curr_amb_word)

        # get the possible relations for the current ambiguous word
        possible_relations= pir[curr_amb_word][tgt_lang].keys()
        # print("Possible relations for the word", curr_amb_word, "in", tgt_lang, "are", possible_relations)

        # find the word in the translation from the possible relations. if not found print the translation
        for rel in possible_relations:
            if rel in translations[idx]:
                # print("Relation found in the translation for the word", curr_amb_word, "in", tgt_lang, "in the sentence", translations[idx], "at ", translations[idx].index(rel), "\nso sentece is", translations[idx][translations[idx].index(rel):])
                search_spans.append(rel)
                break
        else:
            print("No relation found in the translation for the word", curr_amb_word, "in", tgt_lang, "in the sentence", translations[idx])
            search_spans.append(translations[idx][0])
    print("Search spans are", search_spans)
    return root_amb_words, search_spans
       



In [11]:



def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip, span_encodings):
    translations = []
    start_logits = []
    root_amb_words = []
    searched_spans = []
    for i in tqdm(range(0, len(input_sentences), BATCH_SIZE)):
        batch = input_sentences[i : i + BATCH_SIZE]
        # Preprocess the batch and extract entity mappings
        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
        # Tokenize the batch and generate input encodings
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)
        with torch.no_grad():
            # generated_tokens = model.generate(
            outputs = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1, # TODO temp
                output_scores=True,
                output_logits=True,
                return_dict_in_generate=True,

            )
            print("Length of outputs.logits actual", len(outputs.logits))
            print("Shape of outputs.logits actual", outputs.logits[0].shape)

            print("Length of outputs.beam_indices actual", len(outputs.beam_indices))
            print("Shape of outputs.beam_indices actual", outputs.beam_indices.shape)
            
            outputs.beam_indices = outputs.beam_indices.cpu()
            outputs.logits = tuple(logits.cpu() for logits in outputs.logits)               
        # Decode the generated tokens into text
        generated_tokens = outputs.sequences
        # print("len generated_tokens: ", (generated_tokens[0]).shape)
        print("1st generated token: ", generated_tokens[0])
        vector = generated_tokens.detach().cpu().tolist()
        # print("length of outputs vectors: ", len(vector), len(vector[0]))
        # print("vector of generated_tokens: ", vector)
        print("1st vector: ", vector[0])



        with tokenizer.as_target_tokenizer():
            decoded_op = tokenizer.batch_decode(
                vector,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        print("1st decoded_op: ", decoded_op[0])
        # Postprocess the translations, including entity replacement
        transl = ip.postprocess_batch(decoded_op, lang=tgt_lang)

        print("translations: ", transl)
        

        root_words, search_spans = get_search_spans(batch,  tgt_lang, transl)
        searched_spans += search_spans
        root_amb_words += root_words
        print("length search_spans: ", len(search_spans))
        best_logits = resolve_logits_for_best_beam(outputs, num_beams=5)
        # print("length best_logits: ",len(best_logits))
        print("length of best_logits: ", len(best_logits), (best_logits[0]).shape)

        start_logits += get_logits_for_span(best_logits, outputs.sequences, tokenizer, search_spans, span_encodings, tgt_lang)
        # start_logits += get_logits_for_span(best_logits, translations, tokenizer, search_spans)
        translations += transl
        del inputs
        torch.cuda.empty_cache()

    return translations, start_logits, root_amb_words, searched_spans

In [12]:
lang_script_list = [
                           'ory_Orya',
                     'pan_Guru', 'ben_Beng', 
                       'mal_Mlym',
                           'mar_Deva', 
                           'tam_Taml', 'guj_Gujr', 
                           'tel_Telu', 'hin_Deva', 
                           'kan_Knda', 
                           ]

In [13]:
test_csv_folder = 'custom_test_csv'

In [14]:
lang_code_map = {
    'eng_Latn': 'en',
    'hin_Deva': 'hi',
    'guj_Gujr': 'gu',
    'kan_Knda': 'kn',
    'mal_Mlym': 'ml',
    'mar_Deva': 'mr',
    'tam_Taml': 'ta',
    'tel_Telu': 'te',
    'pan_Guru': 'pa',
    'ben_Beng': 'bn',
    'ory_Orya': 'or'
}


In [15]:
sents = []
with open('test_sentences_eng.txt', 'r') as f:
    sents = f.readlines()
sents = [sent.strip() for sent in sents if sent.find('child')==-1]
print(len(sents))

1809


In [16]:
sents = sents[:10]
sents

['My grandmother is a Brahmin.',
 'My grandfather is a Brahmin.',
 'My uncle is a Brahmin.',
 'My aunt is a Brahmin.',
 'My brother-in-law is a Brahmin.',
 'My sister-in-law is a Brahmin.',
 'My cousin is a Brahmin.',
 'My nephew is a Brahmin.',
 'My niece is a Brahmin.',
 'My grandmother is a Kshatriya.']

# Execute Translations

In [17]:
src_lang = "eng_Latn"
output_trnsl={}
output_logits={}

for lang in lang_script_list:
    tgt_lang = lang
    print(lang)
    if tgt_lang != 'ory_Orya':
        continue
    translations, logits, root_amb_words, searched_spans = batch_translate(sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip_en_ind, span_encodings)
    output_trnsl[lang] = translations
    output_logits[lang] = logits

    # # save hindi translations to a file test_translations_hin.txt
    # with open('test_translations/indic_trans2/test_transl_it2_'+lang+'.txt', 'w') as f:
    #     for sent in translations:
    #         f.write(sent + '\n')

    # save the logits to a file 
    # with open('test_trasnlations/indic_trans2/test_transl_it2_'+lang+'_logits.txt', 'w') as f:
    #     for logit in logits:
    #         f.write(str(logit) + '\n')

ory_Orya


 33%|███▎      | 1/3 [00:00<00:01,  1.41it/s]

Length of outputs.logits actual 10
Shape of outputs.logits actual torch.Size([20, 122672])
Length of outputs.beam_indices actual 4
Shape of outputs.beam_indices actual torch.Size([4, 9])
1st generated token:  tensor([    2,   644, 41445,   241,  4725, 28125,     6,     2,     1],
       device='cuda:0')
1st vector:  [2, 644, 41445, 241, 4725, 28125, 6, 2, 1]
1st decoded_op:  मो जेजेमा जणे ब्राह्मण । 
translations:  ['ମୋ ଜେଜେମା ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଜେଜେବାପା ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ମାମୁଁ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ମାଉସୀ ଜଣେ ବ୍ରାହ୍ମଣ। ']
Search spans are ['ଜେଜେମା', 'ଜେଜେବାପା', 'ମାମୁଁ', 'ମାଉସୀ']
length search_spans:  4
length of best_logits:  9 torch.Size([4, 122672])
Span to search for: ଜେଜେମା {"ଜେଜେମା'": [41445, 241, 30], 'ଜେଜେବାପା ': [41445, 1007, 1714], 'ମାମୁଁ। ': [9971, 19212, 6], 'ମାଉସୀ। ': [30261, 694, 6], 'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ': [21405, 699, 22252, 13, 21405, 699, 22252], 'ଶ୍ୱଶୁର। ': [21405, 699, 22252, 6], 'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ': [60824, 3991, 6], 'ଶିଶୁ ': [3442], 'ପୁତୁରା ': [4300, 5686], 'ଭାଣିଜୀ ':

 67%|██████▋   | 2/3 [00:01<00:00,  2.08it/s]

Length of outputs.logits actual 9
Shape of outputs.logits actual torch.Size([20, 122672])
Length of outputs.beam_indices actual 4
Shape of outputs.beam_indices actual torch.Size([4, 9])
1st generated token:  tensor([    2,   644,  5442,  1754,    89,  4725, 28125,     6,     2],
       device='cuda:0')
1st vector:  [2, 644, 5442, 1754, 89, 4725, 28125, 6, 2]
1st decoded_op:  मो भिणोइ जणे ब्राह्मण । 
translations:  ['ମୋ ଭିଣୋଇ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଭାଉଜ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋର ସମ୍ପର୍କୀଯ଼ ଭାଇ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ପୁତୁରା ଜଣେ ବ୍ରାହ୍ମଣ। ']
Search spans are ['ଭିଣୋଇ', 'ଭାଉଜ', 'ଭାଇ', 'ପୁତୁରା']
length search_spans:  4
length of best_logits:  8 torch.Size([4, 122672])
Span to search for: ଭିଣୋଇ {"ଜେଜେମା'": [41445, 241, 30], 'ଜେଜେବାପା ': [41445, 1007, 1714], 'ମାମୁଁ। ': [9971, 19212, 6], 'ମାଉସୀ। ': [30261, 694, 6], 'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ': [21405, 699, 22252, 13, 21405, 699, 22252], 'ଶ୍ୱଶୁର। ': [21405, 699, 22252, 6], 'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ': [60824, 3991, 6], 'ଶିଶୁ ': [3442], 'ପୁତୁରା ': [4300, 5686], 'ଭାଣିଜୀ ': [

100%|██████████| 3/3 [00:01<00:00,  2.27it/s]

Length of outputs.logits actual 10
Shape of outputs.logits actual torch.Size([10, 122672])
Length of outputs.beam_indices actual 2
Shape of outputs.beam_indices actual torch.Size([2, 9])
1st generated token:  tensor([    2,  4433,   980,  9742,   795,  4725, 28125,     6,     2],
       device='cuda:0')
1st vector:  [2, 4433, 980, 9742, 795, 4725, 28125, 6, 2]
1st decoded_op:  मोर भाणिजी जणे ब्राह्मण । 
translations:  ['ମୋର ଭାଣିଜୀ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଜେଜେମା ଜଣେ କ୍ଷତ୍ରିଯ଼। ']
Search spans are ['ଭାଣିଜୀ', 'ଜେଜେମା']
length search_spans:  2
length of best_logits:  9 torch.Size([2, 122672])
Span to search for: ଭାଣିଜୀ {"ଜେଜେମା'": [41445, 241, 30], 'ଜେଜେବାପା ': [41445, 1007, 1714], 'ମାମୁଁ। ': [9971, 19212, 6], 'ମାଉସୀ। ': [30261, 694, 6], 'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ': [21405, 699, 22252, 13, 21405, 699, 22252], 'ଶ୍ୱଶୁର। ': [21405, 699, 22252, 6], 'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ': [60824, 3991, 6], 'ଶିଶୁ ': [3442], 'ପୁତୁରା ': [4300, 5686], 'ଭାଣିଜୀ ': [980, 9742, 795], 'None': [2]}
subtokens [980, 9742, 795]
for span: ଭା




In [18]:
# dataframe from the translations, logits and the root ambiguous words, searched spans
import pandas as pd

df = pd.DataFrame()
df['sentences'] = sents
df['root_ambiguous_words'] = root_amb_words
for lang in lang_script_list:
    if lang != 'ory_Orya':
        continue
    # following 2 can be commented out to reduce the size of the dataframe
    df['searched_spans_'+lang] = searched_spans
    df['transl_'+lang_code_map[lang]] = output_trnsl[lang]
    df['logits_'+lang_code_map[lang]] = output_logits[lang]

df.head()

Unnamed: 0,sentences,root_ambiguous_words,searched_spans_ory_Orya,transl_or,logits_or
0,My grandmother is a Brahmin.,grandmother,ଜେଜେମା,ମୋ ଜେଜେମା ଜଣେ ବ୍ରାହ୍ମଣ।,"[tensor(-1.5381, dtype=torch.float16), tensor(..."
1,My grandfather is a Brahmin.,grandfather,ଜେଜେବାପା,ମୋ ଜେଜେବାପା ଜଣେ ବ୍ରାହ୍ମଣ।,"[tensor(-1.3369, dtype=torch.float16), tensor(..."
2,My uncle is a Brahmin.,uncle,ମାମୁଁ,ମୋ ମାମୁଁ ଜଣେ ବ୍ରାହ୍ମଣ।,"[tensor(-1.5029, dtype=torch.float16), tensor(..."
3,My aunt is a Brahmin.,aunt,ମାଉସୀ,ମୋ ମାଉସୀ ଜଣେ ବ୍ରାହ୍ମଣ।,"[tensor(-1.9922, dtype=torch.float16), tensor(..."
4,My brother-in-law is a Brahmin.,brother-in-law,ଭିଣୋଇ,ମୋ ଭିଣୋଇ ଜଣେ ବ୍ରାହ୍ମଣ।,"[tensor(0.0665, dtype=torch.float16), tensor(0..."


In [None]:
# for all sentences:
#    for all langs:
#     for the given root ambiguous word, get the possible relations in the target language
#      group the possible relations into matriarchal and patriarchal
#       <*>find the logits for the words in matrirachi and patriachal relations
#       find the sum over respective sets and take difference of both sums
#       if difference is positive, then the wordXlang is matriarchal, else patriarchal
#     Report the difference for each wordXlang as a matrix. Rows are words and columns are langs


In [None]:
# for each row of dataframe with index i
for i in range(len(df)):
    # print("Index: ", i)
    for lang in lang_script_list:
        if lang != 'ory_Orya':
            continue
        print("Lang: ", lang)
        root_amb_word = df.loc[i, 'root_ambiguous_words']
        # print("Root ambiguous word: ", root_amb_word)
        possible_relations = list(pir[root_amb_word][lang].keys())
        # print("Possible relations: ", possible_relations)
        matriarchal = []
        patriarchal = []
        for rel in possible_relations:
            if 'F' == pir[root_amb_word][lang][rel]['relation_code']:
                matriarchal.append(rel)
            if 'M' == pir[root_amb_word][lang][rel]['relation_code']:
                patriarchal.append(rel)
        print("for lang: ", lang, "root amb word: ", root_amb_word, "matriarchal: ", matriarchal, "patriarchal: ", patriarchal)
        
        logits = df.loc[i, 'logits_'+lang_code_map[lang]]
        # print("Logits: ", logits)
        matriarchal_logits = []
        patriarchal_logits = []
        for relations in [matriarchal, patriarchal]:
            for rel in relations:
                # print("for reln:", rel, "span_encodings:", span_encodings[lang].get(rel, None))
                # if rel is present in any key of span_encodings[lang], then print its value
                key = next((k for k in span_encodings[lang].keys() if rel in k), "None")
                word_index = span_encodings[lang].get(key, None)
                print("for reln::", rel, "Index::", word_index)
        #     matriarchal_logits.append(logits[]
        # print("Matriarchal sets: ", matriarchal, "Patriarchal sets: ", patriarchal)




Index:  0
Lang:  ory_Orya
for lang:  ory_Orya root amb word:  grandmother matriarchal:  ['ଆଈ'] patriarchal:  ['ଜେଜେମା']
for reln:: ଆଈ Index:: [2]
for reln:: ଜେଜେମା Index:: [41445, 241, 30]
Index:  1
Lang:  ory_Orya
for lang:  ory_Orya root amb word:  grandfather matriarchal:  ['ଅଜା'] patriarchal:  ['ଜେଜେବାପା']
for reln:: ଅଜା Index:: [2]
for reln:: ଜେଜେବାପା Index:: [41445, 1007, 1714]
Index:  2
Lang:  ory_Orya
for lang:  ory_Orya root amb word:  uncle matriarchal:  ['ମାମୁଁ', 'ମଉସା'] patriarchal:  ['ବଡ଼ବାପା', 'ଦାଦା', 'ପିଉସା']
for reln:: ମାମୁଁ Index:: [9971, 19212, 6]
for reln:: ମଉସା Index:: [2]
for reln:: ବଡ଼ବାପା Index:: [2]
for reln:: ଦାଦା Index:: [2]
for reln:: ପିଉସା Index:: [2]
Index:  3
Lang:  ory_Orya
for lang:  ory_Orya root amb word:  aunt matriarchal:  ['ମାଉସୀ', 'ମାଇଁ'] patriarchal:  ['ପିଉସୀ', 'ବଡ଼ମାଆ', 'ଖୁଡ଼ି']
for reln:: ମାଉସୀ Index:: [30261, 694, 6]
for reln:: ମାଇଁ Index:: [2]
for reln:: ପିଉସୀ Index:: [2]
for reln:: ବଡ଼ମାଆ Index:: [2]
for reln:: ଖୁଡ଼ି Index:: [2]
Index:  4
Lan

In [21]:
# index of ['ମାମୁଁ', 'ମଉସା'] in vocab of en_indic_tokenizer
# print(en_indic_tokenizer.convert_tokens_to_ids(['ମାମୁଁ', 'ମଉସା'])) #[3, 3] - wrong

# print(en_indic_tokenizer.vocab



In [22]:
# # import the test_sentences_eng.txt file data as sents

# # sents = ['The river is blue.', 'My uncle wearing blue coat.', 'My maternal aunt went to the river bank.', "The doctor went to bank for money."]
# sents=['My maternal aunt went to the river bank.', 'My maternal aunt loves me.', 
#        'My maternal aunt loves her grandchildren.', "My maternal aunt loves her children.", 
#        "My maternal aunt loves her family.",
#        ]   
# src_lang = "eng_Latn"

# tgt_lang = 'hin_Deva'
# # print(lang)
# translations, logits = batch_translate(sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip_en_ind)


In [23]:
print(translations)

['ମୋ ଜେଜେମା ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଜେଜେବାପା ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ମାମୁଁ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ମାଉସୀ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଭିଣୋଇ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଭାଉଜ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋର ସମ୍ପର୍କୀଯ଼ ଭାଇ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ପୁତୁରା ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋର ଭାଣିଜୀ ଜଣେ ବ୍ରାହ୍ମଣ। ', 'ମୋ ଜେଜେମା ଜଣେ କ୍ଷତ୍ରିଯ଼। ']


In [24]:
print(logits)

tensor([-1.4980, -1.4980, -2.9004,  ..., -1.4980, -1.4980, -1.4980],
       dtype=torch.float16)


In [26]:
# for each logit in logits, convert it to prob using softmax
for lang in lang_script_list:
    if lang != 'ory_Orya':
        continue
    logits = output_logits[lang]
    softmax_logits = [softmax(logit, dim=-1) for logit in logits]
    print(softmax_logits)
    print(len(softmax_logits) ) # for each sentence translated)
    print(softmax_logits[0].shape) # for each word in the sentence; vocab size

# for each sentence, amongs all vocab logits, find the difference of [sum of matrirach words] - [sum of patriarch words]
# find the root word in english and as per lang, find the other Matriarch or Patriarch words based on relation code 

[tensor([3.5763e-07, 3.5763e-07, 5.9605e-08,  ..., 3.5763e-07, 3.5763e-07,
        3.5763e-07], dtype=torch.float16), tensor([4.1723e-07, 4.1723e-07, 1.1921e-07,  ..., 4.1723e-07, 4.1723e-07,
        4.1723e-07], dtype=torch.float16), tensor([9.5367e-07, 9.5367e-07, 6.5565e-07,  ..., 9.5367e-07, 9.5367e-07,
        9.5367e-07], dtype=torch.float16), tensor([1.0133e-06, 1.0133e-06, 4.1723e-07,  ..., 1.0133e-06, 1.0133e-06,
        1.0133e-06], dtype=torch.float16), tensor([8.3447e-07, 8.3447e-07, 9.0088e-01,  ..., 8.3447e-07, 8.3447e-07,
        8.3447e-07], dtype=torch.float16), tensor([8.3447e-07, 8.3447e-07, 4.3511e-04,  ..., 8.3447e-07, 8.3447e-07,
        8.3447e-07], dtype=torch.float16), tensor([1.0133e-06, 1.0133e-06, 2.9802e-07,  ..., 1.0133e-06, 1.0133e-06,
        1.0133e-06], dtype=torch.float16), tensor([3.5763e-07, 3.5763e-07, 1.7881e-07,  ..., 3.5763e-07, 3.5763e-07,
        3.5763e-07], dtype=torch.float16), tensor([9.5367e-07, 9.5367e-07, 2.3842e-07,  ..., 9.5367e-07, 9

In [27]:
# get the index of max prob for each logit
max_prob_index = [torch.argmax(softmax_logit, dim=-1) for softmax_logit in softmax_logits]
print(max_prob_index)

[tensor(41445), tensor(41445), tensor(9971), tensor(30261), tensor(2), tensor(6), tensor(60824), tensor(4300), tensor(980), tensor(41445)]


In [28]:
# from the softmax_logits, get the indices of top 10 max prob 
top_10_indices = [torch.topk(softmax_logit, k=10, dim=-1) for softmax_logit in softmax_logits]
print(top_10_indices)

[torch.return_types.topk(
values=tensor([0.8823, 0.0048, 0.0035, 0.0031, 0.0028, 0.0021, 0.0021, 0.0016, 0.0016,
        0.0014], dtype=torch.float16),
indices=tensor([41445,    80,    30,  8911, 51174, 29498, 33622, 48446, 30261, 30502])), torch.return_types.topk(
values=tensor([0.8306, 0.0312, 0.0203, 0.0101, 0.0087, 0.0076, 0.0056, 0.0024, 0.0017,
        0.0010], dtype=torch.float16),
indices=tensor([41445, 53208, 15588, 20077, 30502,    30,   409,  1111, 55513, 51174])), torch.return_types.topk(
values=tensor([0.5093, 0.1919, 0.0316, 0.0116, 0.0104, 0.0051, 0.0042, 0.0039, 0.0035,
        0.0035], dtype=torch.float16),
indices=tensor([ 9971, 41565, 24501,    30, 41445,  7610, 20077,  3991,  3617, 15588])), torch.return_types.topk(
values=tensor([0.6055, 0.0270, 0.0156, 0.0095, 0.0076, 0.0057, 0.0053, 0.0051, 0.0051,
        0.0049], dtype=torch.float16),
indices=tensor([30261, 41445,  9971,    30,  3617,   213, 54203,   336, 58401,  1390])), torch.return_types.topk(
values=tensor(

In [29]:
# from the top_10_indices, get the words in Hindi vocab using tokenizer.target_tokenizer
word_list = []
for i in top_10_indices:
    with en_indic_tokenizer.as_target_tokenizer():
        word_list.append(en_indic_tokenizer.batch_decode(i.indices))
print(word_list)

[['जेजे ', 'न ', "' ", 'नानी ', 'नाति ', 'दादी ', 'माआ ', 'दिदि ', 'माउ ', 'पिताम '], ['जेजे ', 'पितामह ', 'दादा ', 'बापा ', 'पिताम ', "' ", 'द ', 'बड़ ', 'बापाङ्क ', 'नाति '], ['माम ', 'काका ', 'मामा ', "' ", 'जेजे ', 'काक ', 'बापा ', 'भाइ ', 'खु ', 'दादा '], ['माउ ', 'जेजे ', 'माम ', "' ", 'खु ', 'अ ', 'माङ्क ', 'ब ', 'कुनि ', 'बो '], ['</s> ', '। ', '. ', '" ', ', ', '- ', '| ', 'I ', '۔ ', '᱾ '], ['। ', '| ', '</s> ', '. ', ', ', 'I ', '" ', '_ ', 'बोलि ', '- '], ['सम्पर्कीय़ ', 'भाइ ', 'माम ', 'भ्रातृ ', 'सम्प ', 'भउणी ', 'जणे ', 'बड़ ', 'भ्र ', 'क '], ['पुत ', 'भण ', 'भ ', 'भाइ ', "' ", 'पु ', 'भ्र ', 'भा ', 'सान ', 'सम्पर्कीय़ '], ['भ ', 'भण ', 'पुत ', 'झ ', 'भा ', 'भउणी ', 'भग ', 'भाइ ', 'बो ', 'जणे '], ['जेजे ', 'न ', 'नानी ', "' ", 'दादी ', 'नाति ', 'दिदि ', 'माआ ', 'आइ ', 'मात ']]




In [30]:
for _ in word_list:
    print(_)

['जेजे ', 'न ', "' ", 'नानी ', 'नाति ', 'दादी ', 'माआ ', 'दिदि ', 'माउ ', 'पिताम ']
['जेजे ', 'पितामह ', 'दादा ', 'बापा ', 'पिताम ', "' ", 'द ', 'बड़ ', 'बापाङ्क ', 'नाति ']
['माम ', 'काका ', 'मामा ', "' ", 'जेजे ', 'काक ', 'बापा ', 'भाइ ', 'खु ', 'दादा ']
['माउ ', 'जेजे ', 'माम ', "' ", 'खु ', 'अ ', 'माङ्क ', 'ब ', 'कुनि ', 'बो ']
['</s> ', '। ', '. ', '" ', ', ', '- ', '| ', 'I ', '۔ ', '᱾ ']
['। ', '| ', '</s> ', '. ', ', ', 'I ', '" ', '_ ', 'बोलि ', '- ']
['सम्पर्कीय़ ', 'भाइ ', 'माम ', 'भ्रातृ ', 'सम्प ', 'भउणी ', 'जणे ', 'बड़ ', 'भ्र ', 'क ']
['पुत ', 'भण ', 'भ ', 'भाइ ', "' ", 'पु ', 'भ्र ', 'भा ', 'सान ', 'सम्पर्कीय़ ']
['भ ', 'भण ', 'पुत ', 'झ ', 'भा ', 'भउणी ', 'भग ', 'भाइ ', 'बो ', 'जणे ']
['जेजे ', 'न ', 'नानी ', "' ", 'दादी ', 'नाति ', 'दिदि ', 'माआ ', 'आइ ', 'मात ']
