# Dataloader


In [1]:
!pip install h5py
import json
import h5py




In [2]:
# ---------------- 1. SET-UP & IMPORTS ----------------
!pip install h5py tqdm --quiet   # tiny helper libs

import pandas as pd, re, json, h5py, numpy as np
from collections import Counter
from tqdm import tqdm
tqdm.pandas()

CSV_PATH   = "/kaggle/input/question-pairs-dataset/questions.csv"          # ← your upload
OUTPUT_DIR = "/kaggle/working/qqp-prepro/"          # files will appear here
MAX_SRC    = 30   # max tokens per source Q
MAX_TGT    = 30   # max tokens per target Q
MIN_FREQ   = 2    # drop singletons from vocab

PAD,  SOS, EOS = "<pad>", "<sos>", "<eos>"

# ---------------- 2. LOAD & BASIC CLEAN ----------------
df = pd.read_csv(CSV_PATH)
print(df.head(3))

# Keep only rows with both Q1 & Q2 text and (optionally) duplicate label==1
df = df.dropna(subset=["question1", "question2"])
# If you want only paraphrase pairs, uncomment:
# df = df[df["is_duplicate"] == 1]

df = df[["question1", "question2"]].reset_index(drop=True)
print("Total usable pairs:", len(df))

# ---------------- 3. SIMPLE TEXT NORMALISATION ----------------
def clean(text):
    text = str(text).lower().strip()
    text = re.sub(r"[^a-z0-9?.,' ]+", " ", text)   # keep basic chars
    text = re.sub(r"\s+", " ", text)               # collapse spaces
    return text

df["q1"] = df["question1"].progress_apply(clean)
df["q2"] = df["question2"].progress_apply(clean)

# ---------------- 4. TOKENISE ----------------
def tok(text): return text.split()                 # whitespace tokeniser
df["tok1"] = df["q1"].progress_apply(tok)
df["tok2"] = df["q2"].progress_apply(tok)

# ---------------- 5. BUILD VOCAB ----------------
all_tokens = [tok for row in df[["tok1","tok2"]].values.ravel() for tok in row]
vocab_cnt  = Counter(all_tokens)
# keep tokens with freq ≥ MIN_FREQ
vocab      = [PAD, SOS, EOS] + [w for w,c in vocab_cnt.items() if c >= MIN_FREQ]
w2idx      = {w:i for i,w in enumerate(vocab)}
idx2word   = {i:w for w,i in w2idx.items()}
print("Vocab size:", len(vocab))

# ---------------- 6. ENCODE & PAD/TRUNCATE ----------------
def encode(tok_list, max_len):
    ids = [w2idx.get(w, w2idx[PAD]) for w in tok_list][:max_len-1]
    ids = [w2idx[SOS]] + ids                         # prepend <sos>
    ids = ids + [w2idx[EOS]]
    ids = ids[:max_len]                              # truncate inc. <eos>
    pad_len = max_len - len(ids)
    return ids + [w2idx[PAD]]*pad_len, len(ids)      # ids, true length

src_ids, tgt_ids, src_len, tgt_len = [], [], [], []
for t1, t2 in tqdm(df[["tok1","tok2"]].values):
    ids1, l1 = encode(t1, MAX_SRC)
    ids2, l2 = encode(t2, MAX_TGT)
    src_ids.append(ids1); tgt_ids.append(ids2)
    src_len.append(l1);   tgt_len.append(l2)

src_ids = np.array(src_ids, dtype=np.int32)
tgt_ids = np.array(tgt_ids, dtype=np.int32)
src_len = np.array(src_len, dtype=np.int32)
tgt_len = np.array(tgt_len, dtype=np.int32)
print("Array shapes →", src_ids.shape, tgt_ids.shape)

# ---------------- 7. SAVE HDF5 ----------------
!mkdir -p "$OUTPUT_DIR"
h5_path = OUTPUT_DIR + "quora_data_prepro.h5"
with h5py.File(h5_path, "w") as h5f:
    h5f.create_dataset("source_ids", data=src_ids, compression="gzip")
    h5f.create_dataset("target_ids", data=tgt_ids, compression="gzip")
    h5f.create_dataset("source_len", data=src_len, compression="gzip")
    h5f.create_dataset("target_len", data=tgt_len, compression="gzip")
print("✅ HDF5 saved →", h5_path)

# ---------------- 8. SAVE VOCAB JSON ----------------
json_path = OUTPUT_DIR + "quora_data_prepro.json"
with open(json_path, "w") as f:
    json.dump({"ix_to_word": idx2word}, f)
print("✅ JSON saved  →", json_path)

# ---------------- 9. DONE ----------------
print("\nFinished preprocessing!  You can now mount",
      "'/kaggle/working/qqp-prepro' as input in a new notebook.")


   id  qid1  qid2                                          question1  \
0   0     1     2  What is the step by step guide to invest in sh...   
1   1     3     4  What is the story of Kohinoor (Koh-i-Noor) Dia...   
2   2     5     6  How can I increase the speed of my internet co...   

                                           question2  is_duplicate  
0  What is the step by step guide to invest in sh...             0  
1  What would happen if the Indian government sto...             0  
2  How can Internet speed be increased by hacking...             0  
Total usable pairs: 404348


100%|██████████| 404348/404348 [00:02<00:00, 152190.41it/s]
100%|██████████| 404348/404348 [00:02<00:00, 155039.34it/s]
100%|██████████| 404348/404348 [00:01<00:00, 305330.42it/s]
100%|██████████| 404348/404348 [00:01<00:00, 210205.17it/s]


Vocab size: 83223


100%|██████████| 404348/404348 [00:05<00:00, 72674.66it/s] 


Array shapes → (404348, 30) (404348, 30)
✅ HDF5 saved → /kaggle/working/qqp-prepro/quora_data_prepro.h5
✅ JSON saved  → /kaggle/working/qqp-prepro/quora_data_prepro.json

Finished preprocessing!  You can now mount '/kaggle/working/qqp-prepro' as input in a new notebook.


In [3]:
import json, h5py, numpy as np, torch
from torch.utils.data import Dataset, DataLoader, random_split

# ------------------------------------------------------------------
# 1. Paths
DATA_DIR   = "/kaggle/working/qqp-prepro"        # change if different
H5_PATH    = f"{DATA_DIR}/quora_data_prepro.h5"
JSON_PATH  = f"{DATA_DIR}/quora_data_prepro.json"

# ------------------------------------------------------------------
# 2. Load the vocab mapping (ix_to_word) so idx→word is available
with open(JSON_PATH, "r") as f:
    ix2word = json.load(f)["ix_to_word"]          # { "0":"<pad>", "1":"<sos>", ... }
idx2word   = [ix2word[str(i)] for i in range(len(ix2word))]
vocab      = {w:i for i,w in enumerate(idx2word)}  # word → index
PAD_IDX    = vocab["<pad>"]

print("Vocab size =", len(vocab))

# ------------------------------------------------------------------
# 3. Load pre-tokenised / padded arrays from HDF5
with h5py.File(H5_PATH, "r") as h5f:
    source_ids = h5f["source_ids"][:]     # shape (N, max_src_len)
    target_ids = h5f["target_ids"][:]     # shape (N, max_tgt_len)
    source_len = h5f["source_len"][:]     # shape (N,)
    target_len = h5f["target_len"][:]     # shape (N,)

print("Loaded arrays →", source_ids.shape, target_ids.shape)

# ------------------------------------------------------------------
# 4. Torch Dataset
class QQPDataset(Dataset):
    def __init__(self, src, tgt, sl, tl):
        # store numpy arrays (efficient) – convert to tensors per sample
        self.src, self.tgt, self.sl, self.tl = src, tgt, sl, tl
    def __len__(self): return len(self.src)
    def __getitem__(self, idx):
        return {
            "src":     torch.from_numpy(self.src[idx]).long(),
            "tgt":     torch.from_numpy(self.tgt[idx]).long(),
            "src_len": torch.tensor(self.sl[idx]).long(),
            "tgt_len": torch.tensor(self.tl[idx]).long()
        }

full_ds = QQPDataset(source_ids, target_ids, source_len, target_len)

# ------------------------------------------------------------------
# 5. Split into train / val
VAL_FRAC   = 0.4              # 10 % validation
val_size   = int(len(full_ds)*VAL_FRAC)
train_size = len(full_ds) - val_size
train_ds, val_ds = random_split(full_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(42))

print(f"Train size = {len(train_ds)},  Val size = {len(val_ds)}")

# ------------------------------------------------------------------
# 6. DataLoaders
BATCH_SIZE = 32
NUM_WORKERS= 4                # 0 if running into multiprocessing issues

train_loader = DataLoader(train_ds,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=NUM_WORKERS,
                          drop_last=True)

val_loader   = DataLoader(val_ds,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=NUM_WORKERS)

# ------------------------------------------------------------------
# 7. Quick sanity-check – pull one batch
batch = next(iter(train_loader))
print("Batch shapes:",
      batch["src"].shape,     # [B, max_src_len]
      batch["tgt"].shape,     # [B, max_tgt_len]
      batch["src_len"].shape) # [B]


Vocab size = 83223
Loaded arrays → (404348, 30) (404348, 30)
Train size = 242609,  Val size = 161739
Batch shapes: torch.Size([32, 30]) torch.Size([32, 30]) torch.Size([32])


### Grok


In [4]:
# import json
# import h5py
# import torch
# from torch.utils.data import Dataset, DataLoader
# import numpy as np

# # 1) Load JSON and build idx2word + vocab
# json_path = '/kaggle/input/qqp-processed/quora_data_prepro.json'
# with open(json_path, 'r') as f:
#     meta = json.load(f)

# # ix_to_word maps string indices (1-based) to tokens
# ix_to_word = {int(k): v for k, v in meta['ix_to_word'].items()}
# print(f"Original ix_to_word keys: min={min(ix_to_word.keys())}, max={max(ix_to_word.keys())}, count={len(ix_to_word)}")

# # Add <pad> and <sos> tokens
# idx2word = ['<pad>', '<sos>']  # Indices 0 and 1
# vocab = {'<pad>': 0, '<sos>': 1}
# # Shift original indices up by 2
# for idx in sorted(ix_to_word.keys()):
#     idx2word.append(ix_to_word[idx])
#     vocab[ix_to_word[idx]] = len(idx2word) - 1

# print(f"Loaded vocab of size {len(vocab)}")

# # 2) Load HDF5 file
# h5_path = '/kaggle/input/qqp-processed/quora_data_prepro.h5'
# with h5py.File(h5_path, 'r') as h5f:
#     print("HDF5 contains:", list(h5f.keys()))

#     # Load training sequences and lengths
#     src_key = 'ques1_train'
#     tgt_key = 'ques_train'
#     src_len_key = 'ques1_length_train'
#     tgt_len_key = 'ques_length_train'

#     if not all(k in h5f.keys() for k in [src_key, tgt_key, src_len_key, tgt_len_key]):
#         raise KeyError(f"Required keys {[src_key, tgt_key, src_len_key, tgt_len_key]} not found in {list(h5f.keys())}")

#     source_ids = np.array(h5f[src_key][:], dtype=np.int64)
#     target_ids = np.array(h5f[tgt_key][:], dtype=np.int64)
#     source_lens = np.array(h5f[src_len_key][:], dtype=np.int64)
#     target_lens = np.array(h5f[tgt_len_key][:], dtype=np.int64)

#     # Verify data shapes
#     print(f"Source IDs shape: {source_ids.shape}, Target IDs shape: {target_ids.shape}")
#     print(f"Source lengths shape: {source_lens.shape}, Target lengths shape: {target_lens.shape}")

#     # Adjust indices: HDF5 indices (1-based) → new indices (shifted by 2)
#     source_ids = np.where(source_ids == 0, 0, source_ids + 1)  # 0 stays <pad>, others shift
#     target_ids = np.where(target_ids == 0, 0, target_ids + 1)
#     # Clip to valid range
#     source_ids = np.clip(source_ids, 0, len(idx2word) - 1)
#     target_ids = np.clip(target_ids, 0, len(idx2word) - 1)

# # 3) Define Dataset
# class QQPDataset(Dataset):
#     def __init__(self, src_ids, tgt_ids, src_lens, tgt_lens):
#         self.src_ids = src_ids
#         self.tgt_ids = tgt_ids
#         self.src_lens = src_lens
#         self.tgt_lens = tgt_lens

#     def __len__(self):
#         return len(self.src_ids)

#     def __getitem__(self, idx):
#         return {
#             'src': torch.tensor(self.src_ids[idx], dtype=torch.long),
#             'tgt': torch.tensor(self.tgt_ids[idx], dtype=torch.long),
#             'src_len': torch.tensor(self.src_lens[idx], dtype=torch.long),
#             'tgt_len': torch.tensor(self.tgt_lens[idx], dtype=torch.long)
#         }

# # 4) Instantiate Dataset
# train_dataset = QQPDataset(source_ids, target_ids, source_lens, target_lens)

# # 5) Create Training DataLoader
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=64,
#     shuffle=True,
#     num_workers=2,
#     drop_last=True
# )

# # 6) Create Validation DataLoader
# with h5py.File(h5_path, 'r') as h5f:
#     src_key = 'ques1_test'
#     tgt_key = 'ques_test'
#     src_len_key = 'ques1_length_test'
#     tgt_len_key = 'ques_length_test'

#     if not all(k in h5f.keys() for k in [src_key, tgt_key, src_len_key, tgt_len_key]):
#         raise KeyError(f"Validation keys {[src_key, tgt_key, src_len_key, tgt_len_key]} not found in {list(h5f.keys())}")

#     val_source_ids = np.array(h5f[src_key][:], dtype=np.int64)
#     val_target_ids = np.array(h5f[tgt_key][:], dtype=np.int64)
#     val_source_lens = np.array(h5f[src_len_key][:], dtype=np.int64)
#     val_target_lens = np.array(h5f[tgt_len_key][:], dtype=np.int64)

#     # Adjust indices
#     val_source_ids = np.where(val_source_ids == 0, 0, val_source_ids + 1)
#     val_target_ids = np.where(val_target_ids == 0, 0, val_target_ids + 1)
#     val_source_ids = np.clip(val_source_ids, 0, len(idx2word) - 1)
#     val_target_ids = np.clip(val_target_ids, 0, len(idx2word) - 1)

# val_dataset = QQPDataset(val_source_ids, val_target_ids, val_source_lens, val_target_lens)
# val_loader = DataLoader(
#     val_dataset,
#     batch_size=64,
#     shuffle=False,
#     num_workers=2,
#     drop_last=False
# )

# # 7) Test a batch
# for batch in train_loader:
#     print("Train batch keys:", batch.keys())
#     print("Src shape:", batch['src'].shape)
#     print("Tgt shape:", batch['tgt'].shape)
#     print("Src_len shape:", batch['src_len'].shape)
#     print("Tgt_len shape:", batch['tgt_len'].shape)
#     break

In [5]:
# import json
# import h5py

# # 1) Load the JSON and build idx2word + vocab
# json_path = '/kaggle/input/qqp-processed/quora_data_prepro.json'
# with open(json_path, 'r') as f:
#     meta = json.load(f)

# # meta only has 'ix_to_word', mapping string indices → tokens
# ix_to_word = {int(k): v for k, v in meta['ix_to_word'].items()}
# print(f"ix_to_word keys: min={min(ix_to_word.keys())}, max={max(ix_to_word.keys())}, count={len(ix_to_word)}")

# # Build a list so idx2word[i] = token, using sorted keys
# sorted_keys = sorted(ix_to_word.keys())
# idx2word = [ix_to_word[k] for k in sorted_keys]
# # And invert to get word→idx
# vocab = {w: i for i, w in enumerate(idx2word)}

# print(f"Loaded vocab of size {len(vocab)}")

# # 2) Inspect your HDF5 to see what datasets it contains
# h5_path = '/kaggle/input/qqp-processed/quora_data_prepro.h5'
# with h5py.File(h5_path, 'r') as h5f:
#     print("HDF5 contains:", list(h5f.keys()))

#     # 3) Automatically pick the two arrays for source/target sequences
#     all_keys = list(h5f.keys())
#     src_key = next((k for k in all_keys if 'source' in k.lower()), None)
#     tgt_key = next((k for k in all_keys if 'target' in k.lower()), None)

#     if src_key is None or tgt_key is None:
#         raise KeyError(f"Couldn't find 'source'/'target' in {all_keys}")

#     source_ids = h5f[src_key][:]
#     target_ids = h5f[tgt_key][:]

# # 4) Derive max lengths from their shapes
# max_src_len = source_ids.shape[1]
# max_tgt_len = target_ids.shape[1]

# print(f"Using HDF5 keys: src='{src_key}', tgt='{tgt_key}'")
# print(f"Max source length = {max_src_len},  max target length = {max_tgt_len}")

In [6]:
# json_path = '/kaggle/input/qqp-processed/quora_data_prepro.json'
# with open(json_path, 'r') as f:
#     meta = json.load(f)
# # Extract key info
# vocab      = meta['w2idx']        # word→index mapping
# idx2word   = meta['idx2w']        # list of tokens by index
# max_src_len = meta['max_src_len']
# max_tgt_len = meta['max_tgt_len']

# print(f"Vocab size: {len(vocab)}, max lengths: {max_src_len}/{max_tgt_len}")

In [7]:
# h5_path = '/kaggle/input/qqp-processed/quora_data_prepro.h5'
# with h5py.File(h5_path, 'r') as h5f:
#     print("Available datasets:", list(h5f.keys()))
#     source_ids = h5f['source_ids'][:]   # shape (N, max_src_len)
#     target_ids = h5f['target_ids'][:]   # shape (N, max_tgt_len)
#     source_lens = h5f['source_len'][:]  # actual lengths
#     target_lens = h5f['target_len'][:]
# print("Loaded:", source_ids.shape, target_ids.shape)

In [8]:
# import torch
# from torch.utils.data import Dataset

# class QQPDataset(Dataset):
#     def __init__(self, src_ids, tgt_ids, src_lens, tgt_lens):
#         self.src_ids   = src_ids
#         self.tgt_ids   = tgt_ids
#         self.src_lens  = src_lens
#         self.tgt_lens  = tgt_lens

#     def __len__(self):
#         return len(self.src_ids)

#     def __getitem__(self, idx):
#         return {
#             'src':     torch.tensor(self.src_ids[idx], dtype=torch.long),
#             'tgt':     torch.tensor(self.tgt_ids[idx], dtype=torch.long),
#             'src_len': torch.tensor(self.src_lens[idx], dtype=torch.long),
#             'tgt_len': torch.tensor(self.tgt_lens[idx], dtype=torch.long)
#         }

# # Instantiate
# dataset = QQPDataset(source_ids, target_ids, source_lens, target_lens)


In [9]:
# from torch.utils.data import DataLoader

# loader = DataLoader(
#     dataset,
#     batch_size=64,
#     shuffle=True,
#     num_workers=2,
#     drop_last=True
# )


# Generator

### Original

In [10]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class ParaphraseGenerator(nn.Module):
#     def __init__(self,
#                  vocab_size: int,
#                  emb_dim: int,
#                  enc_hidden: int,
#                  dec_hidden: int,
#                  latent_dim: int,
#                  max_tgt_len: int,
#                  pad_idx: int,
#                  dropout: float = 0.1):
#         """
#         Args:
#             vocab_size: size of the embedding vocabulary
#             emb_dim: dimensionality of word embeddings
#             enc_hidden: hidden size of the encoder GRU
#             dec_hidden: hidden size of the decoder GRU
#             latent_dim: size of the latent vector z
#             max_tgt_len: maximum target sequence length
#             pad_idx: padding token index
#             dropout: dropout probability
#         """
#         super().__init__()
#         # Embedding layer
#         self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
#         # Encoder: bidirectional GRU
#         self.encoder = nn.GRU(emb_dim,
#                               enc_hidden,
#                               batch_first=True,
#                               bidirectional=True)
#         # Project encoder’s bidirectional hidden to decoder hidden
#         self.enc_to_dec = nn.Linear(enc_hidden * 2 + latent_dim, dec_hidden)
#         # Decoder: unidirectional GRU
#         self.decoder = nn.GRU(emb_dim,
#                               dec_hidden,
#                               batch_first=True)
#         # Output projection to vocabulary
#         self.out_proj = nn.Linear(dec_hidden, vocab_size)
#         self.dropout = nn.Dropout(dropout)
#         self.max_tgt_len = max_tgt_len
#         self.latent_dim = latent_dim

#     def forward(self,
#                 src_ids: torch.Tensor,
#                 src_lens: torch.Tensor,
#                 tgt_ids: torch.Tensor = None,
#                 teacher_forcing_ratio: float = 0.5):
#         """
#         Args:
#             src_ids: [B, T_src] source token IDs
#             src_lens: [B] actual lengths of each source sequence
#             tgt_ids: [B, T_tgt] target token IDs (for teacher forcing)
#             teacher_forcing_ratio: probability to use ground-truth token
#         Returns:
#             logits: [B, T_tgt, V] pre-softmax scores over vocab
#         """
#         batch_size = src_ids.size(0)
#         device = src_ids.device

#         # 1) Embed and pack source sequences
#         embedded_src = self.dropout(self.embedding(src_ids))
#         packed = nn.utils.rnn.pack_padded_sequence(
#             embedded_src, src_lens.cpu(), batch_first=True, enforce_sorted=False)
#         # 2) Encode
#         enc_outputs, enc_hidden = self.encoder(packed)
#         # enc_hidden: [2, B, enc_hidden] from bidirectional GRU
#         # Concatenate forward & backward final states → [B, 2*enc_hidden]
#         enc_hidden = torch.cat([enc_hidden[0], enc_hidden[1]], dim=-1)

#         # 3) Sample latent code z ∼ N(0, I)
#         z = torch.randn(batch_size, self.latent_dim, device=device)

#         # 4) Initialize decoder hidden: project [enc_hidden; z] → [B, dec_hidden]
#         dec_init = torch.tanh(self.enc_to_dec(torch.cat([enc_hidden, z], dim=1)))
#         dec_hidden = dec_init.unsqueeze(0)  # [1, B, dec_hidden]

#         # 5) Decode step-by-step with optional teacher forcing
#         # Prepare first input token: assume <sos> token is index 1
#         inputs = torch.full((batch_size, 1), 1, dtype=torch.long, device=device)
#         logits = []

#         for t in range(self.max_tgt_len):
#             emb_t = self.dropout(self.embedding(inputs))  # [B, 1, emb_dim]
#             out, dec_hidden = self.decoder(emb_t, dec_hidden)
#             # out: [B, 1, dec_hidden]
#             step_logits = self.out_proj(out.squeeze(1))  # [B, V]
#             logits.append(step_logits.unsqueeze(1))       # accumulate

#             # Next input: either ground-truth or greedy sample
#             if tgt_ids is not None and torch.rand(1).item() < teacher_forcing_ratio:
#                 inputs = tgt_ids[:, t].unsqueeze(1)  # teacher forcing
#             else:
#                 inputs = step_logits.argmax(dim=-1).unsqueeze(1)

#         # Concatenate logits: [B, T_tgt, V]
#         return torch.cat(logits, dim=1)


### Grok Mod

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ParaphraseGenerator(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        emb_dim: int,
        enc_hidden: int,
        dec_hidden: int,
        latent_dim: int,
        max_tgt_len: int,
        pad_idx: int,
        dropout: float = 0.1
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.encoder = nn.GRU(emb_dim, enc_hidden, batch_first=True, bidirectional=True)
        self.enc_to_dec = nn.Linear(enc_hidden * 2 + latent_dim, dec_hidden)
        self.decoder = nn.GRU(emb_dim, dec_hidden, batch_first=True)
        self.out_proj = nn.Linear(dec_hidden, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.max_tgt_len = max_tgt_len
        self.latent_dim = latent_dim
        self.pad_idx = pad_idx
        self.sos_idx = 1  # Hard-coded as per original assumption

    def forward(
        self,
        src_ids: torch.Tensor,
        src_lens: torch.Tensor,
        tgt_ids: torch.Tensor = None,
        z: torch.Tensor = None,
        return_repr: bool = False,
        tf_ratio: float = 0.5
    ):
        batch_size = src_ids.size(0)
        device = src_ids.device

        # 1) Embed and pack source sequences
        embedded_src = self.dropout(self.embedding(src_ids))
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded_src, src_lens.cpu(), batch_first=True, enforce_sorted=False
        )

        # 2) Encode
        enc_outputs, enc_hidden = self.encoder(packed)
        enc_hidden = torch.cat([enc_hidden[0], enc_hidden[1]], dim=-1)  # [B, 2*enc_hidden]

        # 3) Use provided z or sample new one
        if z is None:
            z = torch.randn(batch_size, self.latent_dim, device=device)

        # 4) Initialize decoder hidden
        dec_init = torch.tanh(self.enc_to_dec(torch.cat([enc_hidden, z], dim=1)))
        dec_hidden = dec_init.unsqueeze(0)  # [1, B, dec_hidden]

        # 5) Decode
        inputs = torch.full((batch_size, 1), self.sos_idx, dtype=torch.long, device=device)
        logits = []
        outputs = []

        for t in range(self.max_tgt_len):
            emb_t = self.dropout(self.embedding(inputs))  # [B, 1, emb_dim]
            out, dec_hidden = self.decoder(emb_t, dec_hidden)  # out: [B, 1, dec_hidden]
            step_logits = self.out_proj(out.squeeze(1))  # [B, V]
            logits.append(step_logits.unsqueeze(1))
            outputs.append(out.squeeze(1))  # Accumulate for repr

            # Next input
            if tgt_ids is not None and torch.rand(1).item() < tf_ratio:
                inputs = tgt_ids[:, t].unsqueeze(1) if t < tgt_ids.size(1) else inputs
            else:
                inputs = step_logits.argmax(dim=-1).unsqueeze(1)

        logits = torch.cat(logits, dim=1)  # [B, T_tgt, V]
        repr_out = torch.stack(outputs, dim=1)  # [B, T_tgt, dec_hidden]

        if return_repr:
            return logits, repr_out
        return logits

# Example usage in training loop
# logits, repr = model_G(src, src_len, z=z1, return_repr=True, tf_ratio=0.5)

# Discriminator

### Original

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_filters, filter_sizes, pad_idx, dropout=0.1):
        """
        CNN-based sentence encoder using multiple filter widths and max-over-time pooling.
        References:
          – Kim (2014): CNNs for sentence classification :contentReference[oaicite:6]{index=6}
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        # Convolutional filters of sizes e.g. [3,4,5] :contentReference[oaicite:7]{index=7}
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=emb_dim,
                      out_channels=num_filters,
                      kernel_size=fs)
            for fs in filter_sizes
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
          x: [B, T] token IDs
        Returns:
          emb: [B, num_filters * len(filter_sizes)] pooled features
        """
        # Embed tokens → [B, T, emb_dim], then to [B, emb_dim, T] for conv1d :contentReference[oaicite:8]{index=8}
        emb = self.embedding(x).permute(0, 2, 1)
        # Apply each conv → ReLU → max-over-time pooling :contentReference[oaicite:9]{index=9}
        pooled = [
            F.max_pool1d(F.relu(conv(emb)), conv(emb).shape[2]).squeeze(2)
            for conv in self.convs
        ]
        cat = torch.cat(pooled, dim=1)
        return self.dropout(cat)

class Discriminator(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_filters, filter_sizes, pad_idx, dropout=0.1):
        """
        DivGAN Discriminator: q(x,y) = σ(w[C(x);C(y)] + b)
        As per Cao & Wan (2020) §3.1.2 :contentReference[oaicite:10]{index=10}.
        """
        super().__init__()
        # Shared CNN encoder for source & paraphrase :contentReference[oaicite:11]{index=11}
        self.encoder = CNNEncoder(vocab_size, emb_dim, num_filters, filter_sizes, pad_idx, dropout)
        feat_dim = num_filters * len(filter_sizes)
        # Linear scoring layer → binary quality score :contentReference[oaicite:12]{index=12}
        self.fc = nn.Linear(feat_dim * 2, 1)

    def forward(self, src_ids, tgt_ids):
        """
        Args:
          src_ids: [B, T_src] source token IDs
          tgt_ids: [B, T_tgt] paraphrase token IDs
        Returns:
          prob: [B, 1] q(x,y) ∈ [0,1]
        """
        # Encode both sentences with shared CNN :contentReference[oaicite:13]{index=13}
        src_feat = self.encoder(src_ids)
        tgt_feat = self.encoder(tgt_ids)
        # Concatenate and score → sigmoid :contentReference[oaicite:14]{index=14}
        combined = torch.cat([src_feat, tgt_feat], dim=1)
        logit = self.fc(combined)
        return torch.sigmoid(logit)


# Lossses

In [13]:
import torch
import torch.nn.functional as F

def discriminator_loss(D_real: torch.Tensor,
                       D_fake: torch.Tensor) -> torch.Tensor:
    """
    Adversarial loss for the discriminator:
      – D_real: discriminator output for real (x, y) pairs, shape [B,1], values in (0,1)
      – D_fake: discriminator output for generated (x, ŷ) pairs, shape [B,1]
    Returns:
      scalar loss = BCE(D_real, 1) + BCE(D_fake, 0)
    """
    real_labels = torch.ones_like(D_real)
    fake_labels = torch.zeros_like(D_fake)
    loss_real = F.binary_cross_entropy(D_real, real_labels)
    loss_fake = F.binary_cross_entropy(D_fake, fake_labels)
    return loss_real + loss_fake

def generator_adversarial_loss(D_fake: torch.Tensor) -> torch.Tensor:
    """
    Adversarial loss for the generator:
      – D_fake: discriminator output for generated (x, ŷ) pairs, shape [B,1]
    Returns:
      scalar loss = BCE(D_fake, 1)
    """
    real_labels = torch.ones_like(D_fake)
    return F.binary_cross_entropy(D_fake, real_labels)
    
## Original
# def diversity_loss(y1_repr: torch.Tensor,
#                    y2_repr: torch.Tensor,
#                    z1: torch.Tensor,
#                    z2: torch.Tensor,
#                    lambda_div: float = 1.0,
#                    eps: float = 1e-8) -> torch.Tensor:
#     """
#     Hinge-style diversity loss (Eq.6 in DivGAN):
#       y1_repr, y2_repr: [B, D] representation vectors of two samples ŷ^(1), ŷ^(2)
#       z1, z2:           [B, L] latent codes used to generate them
#       lambda_div:       slack margin λ
#     Returns:
#       mean over batch of max(λ − ||y1−y2||2 / (||z1−z2||2+eps), 0)
#     """
#     # L2 distances
#     dist_y = (y1_repr - y2_repr).norm(p=2, dim=1)        # [B]
#     dist_z = (z1   - z2  ).norm(p=2, dim=1).clamp(min=eps)  # [B]
#     # hinge penalty
#     hinge = F.relu(lambda_div - dist_y / dist_z)
#     return hinge.mean()

def diversity_loss(y1_repr: torch.Tensor, y2_repr: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor, lambda_div: float = 1.0, eps: float = 1e-8) -> torch.Tensor:
    """
    Hinge-style diversity loss (Eq.6 in DivGAN):
      y1_repr, y2_repr: [B, T_tgt, dec_hidden] representation vectors of two samples ŷ^(1), ŷ^(2)
      z1, z2:           [B, latent_dim] latent codes used to generate them
      lambda_div:       slack margin λ
      eps:              small constant to avoid division by zero
    Returns:
      mean over batch of max(λ − ||y1−y2||2 / (||z1−z2||2+eps), 0)
    """
    # Mean pool over sequence dimension: [B, T_tgt, dec_hidden] → [B, dec_hidden]
    y1_repr = y1_repr.mean(dim=1)
    y2_repr = y2_repr.mean(dim=1)
    # L2 distances
    dist_y = (y1_repr - y2_repr).norm(p=2, dim=1)  # [B]
    dist_z = (z1 - z2).norm(p=2, dim=1).clamp(min=eps)  # [B]
    # Hinge penalty
    hinge = F.relu(lambda_div - dist_y / dist_z)
    return hinge.mean()

# Evaluation Metrices

In [None]:
!pip install sacrebleu bert-score

import sacrebleu
from bert_score import score as bert_score

def compute_bleu4(references, hypotheses):
    """
    Compute corpus-level BLEU-4 using sacrebleu.
    Args:
      references: List[str] of ground-truth sentences.
      hypotheses: List[str] of generated sentences.
    Returns:
      bleu4_score: float
    """
    if not references or not hypotheses:
        return 0.0
    # Filter out empty strings
    refs = [r for r in references if r.strip()]
    hyps = [h for h in hypotheses if h.strip()]
    if not refs or not hyps:
        return 0.0
    # Wrap references for sacrebleu
    refs = [refs]
    bleu = sacrebleu.corpus_bleu(hyps, refs)
    return bleu.score  # BLEU score in percentage


# def compute_bertscore(references, hypotheses, lang='en', rescale_with_baseline=True):
#     """
#     Compute average BERTScore F1 across the corpus.
#     Args:
#       references: List[str] of ground-truth sentences.
#       hypotheses: List[str] of generated sentences.
#       lang: language for BERTScore ('en' for English).
#       rescale_with_baseline: whether to apply the authors’ baseline rescaling.
#     Returns:
#       P, R, F1: floats (precision, recall, F1)
#     """
#     P, R, F1 = bert_score(hypotheses, references, lang=lang, 
#                            rescale_with_baseline=rescale_with_baseline)
#     # bert_score returns torch tensors; convert to floats
#     return P.mean().item(), R.mean().item(), F1.mean().item()


Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m898.3 kB/s[0m eta [36m0:00:00[0m0:01[0m
[?25hCollecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.0->

# Training Loop

### Original

In [None]:
# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader
# from torch.optim import Adam

# # --- 1. Hyperparameters & Setup ---
# device      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vocab_size  = len(vocab)         # from your quora_data_prepro.json
# emb_dim     = 300
# enc_hidden  = 256
# dec_hidden  = 256
# latent_dim  = 64
# max_tgt_len = max_tgt_len        # from your metadata
# pad_idx     = vocab['<pad>']     # or whatever pad token index is
# num_filters = 100
# filter_sizes= [3,4,5]
# dropout     = 0.1
# γ_div       = 1.0                # weight for diversity loss
# lr_G        = 1e-4
# lr_D        = 1e-4
# num_epochs  = 20
# batch_size  = 64

# # --- 2. DataLoader (you’ve already done this) ---
# dataset = QQPDataset(source_ids, target_ids, source_lens, target_lens)
# loader  = DataLoader(dataset,
#                      batch_size=batch_size,
#                      shuffle=True,
#                      num_workers=2,
#                      drop_last=True)

# # --- 3. Model Instantiation ---
# model_G = ParaphraseGenerator(vocab_size, emb_dim,
#                               enc_hidden, dec_hidden,
#                               latent_dim, max_tgt_len,
#                               pad_idx, dropout).to(device)
# model_D = Discriminator(vocab_size, emb_dim,
#                         num_filters, filter_sizes,
#                         pad_idx, dropout).to(device)

# # --- 4. Optimizers ---
# opt_G = Adam(model_G.parameters(), lr=lr_G, betas=(0.5, 0.999))
# opt_D = Adam(model_D.parameters(), lr=lr_D, betas=(0.5, 0.999))

# # --- 5. Training Loop ---
# for epoch in range(1, num_epochs+1):
#     model_G.train()
#     model_D.train()
#     total_D_loss = 0.0
#     total_G_loss = 0.0

#     for batch in loader:
#         # Move to device
#         src     = batch['src'].to(device)      # [B, T_src]
#         tgt     = batch['tgt'].to(device)      # [B, T_tgt]
#         src_len = batch['src_len'].to(device)
#         tgt_len = batch['tgt_len'].to(device)

#         B = src.size(0)

#         # --- 5.1 Generate two samples for diversity ---
#         z1 = torch.randn(B, latent_dim, device=device)
#         z2 = torch.randn(B, latent_dim, device=device)

#         # Modify your generator to return (logits, repr)
#         logits1, repr1 = model_G(src, src_len, z=z1, return_repr=True)
#         logits2, repr2 = model_G(src, src_len, z=z2, return_repr=True)

#         # Greedy decode for discriminator input
#         pred1 = logits1.argmax(dim=-1)
#         pred2 = logits2.argmax(dim=-1)

#         # --- 5.2 Discriminator update ---
#         opt_D.zero_grad()
#         D_real = model_D(src, tgt)
#         D_fake = model_D(src, pred1.detach())
#         loss_D = discriminator_loss(D_real, D_fake)
#         loss_D.backward()
#         opt_D.step()
#         total_D_loss += loss_D.item()

#         # --- 5.3 Generator update ---
#         opt_G.zero_grad()
#         # Fool the discriminator
#         D_fake_for_G = model_D(src, pred1)
#         loss_G_adv   = generator_adversarial_loss(D_fake_for_G)
#         # Diversity penalty
#         loss_div     = diversity_loss(repr1, repr2, z1, z2, lambda_div=1.0)
#         # (Optional) you could add a teacher-forcing MLE loss here
#         loss_G       = loss_G_adv + γ_div * loss_div
#         loss_G.backward()
#         opt_G.step()
#         total_G_loss += loss_G.item()

#     avg_D = total_D_loss / len(loader)
#     avg_G = total_G_loss / len(loader)
#     print(f"Epoch {epoch}/{num_epochs} → D_loss: {avg_D:.4f}, G_loss: {avg_G:.4f}")

#     # --- 5.4 Validation & Metrics (per epoch) ---
#     model_G.eval()
#     hyps, refs = [], []
#     with torch.no_grad():
#         for batch in val_loader:  # assume you created a val_loader similarly
#             src     = batch['src'].to(device)
#             src_len = batch['src_len'].to(device)
#             logits  = model_G(src, src_len, z=None, teacher_forcing_ratio=0.0)
#             preds   = logits.argmax(dim=-1).cpu().tolist()
#             targets = batch['tgt'].cpu().tolist()

#             # Convert IDs → words, strip pads
#             for p, t in zip(preds, targets):
#                 hyps.append(" ".join(idx2word[idx] for idx in p if idx!=pad_idx))
#                 refs.append(" ".join(idx2word[idx] for idx in t if idx!=pad_idx))

#     bleu4 = compute_bleu4(refs, hyps)
#     _, _, bert_f1 = compute_bertscore(refs, hyps)
#     print(f"  → Val BLEU-4: {bleu4:.2f}, BERTScore F1: {bert_f1:.4f}")



In [None]:
# # Number of epochs to train
# NUM_EPOCHS = 10

# for epoch in range(1, NUM_EPOCHS + 1):
#     # --- Training ---
#     G.train(); D.train()
#     total_D_loss, total_G_loss = 0.0, 0.0

#     for batch in loader:
#         src, tgt = batch['src'].to(DEVICE), batch['tgt'].to(DEVICE)
#         src_len, tgt_len = batch['src_len'].to(DEVICE), batch['tgt_len'].to(DEVICE)
#         B = src.size(0)

#         # Sample two latent codes for diversity
#         z1 = torch.randn(B, LATENT_DIM, device=DEVICE)
#         z2 = torch.randn(B, LATENT_DIM, device=DEVICE)

#         # Generator forward + repr
#         logits1, repr1 = G(src, src_len, z=z1, return_repr=True)
#         logits2, repr2 = G(src, src_len, z=z2, return_repr=True)
#         fake = logits1.argmax(dim=-1)

#         # Discriminator update
#         optD.zero_grad()
#         D_real = D(src, tgt)
#         D_fake = D(src, fake.detach())
#         loss_D = discriminator_loss(D_real, D_fake)
#         loss_D.backward()
#         optD.step()
#         total_D_loss += loss_D.item()

#         # Generator update
#         optG.zero_grad()
#         Df = D(src, fake)
#         loss_G_adv = generator_adversarial_loss(Df)
#         loss_div  = diversity_loss(repr1, repr2, z1, z2)
#         loss_G    = loss_G_adv + GAMMA_DIV * loss_div
#         loss_G.backward()
#         optG.step()
#         total_G_loss += loss_G.item()

#     avg_D = total_D_loss / len(loader)
#     avg_G = total_G_loss / len(loader)
#     print(f"Epoch {epoch}/{NUM_EPOCHS} — Train D_loss: {avg_D:.4f}, G_loss: {avg_G:.4f}")

#     # --- Validation ---
#     G.eval()
#     all_refs, all_hyps = [], []

#     with torch.no_grad():
#         for batch in val_loader:
#             src = batch['src'].to(DEVICE)
#             src_len = batch['src_len'].to(DEVICE)

#             # Generate with greedy decoding
#             logits = G(src, src_len, z=None, return_repr=False, tf_ratio=0.0)
#             preds  = logits.argmax(dim=-1).cpu().tolist()
#             targets= batch['tgt'].cpu().tolist()

#             for p, t in zip(preds, targets):
#                 # convert ID lists to strings, stripping PAD_IDX
#                 hyp = " ".join(idx2word[idx] for idx in p if idx != PAD_IDX)
#                 ref = " ".join(idx2word[idx] for idx in t if idx != PAD_IDX)
#                 all_hyps.append(hyp)
#                 all_refs.append(ref)

#     bleu4 = compute_bleu4(all_refs, all_hyps)
#     _, _, bert_f1 = compute_bertscore(all_refs, all_hyps)
#     print(f"           Val BLEU-4: {bleu4:.2f}, BERTScore F1: {bert_f1:.4f}")

# print("✅ Training complete!") 


### Grok Mod

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
import sacrebleu
from bert_score import score as bert_score

# Loss functions
def discriminator_loss(D_real: torch.Tensor, D_fake: torch.Tensor, label_smoothing: float = 0.1) -> torch.Tensor:
    real_labels = torch.full_like(D_real, 1.0 - label_smoothing)
    fake_labels = torch.full_like(D_fake, label_smoothing)
    loss_real = F.binary_cross_entropy(D_real, real_labels)
    loss_fake = F.binary_cross_entropy(D_fake, fake_labels)
    return loss_real + loss_fake

def generator_adversarial_loss(D_fake: torch.Tensor) -> torch.Tensor:
    real_labels = torch.ones_like(D_fake)
    return F.binary_cross_entropy(D_fake, real_labels)

def diversity_loss(y1_repr: torch.Tensor, y2_repr: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor, lambda_div: float = 1.0, eps: float = 1e-8) -> torch.Tensor:
    y1_repr = y1_repr.mean(dim=1)
    y2_repr = y2_repr.mean(dim=1)
    dist_y = (y1_repr - y2_repr).norm(p=2, dim=1)
    dist_z = (z1 - z2).norm(p=2, dim=1).clamp(min=eps)
    hinge = F.relu(lambda_div - dist_y / dist_z)
    return hinge.mean()

# # Metrics
# def compute_bleu4(references, hypotheses):
#     if not references or not hypotheses:
#         return 0.0
#     refs = [r for r in references if r.strip()]
#     hyps = [h for h in hypotheses if h.strip()]
#     if not refs or not hyps:
#         return 0.0
#     refs = [refs]
#     bleu = sacrebleu.corpus_bleu(hyps, refs)
#     return bleu.score

# def compute_bertscore(references, hypotheses, lang='en', rescale_with_baseline=True):
#     P, R, F1 = bert_score(hypotheses, references, lang=lang, rescale_with_baseline=rescale_with_baseline)
#     return P.mean().item(), R.mean().item(), F1.mean().item()

# Hyperparameters & Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(vocab)  # 27697
emb_dim = 300
enc_hidden = 256
dec_hidden = 256
latent_dim = 64
max_tgt_len = 25  # Will adjust after inspection
pad_idx = vocab['<pad>']  # 0
num_filters = 100
filter_sizes = [3, 4, 5]
dropout = 0.1
gamma_div = 1.0
lr_G = 1e-4
lr_D = 1e-4
num_epochs = 10

# Inspect data (add to data loading code or run separately)
# print(f"Max target length: {target_ids.shape[1]}")
# print(f"Source IDs min: {source_ids.min()}, max: {source_ids.max()}")
# print(f"Target IDs min: {target_ids.min()}, max: {target_ids.max()}")

# Models
model_G = ParaphraseGenerator(
    vocab_size, emb_dim, enc_hidden, dec_hidden, latent_dim, max_tgt_len, pad_idx, dropout
).to(device)
model_D = Discriminator(
    vocab_size, emb_dim, num_filters, filter_sizes, pad_idx, dropout
).to(device)

# Optimizers
opt_G = Adam(model_G.parameters(), lr=lr_G, betas=(0.5, 0.999))
opt_D = Adam(model_D.parameters(), lr=lr_D, betas=(0.5, 0.999))

# Training Loop
print("Training Starting")
for epoch in range(1, num_epochs + 1):
    model_G.train()
    model_D.train()
    total_D_loss = 0.0
    total_G_loss = 0.0

    for batch_idx, batch in enumerate(train_loader):
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        src_len = batch['src_len'].to(device)
        tgt_len = batch['tgt_len'].to(device)
        B = src.size(0)

        z1 = torch.randn(B, latent_dim, device=device)
        z2 = torch.randn(B, latent_dim, device=device)

        # Generator forward (no teacher forcing for GAN steps)
        logits1, repr1 = model_G(src, src_len, z=z1, return_repr=True, tf_ratio=0.0)
        logits2, repr2 = model_G(src, src_len, z=z2, return_repr=True, tf_ratio=0.0)
        pred1 = logits1.argmax(dim=-1)
        pred2 = logits2.argmax(dim=-1)


        # Debug: Print sample predictions
        if batch_idx == 0 and epoch == 1:
            sample_pred = pred1[0].cpu().tolist()
            sample_tgt = tgt[0].cpu().tolist()
            print("Sample pred:", " ".join(idx2word[idx] for idx in sample_pred if idx != pad_idx))
            print("Sample ref:", " ".join(idx2word[idx] for idx in sample_tgt if idx != pad_idx))

        opt_D.zero_grad()
        D_real = model_D(src, tgt)
        D_fake = model_D(src, pred1.detach())
        loss_D = discriminator_loss(D_real, D_fake, label_smoothing=0.1)
        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(model_D.parameters(), max_norm=1.0)
        opt_D.step()
        total_D_loss += loss_D.item()

        opt_G.zero_grad()
        D_fake_for_G = model_D(src, pred1)
        loss_G_adv = generator_adversarial_loss(D_fake_for_G)
        loss_div = diversity_loss(repr1, repr2, z1, z2, lambda_div=1.0)
        loss_G = loss_G_adv + gamma_div * loss_div
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(model_G.parameters(), max_norm=1.0)
        opt_G.step()
        total_G_loss += loss_G.item()

    avg_D = total_D_loss / len(train_loader)
    avg_G = total_G_loss / len(train_loader)
    print(f"Epoch {epoch}/{num_epochs} → Train D_loss: {avg_D:.4f}, G_loss: {avg_G:.4f}")

    model_G.eval()
    all_refs, all_hyps = [], []
    
    with torch.no_grad():
        for batch in val_loader:
            src     = batch['src'].to(device)
            src_len = batch['src_len'].to(device)
    
            # Greedy decode – no teacher forcing, no latent z (sample internally)
            logits = model_G(src, src_len, z=None, return_repr=False, tf_ratio=0.0)
            preds  = logits.argmax(dim=-1).cpu().tolist()
            targets= batch['tgt'].cpu().tolist()
    
            for p, t in zip(preds, targets):
                hyp = " ".join(idx2word[idx] for idx in p if idx != pad_idx)
                ref = " ".join(idx2word[idx] for idx in t if idx != pad_idx)
                all_hyps.append(hyp)
                all_refs.append(ref)

# BLEU-4 only
bleu4 = compute_bleu4(all_refs, all_hyps)
print(f"  → Val BLEU-4: {bleu4:.2f}")