In [23]:
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer
import sacrebleu
import pandas as pd

In [24]:
# Load pretrained mBART model and tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBartTokenizer.from_pretrained(model_name)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBart50Tokenizer'. 
The class this function is called from is 'MBartTokenizer'.


In [25]:
# Set source and target languages
tokenizer.src_lang = "zh_CN"
tokenizer.tgt_lang = "en_XX"

In [26]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [27]:
# Load parallel sentences (zh-en) and anchor word dictionary
train_df = pd.read_csv("training_data.csv")
parallel_sentences = list(zip(train_df['zh'], train_df['en']))[0:1000]

In [28]:
# Load anchor word dictionary
anchor_word_dict = {}
with open("final_anchors.txt", "r", encoding="utf-8") as f:
    for line in f:
        zh_anchor, en_anchor = line.strip().split()
        anchor_word_dict['<'+zh_anchor+'>'] = '<' + en_anchor + '>'

In [29]:
# Load data
tokens_to_be_added = []
with open('tokens_to_be_added.txt', 'r', encoding='utf-8') as f:
    for line in f:
        tokens_to_be_added.append(line.strip())
        
# Add custom tokens and resize model embeddings
tokenizer.add_tokens(tokens_to_be_added)
model.resize_token_embeddings(len(tokenizer))


MBartScaledWordEmbedding(255669, 1024, padding_idx=1)

In [30]:
# Reward weights
lambda_bleu = 0.7
lambda_anchor = 0.3

In [31]:
def compute_reward(generated, reference, input_sentence):
    # BLEU reward (Scaled up for more impact)
    bleu_score = sacrebleu.sentence_bleu(generated, [reference]).score / 100.0
    bleu_reward = bleu_score * lambda_bleu  # Larger weight for BLEU

    # Anchor word reward (Count all matches)
    anchor_reward = sum(1.0 for zh_anchor, en_anchor in anchor_word_dict.items() if zh_anchor in input_sentence and en_anchor in generated)

    # Optionally normalize or not (for now, no normalization to make it more impactful)
    # anchor_reward = anchor_reward / max(len(anchor_word_dict), 1)  # If you want normalization

    # Weighted total reward
    total_reward = bleu_reward + lambda_anchor * anchor_reward
    print(total_reward, bleu_reward, anchor_reward)
    return total_reward

In [32]:
# RL loop setup
def fine_tune_with_rl(model, tokenizer, parallel_sentences, num_epochs=1, log_interval=100):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)  # Increase the learning rate

    for epoch in range(num_epochs):
        total_loss = 0
        correct_translations = 0
        total_examples = 0

        for idx, (zh_sentence, en_reference) in enumerate(parallel_sentences):
            # Tokenize input sentence
            inputs = tokenizer(zh_sentence, return_tensors="pt", truncation=True, max_length=256).to(device)

            # Generate translation with forced BOS token
            outputs = model.generate(
                **inputs,
                max_length=256,
                num_beams=5,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )
            generated_tokens = outputs[0]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

            # Compute reward
            reward = compute_reward(generated_text, en_reference, zh_sentence)

            # Get model logits (without softmax applied) to calculate log probabilities
            model_outputs = model(**inputs, labels=generated_tokens.unsqueeze(0))
            logits = model_outputs.logits  # Shape: (batch_size, seq_len, vocab_size)
            
            # Get the log probabilities of the generated tokens
            log_probs = torch.log_softmax(logits, dim=-1)  # Apply softmax to logits and then take the log
            
            # Get the log probability of the generated tokens
            log_prob = 0
            for t, token in enumerate(generated_tokens):
                log_prob += log_probs[0, t, token]  # Sum log probabilities of the generated tokens

            # Compute policy loss (-reward * log_prob)
            policy_loss = -reward * log_prob

            # Backpropagation
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

            total_loss += policy_loss.item()
            total_examples += 1
            if reward > 0.8:  # Example threshold for correct translations
                correct_translations += 1

            # Logging progress
            if (idx + 1) % log_interval == 0:
                accuracy = correct_translations / total_examples
                print(f"Epoch {epoch + 1}/{num_epochs}, Step {idx + 1}/{len(parallel_sentences)}, Loss: {total_loss / total_examples:.4f}, Accuracy: {accuracy:.4f}")

        # End of epoch log
        accuracy = correct_translations / total_examples
        print(f"Epoch {epoch + 1}/{num_epochs} completed. Average Loss: {total_loss / total_examples:.4f}, Accuracy: {accuracy:.4f}")

In [33]:
fine_tune_with_rl(model, tokenizer, parallel_sentences, num_epochs=3, log_interval=5)

0.06186229357998088 0.06186229357998088 0
0.12103485505975332 0.12103485505975332 0
0.07972513809448918 0.07972513809448918 0
0.04300622298506018 0.04300622298506018 0
0.02363345370051766 0.02363345370051766 0
Epoch 1/3, Step 5/1000, Loss: 2.1475, Accuracy: 0.0000
0.06405407836463438 0.06405407836463438 0
0.036273302413397705 0.036273302413397705 0
0.02074201918486772 0.02074201918486772 0
0.011853781513057045 0.011853781513057045 0
0.009221176051421096 0.009221176051421096 0
Epoch 1/3, Step 10/1000, Loss: 1.3629, Accuracy: 0.0000
0.001288341869479308 0.001288341869479308 0
0.003970076429902533 0.003970076429902533 0
0.00011022187891041029 0.00011022187891041029 0
3.230361633773976e-05 3.230361633773976e-05 0
7.873183048546209e-09 7.873183048546209e-09 0
Epoch 1/3, Step 15/1000, Loss: 0.9128, Accuracy: 0.0000
0.0003055892584092149 0.0003055892584092149 0
0.00060731716403617 0.00060731716403617 0
0.0007615969223167628 0.0007615969223167628 0
0.0 0.0 0
0.0 0.0 0
Epoch 1/3, Step 20/1000, 

KeyboardInterrupt: 