In [None]:
!pip install torchtext==0.8.0
!pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
!pip install ftfy --quiet
!pip install transformers --quiet 
!pip install sentencepiece --quiet

! wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
! wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json

In [None]:
import pandas as pd
import spacy
from spacy.lang.en import English
from tqdm.notebook import tqdm
from transformers import MarianMTModel, MarianTokenizer
from ftfy import fix_encoding
import ftfy
import json
import warnings
warnings.filterwarnings("ignore")

import torch
if torch.cuda.is_available(): 
   dev = "cuda:0"
else: 
   dev = "cpu" 
print(dev, torch.cuda.get_device_name(0))
device = torch.device(dev)

# Model
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
marian_tokenizer = MarianTokenizer.from_pretrained(model_name)
marian_model = MarianMTModel.from_pretrained(model_name)

In [4]:
nlp = English()
nlp.add_pipe(nlp.create_pipe('sentencizer'))
def chunkstring_spacy(text):
    """
    Segment text and prepare to translation

    Args:
      text: Sentence to be translated
      
    Returns:
      Segmented text.
    """
    chunck_sentences = []
    doc = nlp(str(text))
    for sent in doc.sents:
        chunck_sentences.append('>>pt_br<<' + ' ' + sent.text)
        
    return chunck_sentences

def translate(aux_sent):
    """
    Translate text

    Args:
      aux_sent: Sentence to be translated
      
    Returns:
      Translated text.
    """
    max_length = 512
    num_beams = 1
    sentence = chunkstring_spacy(aux_sent)

    marian_model.to(device)
    marian_model.eval()
    tokenized_text = marian_tokenizer.prepare_seq2seq_batch(sentence, max_length=max_length) 

    translated = marian_model.generate(input_ids=torch.LongTensor(tokenized_text['input_ids']).to(device), 
                                        max_length=max_length, 
                                        num_beams=num_beams, 
                                        early_stopping=True, 
                                        do_sample=False)                        
    tgt_text = [fix_encoding(marian_tokenizer.decode(t, skip_special_tokens=True)) for t in translated]
    return ' '.join(tgt_text)

def insert_dash(string, index, mode):
    """
    Insert special tokens between answer span

    Args:
      string: Dataset context
      index: Position in context
      mode: At the beginning or at the end of the answer
    Returns:
      Dataset context with special tokens
    """
    if mode == True:
        return string[:index] + ' ###### ' + string[index:]  
    else:
        return string[:index] + ' $$$ ' + string[index:]

def get_answer_2(tgt_text):
    """
    Get answer span and its index

    Args:
      tgt_text: Context
      
    Returns:
      Answer span and start answer index
    """
    x = [c for c in tgt_text.split() if c.startswith('##')]
    idx = str(tgt_text).find(x[0]) +len(x[0]) + 1 

    inicio = tgt_text.index(x[0])
    y = [d for d in tgt_text.split() if d.startswith('$$')]
    fim = tgt_text.index(y[0])
    answer = ''.join(tgt_text[inicio+1+len(x[0]):fim])

    return answer, idx 

def remove_token(tgt_text):
    """
    Remove special tokens in the context
    Args:
      tgt_text: Context
      
    Returns:
      Context without special tokens
    """
    context_used = [c for c in tgt_text.split() if not c.startswith('##')]
    context = [d for d in context_used if not d.startswith('$$')]
    context = ' '.join(context) 
    return context

def translate_squad(input):
    """
    Translate SQuAD train set to Portuguese
    Args:
      input: Dataset to be translated

    Returns:
      Json containing the translated dataset.
    """
    input_file = input
    print('Translating SQuAD ...')
    with open(input_file) as f: 
        document = json.load(f) 
        dict_3 = {}
        list_2 = []
        for k in tqdm(range(len(document['data']))):  
            list_1 = []
            dict_2 = {}
            for i in range (len(document['data'][k]['paragraphs'])):         
                final_sent = document['data'][k]['paragraphs'][i]['context']   
                dict_1 = {}           
                list_0 = [] 
                for j in range (len(document['data'][k]['paragraphs'][i]['qas'])):
                        dict_0 = {}    
                        question = document['data'][k]['paragraphs'][i]['qas'][j]['question']
                        ans = document['data'][k]['paragraphs'][i]['qas'][j]['answers'][0]['text']
                        ans_len = len(document['data'][k]['paragraphs'][i]['qas'][j]['answers'][0]['text']) 
                        start_token = document['data'][k]['paragraphs'][i]['qas'][j]['answers'][0]['answer_start']
                        end_token = start_token + (ans_len+7)            
                        aux_sent = insert_dash(final_sent, start_token, True)
                        aux_sent = insert_dash(aux_sent, end_token, False).replace('  ',' ')                     
                        tgt_text = translate(aux_sent)
                        tgt_question = translate(question)
                        try:
                            answer, answer_start = get_answer_2(tgt_text)
                            if len(answer) == 0:
                                pass
                            else:
                                id = document['data'][k]['paragraphs'][i]['qas'][j]['id']                        
                                dict_0 = {'answers': [{'answer_start': answer_start, 'text': answer}], 'question': tgt_question, 'id': id}
                                context_used = remove_token(tgt_text)
                                dict_1 = {'context': context_used, 'qas': [dict_0]}     
                                list_1.append(dict_1) 
                        except:
                            pass   
        
            dict_2 = {'title':document['data'][k]['title'] , 'paragraphs': list_1 }   
            list_2.append(dict_2)
        dict_3 = {'data':list_2, 'version': 1.1}
        with open('/content/squad_translated.json', 'w') as f:
            json.dump(dict_3, f)


In [6]:
input = '/content/train-v1.1.json'
translate_squad(input)