In [1]:
%env CUDA_VISIBLE_DEVICES= 3,5,7

env: CUDA_VISIBLE_DEVICES=3,5,7


In [2]:
import os
os.environ['HF_HOME'] = '/home/sofia/cache_custom'

In [10]:
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransToolkit import IndicProcessor
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm import tqdm
from torch.nn.functional import softmax


BATCH_SIZE = 16 # edited from 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None
print(DEVICE)

cuda


In [4]:
import importlib
import possible_indic_relations as poss_indic_rel
# Reload the module to reflect changes
importlib.reload(poss_indic_rel)

pir= poss_indic_rel.possible_relations
pir

ambiguos_words = list(pir.keys())

In [5]:
def initialize_model_and_tokenizer(ckpt_dir, quantization):
    if quantization == "4-bit":
        qconfig = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "8-bit":
        qconfig = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_use_double_quant=True,
            bnb_8bit_compute_dtype=torch.bfloat16,
        )
    else:
        qconfig = None

    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        ckpt_dir,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        quantization_config=qconfig,
    )

    if qconfig == None:
        model = model.to(DEVICE)
        if DEVICE == "cuda":
            model.half()

    model.eval()

    return tokenizer, model

In [6]:
en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"  # ai4bharat/indictrans2-en-indic-dist-200M
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir,  quantization)

ip_en_ind = IndicProcessor(inference=True)

In [7]:
lang_script_list = [
                           'ory_Orya',
                     'pan_Guru', 'ben_Beng', 
                       'mal_Mlym',
                           'mar_Deva', 
                           'tam_Taml', 'guj_Gujr', 
                           'tel_Telu', 'hin_Deva', 
                           'kan_Knda', 
                           ]

In [11]:
# word_trl=[]
span_encodings = {}
for lang in lang_script_list:
#   if lang == 'ory_Orya':
    span_encodings[lang] = {}
    inputs = ambiguos_words
    for i in tqdm(range(0, len(inputs), BATCH_SIZE)):
        batch = inputs[i : i + BATCH_SIZE]
        print("Batch:", batch)  

        # batch = ip_en_ind.preprocess_batch(words_ids[lang].keys().tolist(), src_lang=lang, tgt_lang=lang)
        batch = ip_en_ind.preprocess_batch(batch, src_lang='eng_Latn', tgt_lang=lang)
        print("Batch:", batch)
        # # Tokenize the batch and generate input encodings
        inputs = en_indic_tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            # generated_tokens = model.generate(
            outputs = en_indic_model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1, # TODO temp
                output_scores=True,
                output_logits=True,
                return_dict_in_generate=True,

            )
            # print("Length of outputs.logits actual", len(outputs.logits))
            # print("Shape of outputs.logits actual", outputs.logits[0].shape)

            # print("Length of outputs.beam_indices actual", len(outputs.beam_indices))
            # print("Shape of outputs.beam_indices actual", outputs.beam_indices.shape)
            
            outputs.beam_indices = outputs.beam_indices.cpu()
            outputs.logits = tuple(logits.cpu() for logits in outputs.logits)               
        # Decode the generated tokens into text
        generated_tokens = outputs.sequences
        # print("len generated_tokens: ", (generated_tokens[0]).shape)
        print("1st generated token: ", generated_tokens[0])
        vector = generated_tokens.detach().cpu().tolist()
        # print("length of outputs vectors: ", len(vector), len(vector[0]))
        # print("vector of generated_tokens: ", vector)
        print("1st vector: ", vector[0])
        


        with en_indic_tokenizer.as_target_tokenizer():
            decoded_op = en_indic_tokenizer.batch_decode(
                vector,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        print("1st decoded_op: ", decoded_op[0])
        # Postprocess the translations, including entity replacement
        word_trl = ip_en_ind.postprocess_batch(decoded_op, lang=lang)

        print("translations: ", word_trl)
        for word in word_trl:
            word_index = word_trl.index(word)
            if word_index < len(vector):
                span_encodings[lang][word] = vector[word_index]
                # keep the items between '2' and '2' from span_encodings[lang][word]
                start_idx = vector[word_index].index(2)
                end_idx = vector[word_index].index(2, start_idx+1)
                span_encodings[lang][word] = vector[word_index][start_idx+1:end_idx]
            else:
                print(f"Index {word_index} out of range for vector of length {len(vector)}")


span_encodings

  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn ory_Orya grandmother', 'eng_Latn ory_Orya grandfather', 'eng_Latn ory_Orya uncle', 'eng_Latn ory_Orya aunt', 'eng_Latn ory_Orya brother-in-law', 'eng_Latn ory_Orya sister-in-law', 'eng_Latn ory_Orya cousin', 'eng_Latn ory_Orya child', 'eng_Latn ory_Orya nephew', 'eng_Latn ory_Orya niece']


100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


1st generated token:  tensor([    2, 41445,   241,    30,     2,     1,     1,     1,     1],
       device='cuda:0')
1st vector:  [2, 41445, 241, 30, 2, 1, 1, 1, 1]
1st decoded_op:  जेजेमा'
translations:  ["ଜେଜେମା'", 'ଜେଜେବାପା ', 'ମାମୁଁ। ', 'ମାଉସୀ। ', 'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ', 'ଶ୍ୱଶୁର। ', 'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ', 'ଶିଶୁ ', 'ପୁତୁରା ', 'ଭାଣିଜୀ ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn pan_Guru grandmother', 'eng_Latn pan_Guru grandfather', 'eng_Latn pan_Guru uncle', 'eng_Latn pan_Guru aunt', 'eng_Latn pan_Guru brother-in-law', 'eng_Latn pan_Guru sister-in-law', 'eng_Latn pan_Guru cousin', 'eng_Latn pan_Guru child', 'eng_Latn pan_Guru nephew', 'eng_Latn pan_Guru niece']


100%|██████████| 1/1 [00:00<00:00,  2.41it/s]


1st generated token:  tensor([    2, 29498,   640,     2,     1,     1,     1], device='cuda:0')
1st vector:  [2, 29498, 640, 2, 1, 1, 1]
1st decoded_op:  दादी मां 
translations:  ['ਦਾਦੀ ਮਾਂ ', 'ਦਾਦਾ ਜੀ ', 'ਚਾਚਾ ', 'ਮਾਸੀ ਜੀ। ', 'ਭਰਾ-ਸੱਸ ', 'ਭਰਜਾਈ ', 'ਚਚੇਰਾ ਭਰਾ ', 'ਬੱਚਾ ', 'ਭਤੀਜੇ ', 'ਭਤੀਜੀ ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn ben_Beng grandmother', 'eng_Latn ben_Beng grandfather', 'eng_Latn ben_Beng uncle', 'eng_Latn ben_Beng aunt', 'eng_Latn ben_Beng brother-in-law', 'eng_Latn ben_Beng sister-in-law', 'eng_Latn ben_Beng cousin', 'eng_Latn ben_Beng child', 'eng_Latn ben_Beng nephew', 'eng_Latn ben_Beng niece']


100%|██████████| 1/1 [00:00<00:00,  3.09it/s]


1st generated token:  tensor([    2, 48446,   241,     2,     1,     1], device='cuda:0')
1st vector:  [2, 48446, 241, 2, 1, 1]
1st decoded_op:  दिदिमा 
translations:  ['দিদিমা ', 'দাদা। ', 'চাচা ', 'আন্টি। ', 'শ্যালক ', 'শ্যালিকা ', 'চাচাত ভাই। ', 'শিশু। ', 'ভাগ্নে ', 'ভাগ্নি ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn mal_Mlym grandmother', 'eng_Latn mal_Mlym grandfather', 'eng_Latn mal_Mlym uncle', 'eng_Latn mal_Mlym aunt', 'eng_Latn mal_Mlym brother-in-law', 'eng_Latn mal_Mlym sister-in-law', 'eng_Latn mal_Mlym cousin', 'eng_Latn mal_Mlym child', 'eng_Latn mal_Mlym nephew', 'eng_Latn mal_Mlym niece']


100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


1st generated token:  tensor([    2, 18823,  2914,  2826,     2,     1,     1], device='cuda:0')
1st vector:  [2, 18823, 2914, 2826, 2, 1, 1]
1st decoded_op:  मुत्तश्शि 
translations:  ['മുത്തശ്ശി ', 'മുത്തച്ഛൻ ', 'അമ്മാവൻ ', 'അമ്മായി. ', 'ഭാര്യാസഹോദരൻ ', 'ഭാര്യാസഹോദരി ', 'കസിൻ ', 'കുട്ടി. ', 'അനന്തരവൻ ', 'അനന്തരവൾ ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn mar_Deva grandmother', 'eng_Latn mar_Deva grandfather', 'eng_Latn mar_Deva uncle', 'eng_Latn mar_Deva aunt', 'eng_Latn mar_Deva brother-in-law', 'eng_Latn mar_Deva sister-in-law', 'eng_Latn mar_Deva cousin', 'eng_Latn mar_Deva child', 'eng_Latn mar_Deva nephew', 'eng_Latn mar_Deva niece']


100%|██████████| 1/1 [00:00<00:00,  2.55it/s]


1st generated token:  tensor([    2, 32967,     2,     1,     1], device='cuda:0')
1st vector:  [2, 32967, 2, 1, 1]
1st decoded_op:  आजी 
translations:  ['आजी ', 'आजोबा. ', 'काका. ', 'मावशी. ', 'मेहुणा ', 'मेहुणी ', 'चुलत भाऊ ', 'बाळ. ', 'पुतण्या ', 'भाची ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn tam_Taml grandmother', 'eng_Latn tam_Taml grandfather', 'eng_Latn tam_Taml uncle', 'eng_Latn tam_Taml aunt', 'eng_Latn tam_Taml brother-in-law', 'eng_Latn tam_Taml sister-in-law', 'eng_Latn tam_Taml cousin', 'eng_Latn tam_Taml child', 'eng_Latn tam_Taml nephew', 'eng_Latn tam_Taml niece']


100%|██████████| 1/1 [00:00<00:00,  3.06it/s]


1st generated token:  tensor([  2, 511, 956,   2,   1], device='cuda:0')
1st vector:  [2, 511, 956, 2, 1]
1st decoded_op:  पाट्टि 
translations:  ['பாட்டி ', 'தாத்தா. ', 'மாமா ', 'அத்தை ', 'மைத்துனர் ', 'மைத்துனர் ', 'உறவினர். ', 'குழந்தை. ', 'மருமகன் ', 'மருமகள் ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn guj_Gujr grandmother', 'eng_Latn guj_Gujr grandfather', 'eng_Latn guj_Gujr uncle', 'eng_Latn guj_Gujr aunt', 'eng_Latn guj_Gujr brother-in-law', 'eng_Latn guj_Gujr sister-in-law', 'eng_Latn guj_Gujr cousin', 'eng_Latn guj_Gujr child', 'eng_Latn guj_Gujr nephew', 'eng_Latn guj_Gujr niece']


100%|██████████| 1/1 [00:00<00:00,  3.50it/s]


1st generated token:  tensor([    2, 29498,     2,     1,     1,     1], device='cuda:0')
1st vector:  [2, 29498, 2, 1, 1, 1]
1st decoded_op:  दादी 
translations:  ['દાદી ', 'દાદા ', 'કાકા ', 'કાકી ', 'સાળા ', 'સાસુ-સસરા ', 'પિતરાઇ ભાઇ ', 'બાળક ', 'ભત્રીજો ', 'ભત્રીજી ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn tel_Telu grandmother', 'eng_Latn tel_Telu grandfather', 'eng_Latn tel_Telu uncle', 'eng_Latn tel_Telu aunt', 'eng_Latn tel_Telu brother-in-law', 'eng_Latn tel_Telu sister-in-law', 'eng_Latn tel_Telu cousin', 'eng_Latn tel_Telu child', 'eng_Latn tel_Telu nephew', 'eng_Latn tel_Telu niece']


100%|██████████| 1/1 [00:00<00:00,  3.26it/s]


1st generated token:  tensor([   2, 1774, 1476,    2,    1,    1], device='cuda:0')
1st vector:  [2, 1774, 1476, 2, 1, 1]
1st decoded_op:  अम्मम्म 
translations:  ['అమ్మమ్మ ', 'తాతయ్య ', 'అంకుల్ ', 'అత్తగారు ', 'బావమరిది ', 'చెల్లెలు ', 'బంధువు ', 'పిల్లవాడు. ', 'మేనల్లుడు ', 'మేనకోడలు ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn hin_Deva grandmother', 'eng_Latn hin_Deva grandfather', 'eng_Latn hin_Deva uncle', 'eng_Latn hin_Deva aunt', 'eng_Latn hin_Deva brother-in-law', 'eng_Latn hin_Deva sister-in-law', 'eng_Latn hin_Deva cousin', 'eng_Latn hin_Deva child', 'eng_Latn hin_Deva nephew', 'eng_Latn hin_Deva niece']


100%|██████████| 1/1 [00:00<00:00,  3.76it/s]


1st generated token:  tensor([    2, 29498, 12075,     2,     1], device='cuda:0')
1st vector:  [2, 29498, 12075, 2, 1]
1st decoded_op:  दादी माँ 
translations:  ['दादी माँ ', 'दादा जी। ', 'चाचा ', 'चाची ', 'बहनोई ', 'ननद ', 'चचेरा भाई ', 'बच्चा। ', 'भतीजे ', 'भतीजी ']


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: ['grandmother', 'grandfather', 'uncle', 'aunt', 'brother-in-law', 'sister-in-law', 'cousin', 'child', 'nephew', 'niece']
Batch: ['eng_Latn kan_Knda grandmother', 'eng_Latn kan_Knda grandfather', 'eng_Latn kan_Knda uncle', 'eng_Latn kan_Knda aunt', 'eng_Latn kan_Knda brother-in-law', 'eng_Latn kan_Knda sister-in-law', 'eng_Latn kan_Knda cousin', 'eng_Latn kan_Knda child', 'eng_Latn kan_Knda nephew', 'eng_Latn kan_Knda niece']


100%|██████████| 1/1 [00:00<00:00,  3.17it/s]

1st generated token:  tensor([    2,  4565, 35330,     2,     1,     1], device='cuda:0')
1st vector:  [2, 4565, 35330, 2, 1, 1]
1st decoded_op:  अज्जि 
translations:  ['ಅಜ್ಜಿ ', 'ಅಜ್ಜ ', 'ಚಿಕ್ಕಪ್ಪ ', 'ಚಿಕ್ಕಮ್ಮ. ', 'ಸೋದರ ಸಂಬಂಧಿ ', 'ಅತ್ತಿಗೆ ', 'ಸೋದರಸಂಬಂಧಿ ', 'ಮಗು. ', 'ಸೋದರಳಿಯ ', 'ಸೋದರ ಸೊಸೆ ']





{'ory_Orya': {"ଜେଜେମା'": [41445, 241, 30],
  'ଜେଜେବାପା ': [41445, 1007, 1714],
  'ମାମୁଁ। ': [9971, 19212, 6],
  'ମାଉସୀ। ': [30261, 694, 6],
  'ଶ୍ୱଶୁର-ଶ୍ୱଶୁର ': [21405, 699, 22252, 13, 21405, 699, 22252],
  'ଶ୍ୱଶୁର। ': [21405, 699, 22252, 6],
  'ସମ୍ପର୍କୀଯ଼ ଭାଇ। ': [60824, 3991, 6],
  'ଶିଶୁ ': [3442],
  'ପୁତୁରା ': [4300, 5686],
  'ଭାଣିଜୀ ': [980, 9742, 795]},
 'pan_Guru': {'ਦਾਦੀ ਮਾਂ ': [29498, 640],
  'ਦਾਦਾ ਜੀ ': [15588, 613],
  'ਚਾਚਾ ': [34059],
  'ਮਾਸੀ ਜੀ। ': [65770, 613, 6],
  'ਭਰਾ-ਸੱਸ ': [8327, 13, 116, 19, 115],
  'ਭਰਜਾਈ ': [1144, 40842],
  'ਚਚੇਰਾ ਭਰਾ ': [49615, 2656, 8327],
  'ਬੱਚਾ ': [336, 19, 317],
  'ਭਤੀਜੇ ': [39136, 1386],
  'ਭਤੀਜੀ ': [39136, 795]},
 'ben_Beng': {'দিদিমা ': [48446, 241],
  'দাদা। ': [15588, 6],
  'চাচা ': [34059],
  'আন্টি। ': [5745, 102, 6],
  'শ্যালক ': [649, 4692, 75],
  'শ্যালিকা ': [649, 4692, 1525],
  'চাচাত ভাই। ': [35407, 359, 3991, 6],
  'শিশু। ': [3442, 6],
  'ভাগ্নে ': [291, 27065],
  'ভাগ্নি ': [291, 17389]},
 'mal_Mlym': {'മുത്തശ്ശി ': [18823, 2914

In [12]:
# write span_encoding into a file named "span_relations_encodings.json"
import json
with open("span_relations_encodings.json", "w") as f:
    json.dump(span_encodings, f)

    