In [3]:
import torch
from transformers import MarianMTModel, MarianTokenizer

In [None]:
en_zh_df = pd.read_csv('news-commentary-v15.en-zh.tsv',sep = '\t', header=None).dropna()

In [4]:
model_name = "Helsinki-NLP/opus-mt-zh-en"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/805k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/807k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.62M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/312M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

In [6]:
# Sample data for demonstration
source_sentences = ["你好", "谢谢"]
target_sentences = ["Hello", "Thank you"]

In [9]:
source_anchor_points = ["你好", "学习"]  # Chinese anchor words
target_anchor_points = ["Hello", "study"]  # English equivalent anchor words

In [10]:
def embed_anchor_points(text, anchor_points):
    for anchor in anchor_points:
        text = text.replace(anchor, f"<<{anchor}>>")
    return text

source_sentences = [embed_anchor_points(sent, source_anchor_points) for sent in source_sentences]
target_sentences = [embed_anchor_points(sent, target_anchor_points) for sent in target_sentences]

In [7]:
from torch.utils.data import DataLoader, Dataset

# Define a custom dataset
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts, tokenizer):
        self.source_texts = source_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        source = self.source_texts[idx]
        target = self.target_texts[idx]
        source_enc = self.tokenizer(source, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
        target_enc = self.tokenizer(target, return_tensors="pt", padding="max_length", truncation=True, max_length=64)

        return {
            "input_ids": source_enc["input_ids"].squeeze(),
            "attention_mask": source_enc["attention_mask"].squeeze(),
            "labels": target_enc["input_ids"].squeeze()
        }

In [11]:
# Instantiate dataset and dataloader
dataset = TranslationDataset(source_sentences, target_sentences, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [12]:
import torch.nn.functional as F

def custom_loss_function(output_logits, target_ids, anchor_mask):
    """
    Compute custom loss with additional emphasis on anchor points.
    
    Parameters:
    - output_logits: Model's output logits
    - target_ids: Actual target token IDs
    - anchor_mask: Mask indicating anchor positions in the target sequence
    
    Returns:
    - Loss with penalties on anchor errors
    """
    loss = F.cross_entropy(output_logits.view(-1, output_logits.size(-1)), target_ids.view(-1), reduction="none")
    
    # Penalty: amplify loss for anchor point errors
    loss = loss.view(target_ids.shape) * (1 + anchor_mask * 2)  # Triple the loss weight for anchors
    return loss.mean()

In [13]:
def create_anchor_mask(tokenizer, sentences, anchor_points):
    """
    Create a mask to identify anchor point tokens in the target sequence.
    
    Parameters:
    - tokenizer: The tokenizer for the model
    - sentences: List of target sentences
    - anchor_points: List of anchor words
    
    Returns:
    - Mask tensor indicating anchor tokens
    """
    anchor_mask = []
    for sentence in sentences:
        mask = [1 if word in anchor_points else 0 for word in tokenizer.tokenize(sentence)]
        mask = mask + [0] * (64 - len(mask))  # pad to max length if needed
        anchor_mask.append(mask[:64])
    return torch.tensor(anchor_mask)

In [14]:
# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [15]:
# Training loop with custom loss function
model.train()
epochs = 3

for epoch in range(epochs):
    for batch in dataloader:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        labels = batch["labels"].to(model.device)

        # Generate anchor mask
        anchor_mask = create_anchor_mask(tokenizer, target_sentences, target_anchor_points).to(model.device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        
        # Calculate custom loss with anchor points
        loss = custom_loss_function(outputs.logits, labels, anchor_mask)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

Epoch 1/3, Loss: 7.913210868835449
Epoch 2/3, Loss: 7.041543006896973
Epoch 3/3, Loss: 5.931280136108398


In [16]:
model.eval()
new_sentence = "你好，欢迎来到这里学习"  # Test sentence with anchor points
input_ids = tokenizer.encode(new_sentence, return_tensors="pt")

# Generate translation
with torch.no_grad():
    generated_ids = model.generate(input_ids)
    translation = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"Chinese: {new_sentence}")
print(f"English Translation: {translation}")

Chinese: 你好，欢迎来到这里学习
English Translation: Hello. Welcome to school.
