# Trying to do some translation

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time

In [2]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

In [3]:
print(device)

cuda:0


## Define Model

In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# model_size = "t5-small"
model_size = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_size)
model = T5ForConditionalGeneration.from_pretrained(model_size).to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("translation_en_to_de", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

In [5]:
task_specific_params.get("translation_en_to_de", {}).get("max_length"), model.config.prefix

(300, 'translate English to German: ')

In [6]:
# model.config

## Define Pytorch Dataset

In [7]:
def read_files(name):
    article_path = "data/%s/article" % name
    highlights_path = "data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles):
        self.x = articles
        
    def __getitem__(self, index):
        sentences = split_in_sentences(self.x[index]) 
        ret_x = []
        for i, sent in enumerate(sentences): 
            if i == 0 and sent[:2] == "By":
                print(sent)
            else:
                x = tokenizer.encode_plus(model.config.prefix + self.transfrom(sent), max_length=100, return_tensors="pt", pad_to_max_length=True)
                ret_x.append((x['input_ids'].view(-1), x['attention_mask'].view(-1)))
        
        return ret_x
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [9]:

from segtok.segmenter import split_single

def split_in_sentences(text):
    return split_single(text)

def transfrom(x):
    x = x.lower()
    x = re.sub("'(.*)'", r"\1", x)
    return x

In [10]:
ds_name = "train"
articles, highlights = read_files(ds_name)

In [11]:
articles_ds = MyDataset(articles)
highlights_ds = MyDataset(highlights)

In [12]:
articles_loader = torch.utils.data.DataLoader(articles_ds, batch_size=1)
highlights_loader = torch.utils.data.DataLoader(highlights_ds, batch_size=1)

In [13]:
class FileWriter:
    def __init__(self, ds_name, name, path="data/"):
        self.file = Path(path + ds_name + "/"+ name + "_german").open("w")
        
    def write_translated(self, list_str):    
        result = ""
        for item_str in list_str:
            result += item_str + " "
            
            
        self.file.write(result.replace("\n", " ") + "\n")
        self.file.flush()

In [14]:
def translate(loader, ds_name, name):
    file_writer = FileWriter(ds_name, name)
    for i, x_list in enumerate(loader):
        if i % 100 == 0:
            print(i)
        
        predictions = []
        for x, x_mask in x_list:
            x = x.to(device)
            x_mask = x_mask.to(device)
            translations = model.generate(input_ids=x, attention_mask=x_mask)
            pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations]
            predictions.append(pred[0])  
        file_writer.write_translated(predictions)
          
translate(articles_loader, "train", "articles")

By . Associated Press . PUBLISHED: . 14:11 EST, 25 October 2013 . | . UPDATED: . 15:36 EST, 25 October 2013 . The bishop of the Fargo Catholic Diocese in North Dakota has exposed potentially hundreds of church members in Fargo, Grand Forks and Jamestown to the hepatitis A virus in late September and early October.
0
By . Daily Mail Reporter . PUBLISHED: . 01:15 EST, 30 November 2013 . | . UPDATED: . 01:23 EST, 30 November 2013 . More than two decades after Magic Johnson announced that he had HIV, the basketball player says he is still surprised at the impact the news had.
By . Daily Mail Reporter . This is the moment a train announcer stunned passengers by announcing over a tannoy as they pulled into a station to beware of pickpockets and gipsies.
By . Ellie Zolfagharifard . Take a look at a map today, and you’re likely to see that North America is larger than Africa, Alaska is larger than Mexico and China is smaller than Greenland.
By . Margot Peppers . Nigerian and Cameroonian pop st

KeyboardInterrupt: 