In [61]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [62]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [70]:
def load_embeddings(file_path, max_vocab=100000, expected_dim=300):
    """Loads FastText embeddings from a file, ignoring malformed lines."""
    word2id, id2word, vectors = {}, {}, []
    
    with open(file_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        next(f)  # Skip header
        for idx, line in enumerate(f):
            if idx >= max_vocab:
                break
            tokens = line.strip().split()
            if len(tokens) != expected_dim + 1:  # 1 for the word, rest for vector
                continue  # skip bad lines
            word = tokens[0]
            vec = np.array(tokens[1:], dtype=np.float32)
            word2id[word] = idx
            id2word[idx] = word
            vectors.append(vec)
    
    embedding_tensor = torch.tensor(np.vstack(vectors), dtype=torch.float32).to(device)
    return embedding_tensor, word2id, id2word


In [71]:
class Generator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W = nn.Parameter(torch.eye(dim).to(device))

    def forward(self, x):
        return x @ self.W

    def orthogonalize(self):
        with torch.no_grad():
            W_np = self.W.data.cpu().numpy()
            u, _, vt = np.linalg.svd(W_np)
            self.W.data.copy_(torch.from_numpy(u @ vt).to(self.W.device))

In [114]:
class Discriminator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(1024, 2048),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(2048, 4096),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(4096, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

In [115]:
def compute_csls_similarity(src_vecs, tgt_vecs, k=5):
    src_vecs = F.normalize(src_vecs, dim=1, p=2)
    tgt_vecs = F.normalize(tgt_vecs, dim=1, p=2)

    sim_matrix = torch.matmul(src_vecs, tgt_vecs.T)

    k = min(k, tgt_vecs.shape[0], src_vecs.shape[0])  # make sure k isn't too big

    src_knn_sim = torch.topk(sim_matrix, k, dim=1, largest=True).values.mean(dim=1)
    tgt_knn_sim = torch.topk(sim_matrix, k, dim=0, largest=True).values.mean(dim=0)

    csls_scores = 2 * sim_matrix - src_knn_sim[:, None] - tgt_knn_sim[None, :]
    return csls_scores

In [116]:
def translate_word(word, en_vecs, hi_vecs, en_w2id, hi_id2w, generator, top_k=1):
    if word not in en_w2id:
        return ["<UNK>"]
    with torch.no_grad():
        idx = en_w2id[word]
        src_vec = en_vecs[idx].unsqueeze(0)
        projected = generator(src_vec)
        csls_scores = compute_csls_similarity(projected, hi_vecs, k=10)
        best_match_ids = csls_scores[0].topk(top_k).indices.tolist()
        return [hi_id2w[i] for i in best_match_ids]

In [117]:
def train_adversarial(en_vecs, hi_vecs, generator, discriminator, epochs = 100, batch_size = 128, lr = 0.0005):

    gen_opt = torch.optim.Adam(generator.parameters(), lr = lr)
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr = lr)
    bce_loss = nn.BCELoss()

    for epoch in range(epochs):
        for _ in range(100):

            en_idx = torch.randint(0, en_vecs.shape[0], (batch_size,))
            hi_idx = torch.randint(0, hi_vecs.shape[0], (batch_size,))
            x_en = en_vecs[en_idx]
            x_hi = hi_vecs[hi_idx]


            x_gen = generator(x_en).detach()
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            dis_real = discriminator(x_hi)
            dis_fake = discriminator(x_gen)

            loss_real = bce_loss(dis_real, real_labels)
            loss_fake = bce_loss(dis_fake, fake_labels)

            dis_loss = loss_real + loss_fake
            dis_opt.zero_grad()
            dis_loss.backward()
            dis_opt.step()

            x_gen = generator(x_en)
            fool_labels = torch.ones(batch_size, 1).to(device)
            gen_loss = bce_loss(discriminator(x_gen), fool_labels)

            gen_opt.zero_grad()
            gen_loss.backward()
            gen_opt.step()

            generator.orthogonalize()

        print(f"Epoch {epoch+1}/{epochs} - Dis Loss: {dis_loss.item():.4f} | Gen Loss: {gen_loss.item():.4f}")

In [118]:
# Load your embeddings
en_vecs, en_w2id, en_id2w = load_embeddings("wiki.en.vec", max_vocab=100000)
hi_vecs, hi_w2id, hi_id2w = load_embeddings("wiki.hi.vec", max_vocab=100000)

# Initialize models
dim = en_vecs.shape[1]
generator = Generator(dim).to(device)
discriminator = Discriminator(dim).to(device)

# Train
train_adversarial(en_vecs, hi_vecs, generator, discriminator, epochs=10)

# Test
print("EN → HI")
for word in ["dog", "king", "water", "school", "city"]:
    print(f"{word} → {translate_word(word, en_vecs, hi_vecs, en_w2id, hi_id2w, generator)[0]}")

Epoch 1/10 - Dis Loss: 0.4733 | Gen Loss: 4.9549
Epoch 2/10 - Dis Loss: 0.4222 | Gen Loss: 4.4872
Epoch 3/10 - Dis Loss: 0.1841 | Gen Loss: 7.3492
Epoch 4/10 - Dis Loss: 0.1369 | Gen Loss: 7.3759
Epoch 5/10 - Dis Loss: 0.1471 | Gen Loss: 7.8462
Epoch 6/10 - Dis Loss: 0.1830 | Gen Loss: 8.4681
Epoch 7/10 - Dis Loss: 0.1605 | Gen Loss: 7.1727
Epoch 8/10 - Dis Loss: 0.1459 | Gen Loss: 7.7234
Epoch 9/10 - Dis Loss: 0.1592 | Gen Loss: 7.9330
Epoch 10/10 - Dis Loss: 0.1195 | Gen Loss: 6.9259
EN → HI
dog → टाइपसेटिंग
king → बकरपुर
water → विशय
school → राजपत्रित
city → नेवला


In [122]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Use MPS if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# ------------------------- Load Embeddings ---------------------------
def load_embeddings(file_path, max_vocab=20000):
    word2id, id2word, vectors = {}, {}, []
    with open(file_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        next(f)  # Skip header
        for idx, line in enumerate(f):
            if idx >= max_vocab:
                break
            tokens = line.strip().split()
            if len(tokens) < 10:
                continue
            word, vec = tokens[0], np.array(tokens[1:], dtype=np.float32)
            if vec.shape[0] != 300:
                continue
            word2id[word] = idx
            id2word[idx] = word
            vectors.append(vec)
    return torch.tensor(np.vstack(vectors), dtype=torch.float32).to(device), word2id, id2word

# ------------------------- Normalize Embeddings ---------------------------
def normalize_embeddings(emb):
    emb = emb - emb.mean(dim=0, keepdim=True)
    emb = F.normalize(emb, p=2, dim=1)
    return emb

# ------------------------- Generator ---------------------------
class Generator(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.W = nn.Parameter(torch.empty(embedding_dim, embedding_dim))
        nn.init.orthogonal_(self.W)

    def forward(self, x):
        return x @ self.W

    def orthogonalize(self, beta=0.01):
        with torch.no_grad():
            W = self.W.data
            self.W.data = (1 + beta) * W - beta * W @ W.T @ W

# ------------------------- Discriminator ---------------------------
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(4096, 4096),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(4096, 1)
        )

    def forward(self, x):
        return self.net(x)

# ------------------------- CSLS Similarity ---------------------------
def compute_csls_similarity(src_vecs, tgt_vecs, k=10):
    src_vecs = F.normalize(src_vecs, dim=1, p=2)
    tgt_vecs = F.normalize(tgt_vecs, dim=1, p=2)

    sim_matrix = torch.matmul(src_vecs, tgt_vecs.T)
    k = min(k, src_vecs.size(0), tgt_vecs.size(0))
    src_knn_sim = torch.topk(sim_matrix, k, dim=1, largest=True).values.mean(dim=1)
    tgt_knn_sim = torch.topk(sim_matrix, k, dim=0, largest=True).values.mean(dim=0)

    csls_scores = 2 * sim_matrix - src_knn_sim[:, None] - tgt_knn_sim[None, :]
    return csls_scores

# ------------------------- Training Loop ---------------------------
def train(generator, discriminator, src_emb, tgt_emb, epochs=10, batch_size=128, lr=0.001):
    gen_opt = torch.optim.Adam(generator.parameters(), lr=lr)
    dis_opt = torch.optim.Adam(discriminator.parameters(), lr=lr * 2)
    bce_loss = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        for i in range(0, min(len(src_emb), len(tgt_emb)), batch_size):
            src_batch = src_emb[i:i+batch_size]
            tgt_batch = tgt_emb[i:i+batch_size]

            # Discriminator
            for _ in range(5):
                gen_emb = generator(src_batch).detach()
                real = tgt_batch
                fake = gen_emb

                real_labels = torch.ones(real.size(0), 1).to(device)
                fake_labels = torch.zeros(fake.size(0), 1).to(device)

                dis_real = discriminator(real)
                dis_fake = discriminator(fake)

                loss_real = bce_loss(dis_real, real_labels)
                loss_fake = bce_loss(dis_fake, fake_labels)
                loss_dis = (loss_real + loss_fake) / 2

                dis_opt.zero_grad()
                loss_dis.backward()
                dis_opt.step()

            # Generator
            gen_emb = generator(src_batch)
            pred = discriminator(gen_emb)
            loss_gen = bce_loss(pred, torch.ones_like(pred))

            gen_opt.zero_grad()
            loss_gen.backward()
            gen_opt.step()
            generator.orthogonalize()

        print(f"Epoch {epoch+1}: Gen Loss: {loss_gen.item():.4f}, Dis Loss: {loss_dis.item():.4f}")

# ------------------------- Translate Word ---------------------------
def translate_word(word, src_emb, tgt_emb, src_w2id, tgt_id2w, generator, top_k=1):
    idx = src_w2id.get(word, None)
    if idx is None:
        return ["<UNK>"]

    src_vec = src_emb[idx].unsqueeze(0)
    projected = generator(src_vec)
    csls_scores = compute_csls_similarity(projected, tgt_emb, k=10)
    best_match_ids = csls_scores[0].topk(top_k).indices.tolist()
    return [tgt_id2w[i] for i in best_match_ids]

# ------------------------- Main Pipeline ---------------------------
if __name__ == "__main__":
    en_vecs, en_w2id, en_id2w = load_embeddings("wiki.en.vec", max_vocab=20000)
    hi_vecs, hi_w2id, hi_id2w = load_embeddings("wiki.hi.vec", max_vocab=20000)

    en_vecs = normalize_embeddings(en_vecs)
    hi_vecs = normalize_embeddings(hi_vecs)

    generator = Generator(300).to(device)
    discriminator = Discriminator(300).to(device)

    train(generator, discriminator, en_vecs, hi_vecs, epochs=50, batch_size=32, lr=0.0001)

    for word in ["dog", "king", "water", "school", "city"]:
        translated = translate_word(word, en_vecs, hi_vecs, en_w2id, hi_id2w, generator)
        print(f"{word} → {translated[0]}")


Epoch 1: Gen Loss: 9.2581, Dis Loss: 0.0130
Epoch 2: Gen Loss: 10.9131, Dis Loss: 0.0028
Epoch 3: Gen Loss: 11.0009, Dis Loss: 0.0034
Epoch 4: Gen Loss: 12.7152, Dis Loss: 0.0036
Epoch 5: Gen Loss: 12.1210, Dis Loss: 0.0017
Epoch 6: Gen Loss: 12.0103, Dis Loss: 0.0012
Epoch 7: Gen Loss: 12.5229, Dis Loss: 0.0047
Epoch 8: Gen Loss: 13.5786, Dis Loss: 0.0011
Epoch 9: Gen Loss: 12.3497, Dis Loss: 0.0008
Epoch 10: Gen Loss: 12.4586, Dis Loss: 0.0005
Epoch 11: Gen Loss: 11.1901, Dis Loss: 0.0013
Epoch 12: Gen Loss: 11.0986, Dis Loss: 0.0016
Epoch 13: Gen Loss: 12.4666, Dis Loss: 0.0008
Epoch 14: Gen Loss: 13.1001, Dis Loss: 0.0012
Epoch 15: Gen Loss: 11.4810, Dis Loss: 0.0005
Epoch 16: Gen Loss: 11.8347, Dis Loss: 0.0021
Epoch 17: Gen Loss: 14.3655, Dis Loss: 0.0001
Epoch 18: Gen Loss: 13.9018, Dis Loss: 0.0008
Epoch 19: Gen Loss: 11.9682, Dis Loss: 0.0011
Epoch 20: Gen Loss: 13.5093, Dis Loss: 0.0041
Epoch 21: Gen Loss: 15.2070, Dis Loss: 0.0035
Epoch 22: Gen Loss: 12.8941, Dis Loss: 0.000