In [7]:
from datapull import pull_data
from datasets import DatasetDict
task = "translation"

In [8]:
english_file_path = "./train.en"
hindi_file_path = "./train.hi"


In [9]:
dataset = pull_data(english_file_path, hindi_file_path,task)

## Train-Test-Valiation Split 
print("Train-Test-Valiation Split : ")
train_test_dataset = dataset.train_test_split(test_size=0.15)
test_valid = train_test_dataset['test'].train_test_split(test_size=0.5)
raw_datasets = DatasetDict({'train': train_test_dataset['train'],
                            'test': test_valid['test'],
                            'valid': test_valid['train']})


Loading English-Hindi Translation Data : 
Train-Test-Valiation Split : 


In [10]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 12750
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1125
    })
    valid: Dataset({
        features: ['translation'],
        num_rows: 1125
    })
})

In [11]:
input_text_list = [list(hi.values())[0] for hi in  raw_datasets['valid']['translation']][:10]
gt_list = [list(hi.values())[1] for hi in  raw_datasets['valid']['translation']][:10]

In [12]:
input_text_list[2]

'We expect this number to rise further.\n'

In [13]:
len(input_text_list)

10

In [14]:
from transformers import AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, AutoTokenizer
model_path = "/home/jupyter/duplicates_detection/intl-duplicates/det_lat/test_folder/en-hi-translation-finetuned-14apr/checkpoint-76500"

# Load the model checkpoint
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

# Define the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)



In [15]:
# Tokenize the input text
input_tokens = tokenizer.batch_encode_plus(input_text_list, max_length=512, truncation=True, return_tensors="pt",padding=True)


In [16]:
tokenizer.decode(input_tokens['input_ids'][1])

'The police have arrested both the accused and they were presented before Court.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'

In [17]:
from transformers import AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, AutoTokenizer

def evaluate_model(model_path, input_text,ground_truth,task):
    print(f"{task.upper() } Evaluation : ")
    
    # Load the model checkpoint
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

    # Define the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Tokenize the input text
    input_tokens = tokenizer.batch_encode_plus(input_text, max_length=512, truncation=True, return_tensors="pt",padding=True)

    # Perform inference
    outputs = model.generate(input_ids=input_tokens["input_ids"], attention_mask=input_tokens["attention_mask"],max_length= 128,early_stopping = True)

    # Decode the generated output tokens
    output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    pred_list = []
    for inp,out,gt in zip(input_text,output_text,ground_truth) :
        print('*'*100)
        print()
        print(f"{task.upper() } Model input : {inp} ")
        print(f"{task.upper() } Model output : {out} ")
        print(f"Ground Truth : {gt} ")
        print()
        print('*'*100)
        pred_list.append(out)
        
    return pred_list

In [18]:
generated_text = evaluate_model(model_path,input_text_list,gt_list,task)



TRANSLATION Evaluation : 




****************************************************************************************************

TRANSLATION Model input : With a very high number of COVID patients recovering every day, Indias steady trend of posting high level of daily recoveries continues.89,154 recoveries have been registered in the last 24 hours in the country
 
TRANSLATION Model output : इस साल जून में कोरोना वायरस के अनुसार, देश में अब तक 7 से अधिक लोग ठीक हुए हैं, जबकि चीन के एक नए मरीज हैं, जो भारत में अब तक कुल कोरोना से ठीक हो चुके हैं। इसके साथ ही देश में अब तक कुल 29 लोगों की मौत हो गई है। वहीं भारतीय दंड संहिता,118,51,51,118,51,51,118,51,51,118,51,51,118,51,114 लाख से अब तक देश में अब तक देश में अब तक देश 
Ground Truth : हर दिन ठीक होने वाले कोविड रोगियों की अधिक संख्या के साथ ही भारत में रोजाना अधिक संख्या में लगातार रिकवरी  भी जारी है
 

****************************************************************************************************
**************************************************************

## Evaluation Metrics

In [19]:
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate import bleu_score
import nltk
from rouge import Rouge
from pycocoevalcap.cider.cider import Cider
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/jupyter/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [20]:
def calculate_bleu_scores(actual, generated):
    smoothie = SmoothingFunction().method4
    actual_tokenized = [nltk.word_tokenize(group) for group in actual]
    generated_tokenized = [nltk.word_tokenize(gen) for gen in generated]
    
    score = corpus_bleu(actual_tokenized, generated_tokenized, smoothing_function=smoothie)
    return score

def calculate_rouge_scores(actual, generated):
    rouge = Rouge()
    scores = rouge.get_scores(generated, actual)
    return scores

def calculate_cider_scores(actual, generated):
    act_dict = {idx: [line] for idx, line in enumerate(actual)}
    gen_dict = {idx: [line] for idx, line in enumerate(generated)}
    cider = Cider()
    (score, scores) = cider.compute_score(act_dict, gen_dict)
    return score

In [22]:
calculate_bleu_scores(gt_list, generated_text)

0.0006194463600572004

In [23]:
calculate_rouge_scores(gt_list, generated_text)

[{'rouge-1': {'r': 0.38095238095238093,
   'p': 0.18604651162790697,
   'f': 0.24999999559082037},
  'rouge-2': {'r': 0.09090909090909091,
   'p': 0.037037037037037035,
   'f': 0.05263157483379533},
  'rouge-l': {'r': 0.23809523809523808,
   'p': 0.11627906976744186,
   'f': 0.15624999559082045}},
 {'rouge-1': {'r': 0.3333333333333333,
   'p': 0.3333333333333333,
   'f': 0.3333333283333334},
  'rouge-2': {'r': 0.058823529411764705, 'p': 0.05, 'f': 0.05405404908692522},
  'rouge-l': {'r': 0.2777777777777778,
   'p': 0.2777777777777778,
   'f': 0.2777777727777779}},
 {'rouge-1': {'r': 0.18181818181818182,
   'p': 0.07692307692307693,
   'f': 0.10810810392987599},
  'rouge-2': {'r': 0.1, 'p': 0.025, 'f': 0.03999999680000026},
  'rouge-l': {'r': 0.18181818181818182,
   'p': 0.07692307692307693,
   'f': 0.10810810392987599}},
 {'rouge-1': {'r': 0.375, 'p': 0.13636363636363635, 'f': 0.19999999608888894},
  'rouge-2': {'r': 0.14285714285714285,
   'p': 0.03333333333333333,
   'f': 0.054054050

In [24]:
calculate_cider_scores(gt_list, generated_text)

2.819146274310125e-13