In [81]:
#imports
import datasets
from datasets import load_dataset, load_metric
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, DataCollatorForSeq2Seq
import torch
import wandb

In [93]:
model_checkpoint = 't5-small'
#fr_en_raw_dataset = load_dataset('wmt14', 'fr-en')
#ru_en_raw_dataset = load_dataset('wmt14', 'ru-en') # we'll augment this with backtranslation soon

# master dict with all of the languages - setup for multitranslation
# we have the model working for translation between two languages - 
# all that is needed is to change the preprocess_fn to handle the new shape
# of this master dict 
all_langs = datasets.DatasetDict(
    {
        'fr-en' : fr_en_raw_dataset,
        'ru-en' : ru_en_raw_dataset
    }
)
 
# all_langs

fr_en_raw_dataset = load_dataset('wmt14', 'fr-en', split='train[:10]')
ru_en_raw_dataset = load_dataset('wmt14', 'ru-en', split='train[:100]')



all_langs = datasets.DatasetDict(
        {
            'fr-en' : fr_en_raw_dataset,
            'ru-en' : ru_en_raw_dataset
            }
        )


code_to_lang = {
        'fr' : 'French',
        'en' : 'English', 
        'ru' : 'Russian'
        }

Reusing dataset wmt14 (/home/tina/.cache/huggingface/datasets/wmt14/fr-en/1.0.0/6aa64c5c4f2c1c217718c6d6266aad92d1229e761c57379c53752b8c0e55c93b)
Reusing dataset wmt14 (/home/tina/.cache/huggingface/datasets/wmt14/ru-en/1.0.0/6aa64c5c4f2c1c217718c6d6266aad92d1229e761c57379c53752b8c0e55c93b)


In [72]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, truncation=True) 
prefix = "" 
max_input_length = 128
max_target_length = 128
source_lang = 'fr'
target_lang = 'en'


def preprocess_fn(examples):
    inputs = [prefix + ex[source_lang] for ex in examples['translation']]
    targets = [ex[target_lang] for ex in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)
        model_inputs['labels'] = labels['input_ids']
        
    return model_inputs

In [58]:
from datasets import Dataset

t = all_langs['fr-en']['translation'][0]

list(t.keys())[1]

'fr'

In [26]:
# Anyway, the fr_en dataset has over 4 million examples, so we'll just take 5,000 for now
sm_fr_en = load_dataset('wmt14', 'fr-en', split='train[:500]')

# but this only takes the training data, so we need to split it up again:
sm_fr_en = sm_fr_en.train_test_split(test_size=0.2)
sm_fr_en['train']

Reusing dataset wmt14 (/home/tina/.cache/huggingface/datasets/wmt14/fr-en/1.0.0/6aa64c5c4f2c1c217718c6d6266aad92d1229e761c57379c53752b8c0e55c93b)


Dataset({
    features: ['translation'],
    num_rows: 400
})

In [94]:
def set_prefix(ex):
    src, tgt = ex.keys()
    src = code_to_lang[src]
    tgt = code_to_lang[tgt]
    return f"translate from {src} to {tgt}: "

def preprocess_fn(examples):
    inputs = []
    targets = []

    #fr-en:
    inputs.extend([set_prefix(ex) + ex[list(ex.keys())[0]] for ex in examples['translation']])
    targets.extend([ex[list(ex.keys())[1]] for ex in examples['translation']])

    #en-fr:
    inputs.extend([set_prefix(ex) + ex[list(ex.keys())[0]] for ex in examples['translation']])
    targets.extend([ex[list(ex.keys())[1]] for ex in examples['translation']])

    #en-ru:
    inputs.extend([set_prefix(ex) + ex[list(ex.keys())[0]] for ex in examples['translation']])
    targets.extend([ex[list(ex.keys())[1]] for ex in examples['translation']])

    #ru-en:
    inputs.extend([set_prefix(ex) + ex[list(ex.keys())[0]] for ex in examples['translation']])
    targets.extend([ex[list(ex.keys())[1]] for ex in examples['translation']])

    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)
        model_inputs['labels'] = labels['input_ids']


    return model_inputs 

In [95]:
# using the tokenizer, and instantiating our model
sm_tokenized = all_langs.map(preprocess_fn, batched=True, remove_columns=all_langs["fr-en"].column_names)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

Loading cached processed dataset at /home/tina/.cache/huggingface/datasets/wmt14/fr-en/1.0.0/6aa64c5c4f2c1c217718c6d6266aad92d1229e761c57379c53752b8c0e55c93b/cache-babb7a4083aff490.arrow
Loading cached processed dataset at /home/tina/.cache/huggingface/datasets/wmt14/ru-en/1.0.0/6aa64c5c4f2c1c217718c6d6266aad92d1229e761c57379c53752b8c0e55c93b/cache-4f4c0795af5528b9.arrow


In [6]:
# params for training
batch_size=8
learning_rate=1e-4
weight_decay=0.001
num_train_epochs=8
optimizer=torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# collator and dataloaders
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

sm = sm_tokenized.remove_columns(['translation']) # we need to get rid of any cols the model can't work with
train_dataloader = torch.utils.data.DataLoader(sm['train'], shuffle=True, batch_size=batch_size, collate_fn=data_collator)
# valid_dataloader = torch.utils.data.DataLoader(sm['validation'],batch_size=batch_size, collate_fn=data_collator)
test_dataloader = torch.utils.data.DataLoader(sm['test'], batch_size=batch_size, collate_fn=data_collator)

In [7]:
# one last helper fn to deal with postprocessing

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

In [None]:
# ok, training time! 


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
metric = load_metric('sacrebleu')

project_name = ""
eval_every = 500

run = wandb.init(project=f"{project_name}")

model.to(device)
global_step=0

for epoch in range(num_train_epochs):
    model.train()
    for batch in train_dataloader:
        labels = batch['labels'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # make sure padding tokens aren't counted in loss fn
        labels[labels == tokenizer.pad_token_id] = -100 
        
        output = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
        loss = output['loss'] # since we're working with a pretrained model, we can just grab the loss directly 
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=1)
        optimizer.step()
        global_step += 1
        
        wandb.log(
            {
                'train_loss' : loss,
                'epoch' : epoch,
            },
            step=global_step
        )
        
        if global_step % 10 == 0: # log every 10 steps, feel free to change
            # accuracy doesn't matter much for our model, but feel free to look anyway
            # or delete this part, also fine
            preds = output['logits'].argmax(-1)
            label_nonpad_mask = labels != tokenizer.pad_token_id
            num_words_in_batch = label_nonpad_mask.sum().item()
            acc = (preds == labels).masked_select(label_nonpad_mask).sum().item() / num_words_in_batch
            
            wandb.log({
                'train_batch_word_accuracy' : acc
            },
               step=global_step,
            )
            
        if global_step % eval_every == 0: # you can change this at the top of the block
            n_generated_tokens = 0
            model.eval()
            
            for batch in test_dataloader:
                labels = batch['labels'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                generated_ids = model.generate(input_ids=input_ids, num_beams=5, attention_mask=attention_mask)
                
                labels = torch.where(labels != -100, labels, tokenizer.pad_token_id) # fix =100 labels for tokenizer


                decoded_preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
                decoded_labels = [tokenizer.decode(l, skip_special_tokens=True, clean_up_tokenization_space=True) for l in labels]
                
                for pred in decoded_preds:
                    n_generated_tokens += len(tokenizer(pred)['input_ids'])
                    
                decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
                
                metric.add_batch(predictions=decoded_preds, references=decoded_labels)
                
            model.train()
            eval_metric = metric.compute()
            wandb.log({
                'bleu' : eval_metric['score'],
                'gen_len' : n_generated_tokens / len(test_dataloader.dataset),
            },
                step=global_step,
            )
