In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
import numpy as np
from torch.utils.data import DataLoader, Dataset

# -------------------------------
# 1. Dataset wrapper
# -------------------------------
class ZuluEnglishDataset(Dataset):
    def __init__(self, zulu_sentences, english_sentences):
        self.zulu_sentences = zulu_sentences
        self.english_sentences = english_sentences

    def __len__(self):
        return len(self.zulu_sentences)

    def __getitem__(self, idx):
        return self.zulu_sentences[idx], self.english_sentences[idx]

# -------------------------------
# 2. Contrastive encoders
# -------------------------------
class ContrastiveEncoder(nn.Module):
    def __init__(self, model_name, proj_dim=256):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.projection = nn.Linear(self.encoder.config.hidden_size, proj_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        proj = self.projection(cls_emb)
        return nn.functional.normalize(proj, p=2, dim=1)

# -------------------------------
# 3. Contrastive Loss (InfoNCE)
# -------------------------------
def contrastive_loss(emb1, emb2, temperature=0.05):
    """
    emb1, emb2: (batch_size, dim)
    """
    batch_size = emb1.size(0)
    sim_matrix = torch.matmul(emb1, emb2.T) / temperature
    labels = torch.arange(batch_size).to(emb1.device)
    loss = nn.CrossEntropyLoss()(sim_matrix, labels)
    return loss

# -------------------------------
# 4. Training loop skeleton
# -------------------------------
def train_contrastive(zulu_sentences, english_sentences, epochs=5, batch_size=16, proj_dim=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load tokenizers and models
    tokenizer_zu = AutoTokenizer.from_pretrained("MoseliMotsoehli/zuBERTa")
    tokenizer_en = AutoTokenizer.from_pretrained("bert-base-uncased")
    model_zu = ContrastiveEncoder("MoseliMotsoehli/zuBERTa", proj_dim).to(device)
    model_en = ContrastiveEncoder("bert-base-uncased", proj_dim).to(device)

    dataset = ZuluEnglishDataset(zulu_sentences, english_sentences)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(list(model_zu.parameters()) + list(model_en.parameters()), lr=2e-5)

    for epoch in range(epochs):
        total_loss = 0
        for zu_batch, en_batch in dataloader:
            # Tokenize
            zu_inputs = tokenizer_zu(list(zu_batch), return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
            en_inputs = tokenizer_en(list(en_batch), return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)

            # Forward
            emb_zu = model_zu(zu_inputs["input_ids"], zu_inputs["attention_mask"])
            emb_en = model_en(en_inputs["input_ids"], en_inputs["attention_mask"])

            # Loss
            loss = contrastive_loss(emb_zu, emb_en)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

    return model_zu, model_en, tokenizer_zu, tokenizer_en, device

# ========================
# Example usage
# ========================
# zulu_sentences, english_sentences = load_your_dataset()  # list of strings
# model_zu, model_en, tokenizer_zu, tokenizer_en, device = train_contrastive(zulu_sentences, english_sentences)
