In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [15]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [25]:
def load_embeddings(file_path, max_vocab=100000):
    word2id, id2word, vectors = {}, {}, []
    with open(file_path, 'r', encoding='utf-8', newline='\n') as f:
        next(f)  # Skip header
        for idx, line in enumerate(f):
            if idx >= max_vocab:
                break
            tokens = line.strip().split()
            if len(tokens) < 10:
                continue
            try:
                word, vec = tokens[0], np.array(tokens[1:], dtype=np.float32)
            except:
                continue
            if vec.shape[0] != 300:
                continue
            word2id[word] = idx
            id2word[idx] = word
            vectors.append(vec)
    return torch.tensor(np.vstack(vectors), dtype=torch.float32).to(device), word2id, id2word

In [26]:
def normalize_embeddings(emb):
    emb = emb - emb.mean(dim=0, keepdim=True)
    emb = F.normalize(emb, p=2, dim=1)
    return emb

In [27]:
class Generator(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.W = nn.Parameter(torch.empty(embedding_dim, embedding_dim))
        nn.init.orthogonal_(self.W)

    def forward(self, x):
        return x @ self.W

    def orthogonalize(self, beta=0.01):
        with torch.no_grad():
            W = self.W.data
            self.W.data = (1 + beta) * W - beta * W @ W.T @ W


class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.LeakyReLU(0.2),
            nn.Linear(4096, 1)
        )

    def forward(self, x):
        return self.net(x)

In [28]:
def compute_csls_similarity(src_vecs, tgt_vecs, k=10):
    src_vecs = F.normalize(src_vecs, dim=1, p=2)
    tgt_vecs = F.normalize(tgt_vecs, dim=1, p=2)

    sim_matrix = torch.matmul(src_vecs, tgt_vecs.T)
    k = min(k, src_vecs.size(0), tgt_vecs.size(0))
    src_knn_sim = torch.topk(sim_matrix, k, dim=1, largest=True).values.mean(dim=1)
    tgt_knn_sim = torch.topk(sim_matrix, k, dim=0, largest=True).values.mean(dim=0)

    csls_scores = 2 * sim_matrix - src_knn_sim[:, None] - tgt_knn_sim[None, :]
    return csls_scores

In [34]:
import math

In [38]:
def train(generator, discriminator, src_emb, tgt_emb, epochs=10, batch_size=128, lr=0.001, save_path="best_model.pth", patience=5):
    gen_opt = torch.optim.Adam(generator.parameters(), lr=lr)
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=lr * 2)
    bce_loss = nn.BCEWithLogitsLoss()

    best_loss = float("inf")
    no_improve_epochs = 0

    for epoch in range(epochs):
        generator.train()
        discriminator.train()
        epoch_loss_gen = 0
        epoch_loss_dis = 0

        for i in range(0, min(len(src_emb), len(tgt_emb)), batch_size):
            src_batch = src_emb[i:i+batch_size]
            tgt_batch = tgt_emb[i:i+batch_size]

            # ----- Train Discriminator -----
            for _ in range(5):
                gen_emb = generator(src_batch).detach()
                real = tgt_batch
                fake = gen_emb

                real_labels = torch.ones(real.size(0), 1).to(device)
                fake_labels = torch.zeros(fake.size(0), 1).to(device)

                dis_real = discriminator(real)
                dis_fake = discriminator(fake)

                loss_real = bce_loss(dis_real, real_labels)
                loss_fake = bce_loss(dis_fake, fake_labels)
                loss_dis = (loss_real + loss_fake) / 2

                dis_opt.zero_grad()
                loss_dis.backward()
                dis_opt.step()

            # ----- Train Generator -----
            gen_emb = generator(src_batch)
            pred = discriminator(gen_emb)
            loss_gen = bce_loss(pred, torch.ones_like(pred))

            gen_opt.zero_grad()
            loss_gen.backward()
            gen_opt.step()
            generator.orthogonalize()

            epoch_loss_gen += loss_gen.item()
            epoch_loss_dis += loss_dis.item()

        avg_loss_gen = epoch_loss_gen / (len(src_emb) // batch_size)
        avg_loss_dis = epoch_loss_dis / (len(src_emb) // batch_size)

        print(f"Epoch {epoch+1}: Gen Loss: {avg_loss_gen:.4f}, Dis Loss: {avg_loss_dis:.4f}")

        # ----- Early stopping & saving -----
        if avg_loss_gen < best_loss:
            best_loss = avg_loss_gen
            no_improve_epochs = 0
            print(f"✨ New best loss! Saving model to {save_path}")
            torch.save({
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
            }, save_path)
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"⏹ Early stopping at epoch {epoch+1} due to no improvement for {patience} epochs.")
                break


In [37]:
def translate_word(word, src_emb, tgt_emb, src_w2id, tgt_id2w, generator, top_k=1):
    idx = src_w2id.get(word, None)
    if idx is None:
        return ["<UNK>"]

    src_vec = src_emb[idx].unsqueeze(0)
    projected = generator(src_vec)
    csls_scores = compute_csls_similarity(projected, tgt_emb, k=10)
    best_match_ids = csls_scores[0].topk(top_k).indices.tolist()
    return [tgt_id2w[i] for i in best_match_ids]

In [31]:
en_words, hi_words  = [], []
with open("en-hi-test.txt", "r", encoding = "utf-8") as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) != 2:
            continue
        en_word, hi_word = parts
        en_words.append(en_word)
        hi_words.append(hi_word)

In [39]:
if __name__ == "__main__":
    en_vecs, en_w2id, en_id2w = load_embeddings("wiki.en.vec", max_vocab=100000)
    hi_vecs, hi_w2id, hi_id2w = load_embeddings("wiki.hi.vec", max_vocab=100000)

    generator = Generator(300).to(device)
    discriminator = Discriminator(300).to(device)

    train(generator, discriminator, en_vecs, hi_vecs, epochs=50, batch_size=32, lr=0.0001, save_path="best_model.pth", patience=5)

    # Load best model
    print("🔄 Loading best saved model...")
    checkpoint = torch.load("best_model.pth", map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])

    # Evaluate
    translated_words = []
    for word in en_words:
        translated = translate_word(word, en_vecs, hi_vecs, en_w2id, hi_id2w, generator)
        translated_words.append(translated[0])

    correct = 0
    for pred, true in zip(translated_words, hi_words):
        if pred == true:
            correct += 1
    score = correct / len(translated_words)
    print(f"✅ Final Accuracy: {score:.4f}")


Epoch 1: Gen Loss: 7.6607, Dis Loss: 0.0332
✨ New best loss! Saving model to best_model.pth
Epoch 2: Gen Loss: 9.2987, Dis Loss: 0.0201
Epoch 3: Gen Loss: 9.8930, Dis Loss: 0.0171
Epoch 4: Gen Loss: 10.4937, Dis Loss: 0.0147
Epoch 5: Gen Loss: 11.0700, Dis Loss: 0.0129
Epoch 6: Gen Loss: 11.2328, Dis Loss: 0.0126
⏹ Early stopping at epoch 6 due to no improvement for 5 epochs.
🔄 Loading best saved model...
✅ Final Accuracy: 0.0000
