<a href="https://colab.research.google.com/github/pentius00/-personal-finacial-planner/blob/main/Paragraph_Multimodal_MoE_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python
# encoding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# **FIX 1: Import torch.amp explicitly**
from torch.cuda.amp import GradScaler, autocast
import torch.amp # <-- Added this

from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import random
import math
import os
import gc

# Helper: Clear CUDA cache
def clear_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

# ================================================================================
# 1. MODEL ARCHITECTURE
# (Re-created based on your logs: 223M params, MoE, Char-level)
# ================================================================================

class Expert(nn.Module):
    """A simple feed-forward network expert."""
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        return self.net(x)

class SparseMoELayer(nn.Module):
    """
    A Sparse Mixture of Experts layer.

    This layer routes each token to the top_k experts.
    It also computes the auxiliary load balancing loss.
    """
    def __init__(self, embed_dim, num_experts, top_k, dropout):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_experts = num_experts
        self.top_k = top_k

        # Gating network
        self.gate = nn.Linear(embed_dim, num_experts, bias=False)

        # Expert modules
        self.experts = nn.ModuleList([Expert(embed_dim, dropout) for _ in range(num_experts)])

        # For load balancing loss
        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(dim=1)
        self.register_buffer("mean_importance", torch.zeros(num_experts))


    def forward(self, x):
        # x shape: [batch_size, seq_len, embed_dim]
        # For paragraph embeddings, seq_len is 1
        # Let's assume x is [batch_size, embed_dim]
        if x.dim() == 3:
             # Take the CLS token or mean pool
             # Assuming [batch, seq, dim] -> [batch, dim]
             x = x.mean(dim=1)

        # 1. Get Gating Logits
        # gate_logits shape: [batch_size, num_experts]
        gate_logits = self.gate(x)

        # 2. Get Top-k Experts
        # top_k_gates shape: [batch_size, top_k]
        # top_k_indices shape: [batch_size, top_k]
        top_k_gates, top_k_indices = torch.topk(gate_logits, self.top_k, dim=1)

        # 3. Softmax over top-k
        top_k_gates = self.softmax(top_k_gates)

        # ==================== ** START OF CORRECTION ** ====================
        # 4. Calculate Aux Loss (Load Balancing)
        # This was the source of the RuntimeError.

        # f_i = mean *importance* (softmax probability) for each expert
        # gates_full shape: [batch_size, num_experts]
        gates_full = self.softmax(gate_logits)
        mean_importance = gates_full.mean(dim=0) # Shape: [num_experts]

        # P_i = fraction of *tokens* routed to each expert
        # We need to count how many times each expert was in the top_k
        # top_k_indices shape: [batch_size, top_k]

        # Create one-hot vectors for the chosen experts
        # one_hot_indices shape: [batch_size, top_k, num_experts]
        one_hot_indices = F.one_hot(top_k_indices, num_classes=self.num_experts).to(gate_logits.dtype)

        # Sum across the top_k dimension
        # This gives a multi-hot vector for each batch item, e.g., [0, 1, 0, 1, 0, 0, 0, 0]
        # chosen_experts shape: [batch_size, num_experts]
        chosen_experts = one_hot_indices.sum(dim=1)

        # Now, take the mean across the batch to get the fraction of times each expert was chosen
        # fraction_chosen shape: [num_experts]
        fraction_chosen = chosen_experts.mean(dim=0)

        # The loss is the dot product of these two vectors
        # This encourages the model to distribute *both* importance and *load*
        # Both tensors now correctly have shape [num_experts]
        aux_loss = (mean_importance * fraction_chosen).sum() * self.num_experts
        # ===================== ** END OF CORRECTION ** =====================


        # 5. Route Inputs to Experts
        # We create a sparse batch to send to experts

        final_output = torch.zeros_like(x)

        # flat_indices: [batch_size * top_k]
        flat_indices = top_k_indices.view(-1)

        # flat_gates: [batch_size * top_k]
        flat_gates = top_k_gates.view(-1)

        # flat_x: [batch_size, embed_dim] -> [batch_size * top_k, embed_dim]
        # We repeat the input for each expert it's routed to
        repeated_x = x.repeat_interleave(self.top_k, dim=0)

        # This loop is clearer than complex indexing, but can be optimized
        for i, expert in enumerate(self.experts):
            # Find which tokens are routed to this expert
            expert_mask = (flat_indices == i)
            if expert_mask.any():
                expert_inputs = repeated_x[expert_mask]
                expert_outputs = expert.forward(expert_inputs)

                # Apply gating score
                gated_outputs = expert_outputs * flat_gates[expert_mask].unsqueeze(1)

                # Add to final output
                # We need to map back to original batch index
                batch_indices = expert_mask.nonzero().squeeze(1) // self.top_k
                # Ensure gated_outputs has the same dtype as final_output
                final_output.index_add_(0, batch_indices, gated_outputs.to(final_output.dtype))

        return final_output, aux_loss


class CharacterEmbedder(nn.Module):
    """Encodes text from character IDs to a fixed embedding."""
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(hidden_dim * 2, embed_dim) # Project back to embed_dim

    def forward(self, x):
        # x shape: [batch_size, max_seq_len]
        embedded = self.dropout(self.embedding(x))

        # packed_output shape: [batch_size, max_seq_len, hidden_dim * 2]
        # hidden shape: [num_layers * 2, batch_size, hidden_dim]
        packed_output, hidden = self.gru(embedded)

        # Use the final hidden state
        # Concat fwd and bwd hidden states
        # hidden shape: [2, batch_size, hidden_dim] -> [batch_size, hidden_dim * 2]
        final_hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)

        # Project to final embedding
        # output shape: [batch_size, embed_dim]
        output = self.out_proj(final_hidden)
        return output


class ParagraphMoEModel(nn.Module):
    """
    The main model.
    Encodes text, passes it through MoE layers, and adds a projection head
    for contrastive learning.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vocab_size = config['vocab_size']
        self.embed_dim = config['embed_dim']
        self.hidden_dim = config['hidden_dim'] # For GRU

        # 1. Text Encoder
        # Note: Your log mentioned "Vision (CNN)" but the demo was text-only.
        # This implementation focuses on fixing the text pipeline first.
        self.text_encoder = CharacterEmbedder(
            self.vocab_size,
            self.embed_dim,
            self.hidden_dim,
            config['dropout']
        )

        # 2. MoE Layers
        self.moe_layers = nn.ModuleList([
            SparseMoELayer(
                self.embed_dim,
                config['num_experts'],
                config['top_k'],
                config['dropout']
            ) for _ in range(config['num_layers'])
        ])

        # 3. Projection Head (FOR CONTRASTIVE TRAINING ONLY)
        # This is critical. We train the projector, but throw it
        # away for inference.
        self.projection_head = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embed_dim // 2, self.embed_dim // 4)
        )

    def forward(self, x_char_ids):
        # x_char_ids shape: [batch_size, max_seq_len]

        # 1. Encode text to initial embedding
        # base_embedding shape: [batch_size, embed_dim]
        base_embedding = self.text_encoder(x_char_ids)

        # 2. Pass through MoE layers
        current_x = base_embedding
        total_aux_loss = 0.0

        for layer in self.moe_layers:
            current_x, aux_loss = layer(current_x)
            total_aux_loss += aux_loss

        # 3. Get Projection
        # final_embedding is used for inference
        # projection is used for contrastive loss
        final_embedding = current_x
        projection = self.projection_head(final_embedding)

        # Return both the *final_embedding* (for inference) and the
        # *projection* (for training loss)
        return final_embedding, projection, total_aux_loss / len(self.moe_layers)

    def embed_text(self, x_char_ids):
        """
        Inference-only method.
        Returns the final semantic embedding, NOT the projection.
        """
        with torch.no_grad():
            # 1. Encode
            base_embedding = self.text_encoder(x_char_ids)

            # 2. MoE Layers
            current_x = base_embedding
            for layer in self.moe_layers:
                # We ignore the aux_loss during inference
                current_x, _ = layer(current_x)

            # 3. Return the final embedding
            return current_x

# ================================================================================
# 2. CONTRASTIVE LOSS
# ================================================================================

class ContrastiveLoss(nn.Module):
    """
    NT-Xent Loss (from SimCLR).
    This is the core fix. It teaches the model to learn similarity.
    """
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        self.cosine_sim = nn.CosineSimilarity(dim=-1)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        # z_i, z_j shape: [batch_size, projection_dim]
        batch_size = z_i.shape[0]

        # Normalize projections
        z_i = F.normalize(z_i, p=2, dim=1)
        z_j = F.normalize(z_j, p=2, dim=1)

        # Concatenate all representations
        # z shape: [2 * batch_size, projection_dim]
        z = torch.cat([z_i, z_j], dim=0)

        # Calculate similarity matrix
        # sim_matrix shape: [2 * batch_size, 2 * batch_size]
        sim_matrix = torch.mm(z, z.T) / self.temperature

        # Create labels
        # The positive sample for z_i (at index k) is z_j (at index k + batch_size)
        # And vice-versa.
        labels = torch.arange(batch_size, device=z_i.device)
        labels = torch.cat([labels + batch_size, labels])

        # Mask out self-similarity (diagonal)
        sim_matrix = sim_matrix.masked_fill(
            torch.eye(2 * batch_size, device=z_i.device).bool(),
            -float('inf')
        )

        # Calculate loss
        loss = self.loss_fn(sim_matrix, labels)
        return loss

# ================================================================================
# 3. DATA PREPARATION (Using Real Data)
# ================================================================================

# Simple character-level tokenizer
class CharTokenizer:
    def __init__(self):
        # Basic ASCII + common chars. 0 is <PAD>, 1 is <UNK>
        self.vocab = {char: i + 2 for i, char in enumerate(
            " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"
        )}
        self.vocab['<PAD>'] = 0
        self.vocab['<UNK>'] = 1
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)

    def text_to_ids(self, text, max_len):
        ids = [self.vocab.get(char, self.vocab['<UNK>']) for char in text]
        ids = ids[:max_len]
        return ids

    def pad_batch(self, batch_ids):
        max_len_in_batch = max(len(ids) for ids in batch_ids)
        padded = []
        for ids in batch_ids:
            padding = [self.vocab['<PAD>']] * (max_len_in_batch - len(ids))
            padded.append(ids + padding)
        return torch.tensor(padded, dtype=torch.long)

# Global tokenizer
TOKENIZER = CharTokenizer()

class TextAugmentation:
    """Creates a simple augmented 'view' of text for contrastive learning."""
    def __init__(self, p_drop=0.1):
        self.p_drop = p_drop

    def __call__(self, text):
        # Simple character dropout
        if self.p_drop == 0:
            return text
        return "".join(c for c in text if random.random() > self.p_drop)

class ContrastiveDataset(Dataset):
    """
    Dataset that returns two 'views' of the same text.
    This is required for contrastive learning.
    """
    def __init__(self, paragraphs, transform_a, transform_b):
        self.paragraphs = [p for p in paragraphs if len(p.strip()) > 50] # Filter empty
        self.transform_a = transform_a
        self.transform_b = transform_b

    def __len__(self):
        return len(self.paragraphs)

    def __getitem__(self, idx):
        text = self.paragraphs[idx]
        view_1 = self.transform_a(text)
        view_2 = self.transform_b(text)
        return view_1, view_2

def build_dataloaders(config):
    print("Loading real data from 'wikitext-103-raw-v1'...")
    # This dataset is large. Let's use a subset for a demo.
    dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

    # Combine train and validation for a larger pool, then split
    all_text = list(dataset['train']['text']) + list(dataset['validation']['text'])

    # Filter and clean
    paragraphs = [p.strip() for p in all_text if len(p.strip()) > 100 and len(p.strip()) < config['max_seq_len']]
    random.shuffle(paragraphs)

    print(f"Loaded {len(paragraphs)} paragraphs.")

    # Split
    split_idx = int(len(paragraphs) * 0.95)
    train_paras = paragraphs[:split_idx]
    val_paras = paragraphs[split_idx:]

    # Create transforms (two different augmentations)
    transform_a = TextAugmentation(p_drop=0.1)
    transform_b = TextAugmentation(p_drop=0.1)

    train_dataset = ContrastiveDataset(train_paras, transform_a, transform_b)
    val_dataset = ContrastiveDataset(val_paras, transform_a, transform_b)

    # Custom collate function
    def collate_fn(batch):
        # batch is a list of (view_1, view_2) tuples
        view_1_list = [item[0] for item in batch]
        view_2_list = [item[1] for item in batch]

        tensor_1 = TOKENIZER.pad_batch([TOKENIZER.text_to_ids(t, config['max_seq_len']) for t in view_1_list])
        tensor_2 = TOKENIZER.pad_batch([TOKENIZER.text_to_ids(t, config['max_seq_len']) for t in view_2_list])

        return tensor_1, tensor_2

    # **FIX 2: Adjusted num_workers based on your warning log**
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2, # <-- Changed from 4
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2, # <-- Changed from 4
        pin_memory=True
    )

    return train_loader, val_loader

# ================================================================================
# 4. TRAINING & EVALUATION LOOP
# ================================================================================

def train_one_epoch(model, loader, loss_fn, optimizer, scaler, scheduler, device, config):
    model.train()
    total_loss = 0.0
    total_contrast_loss = 0.0
    total_aux_loss = 0.0

    pbar = tqdm(loader, desc=f"Epoch {config['epoch']}/{config['num_epochs']} [Train]")

    for tensor_1, tensor_2 in pbar:
        tensor_1, tensor_2 = tensor_1.to(device), tensor_2.to(device)

        optimizer.zero_grad()

        # **FIX 3: Corrected autocast for FutureWarning**
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            # Get projections and aux loss for both views
            _, proj_1, aux_loss_1 = model(tensor_1)
            _, proj_2, aux_loss_2 = model(tensor_2)

            # Calculate losses
            # Moved inside autocast
            contrast_loss = loss_fn(proj_1, proj_2)
            aux_loss = (aux_loss_1 + aux_loss_2) / 2

            # ** THIS IS THE CORRECTED LOSS **
            # We combine contrastive loss (for meaning) and aux loss (for balance)
            loss = contrast_loss + config['aux_loss_weight'] * aux_loss

        # Scaler for AMP
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        total_contrast_loss += contrast_loss.item()
        total_aux_loss += aux_loss.item()

        pbar.set_postfix(
            loss=f"{loss.item():.4f}",
            c_loss=f"{contrast_loss.item():.4f}",
            aux=f"{aux_loss.item():.4f}",
            lr=f"{scheduler.get_last_lr()[0]:.1e}"
        )

    scheduler.step()

    avg_loss = total_loss / len(loader)
    avg_contrast = total_contrast_loss / len(loader)
    avg_aux = total_aux_loss / len(loader)

    print(f"  Train Loss: {avg_loss:.4f} (Contrast: {avg_contrast:.4f}, Aux: {avg_aux:.4f})")
    return avg_loss

@torch.no_grad()
def validate(model, loader, loss_fn, device, config):
    model.eval()
    total_contrast_loss = 0.0

    pbar = tqdm(loader, desc="[Validate]")
    for tensor_1, tensor_2 in pbar:
        tensor_1, tensor_2 = tensor_1.to(device), tensor_2.to(device)

        # **FIX 3 (cont.): Corrected autocast for FutureWarning**
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            _, proj_1, _ = model(tensor_1)
            _, proj_2, _ = model(tensor_2)
            # Moved inside autocast
            contrast_loss = loss_fn(proj_1, proj_2)

        total_contrast_loss += contrast_loss.item()

    avg_loss = total_contrast_loss / len(loader)
    print(f"  Val Loss: {avg_loss:.4f}")
    return avg_loss

# ================================================================================
# 5. DEPLOYMENT & SEMANTIC EVALUATION
# ================================================================================

class ParagraphEmbedder:
    """
    Wrapper class for production inference.
    This is what you use in your application.
    """
    def __init__(self, model_path, config):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.config = config

        # Load the *architecture*
        self.model = ParagraphMoEModel(config).to(self.device)

        # Load the *weights*
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
        print(f"Model loaded from {model_path} on {self.device}")

    def embed_text(self, text: str):
        """Embeds a single string of text."""
        # Tokenize
        ids = TOKENIZER.text_to_ids(text, self.config['max_seq_len'])
        tensor = torch.tensor([ids], dtype=torch.long).to(self.device)

        # Get embedding
        # We call the `embed_text` method, NOT the `forward` method
        # to ensure we get the final embedding, not the projection.
        embedding = self.model.embed_text(tensor)
        return embedding.squeeze(0) # [embed_dim]

    def embed_batch(self, texts: list):
        """Embeds a batch of texts."""
        ids_list = [TOKENIZER.text_to_ids(t, self.config['max_seq_len']) for t in texts]
        tensor = TOKENIZER.pad_batch(ids_list).to(self.device)

        embeddings = self.model.embed_text(tensor)
        return embeddings # [batch_size, embed_dim]

def run_semantic_evaluation(embedder):
    """
    Performs the qualitative semantic check.
    If the model is trained correctly, sim_12 > sim_13.
    """
    print("\n" + "="*80)
    print("RUNNING SEMANTIC SIMILARITY DEMO")
    print("="*80)

    text_1 = "Machine learning and artificial intelligence are changing the world."
    text_2 = "Deep learning and neural networks are a key part of AI."
    text_3 = "My favorite cooking recipe is for lasagna and culinary arts."

    # Embed
    emb_1 = embedder.embed_text(text_1).unsqueeze(0)
    emb_2 = embedder.embed_text(text_2).unsqueeze(0)
    emb_3 = embedder.embed_text(text_3).unsqueeze(0)

    # Normalize for cosine similarity
    emb_1 = F.normalize(emb_1)
    emb_2 = F.normalize(emb_2)
    emb_3 = F.normalize(emb_3)

    # Calculate similarity
    sim_12 = F.cosine_similarity(emb_1, emb_2).item()
    sim_13 = F.cosine_similarity(emb_1, emb_3).item()

    print(f"  '{text_1[:20]}...' <-> '{text_2[:20]}...': {sim_12:.4f}")
    print(f"  '{text_1[:20]}...' <-> '{text_3[:20]}...': {sim_13:.4f}")

    if sim_12 > (sim_13 + 0.1): # Check for a meaningful difference
        print("\n  ✅ SUCCESS: Related texts are significantly more similar!")
    else:
        print(f"\n  ❌ FAILURE: Model cannot distinguish related/unrelated text.")
        print(f"  (Sim 1-2: {sim_12:.4f}, Sim 1-3: {sim_13:.4f})")
    print("="*80 + "\n")


# ================================================================================
# 6. MAIN SCRIPT
# ================================================================================

def main():
    clear_cache()

    # This config is based on your log, but with CRITICAL CORRECTIONS
    config = {
        # Model Arch
        'embed_dim': 1024,
        'hidden_dim': 512,      # Hidden dim for the character GRU
        'num_experts': 8,
        'top_k': 2,
        'num_layers': 3,
        'vocab_size': TOKENIZER.vocab_size,

        # Training
        'batch_size': 32,       # Increased from 8, lower if OOM
        'learning_rate': 0.0001,
        'weight_decay': 1e-05,
        'num_epochs': 50,
        'warmup_epochs': 5,     # Warmup is handled by scheduler
        'dropout': 0.1,
        'patience': 10,
        'gradient_clip': 1.0,   # Will be handled by scaler

        # ** CORRECTIONS **
        'aux_loss_weight': 0.1, # <-- INCREASED from 0.01 to fix imbalance
        'contrast_temp': 0.07,  # Temperature for contrastive loss

        # Data
        'max_seq_len': 512      # From your model card
    }

    print("="*80)
    print("CORRECTED PRODUCTION PARAGRAPH-BASED MULTIMODAL MoE WORLD MODEL")
    print("Training Objective: Contrastive Learning (NT-Xent)")
    print("="*80)
    print("Configuration:")
    for k, v in config.items():
        print(f"  {k}: {v}")
    print("-"*80)

    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Data
    train_loader, val_loader = build_dataloaders(config)

    # Model
    model = ParagraphMoEModel(config).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Loss, Optimizer, Scheduler
    loss_fn = ContrastiveLoss(temperature=config['contrast_temp']).to(device)
    optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    scheduler = CosineAnnealingLR(optimizer, T_max=(len(train_loader) * config['num_epochs']), eta_min=1e-7)

    # **FIX 4: Corrected GradScaler for FutureWarning**
    scaler = torch.amp.GradScaler()

    # Training Loop
    best_val_loss = float('inf')
    epochs_no_improve = 0
    checkpoint_dir = "checkpoints_corrected"
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_model_path = os.path.join(checkpoint_dir, "best_model.pth")

    print("\n" + "="*80)
    print("TRAINING START")
    print("="*80)

    for epoch in range(1, config['num_epochs'] + 1):
        config['epoch'] = epoch

        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, scaler, scheduler, device, config)
        val_loss = validate(model, val_loader, loss_fn, device, config)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"  ✓ New best model saved! Val Loss: {val_loss:.4f}")
        else:
            epochs_no_improve += 1
            print(f"  Validation loss did not improve. Patience: {epochs_no_improve}/{config['patience']}")

        if epochs_no_improve >= config['patience']:
            print(f"\nEarly stopping triggered at epoch {epoch}")
            break

        clear_cache()

    print("\n" + "="*80)
    print("TRAINING COMPLETE")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Best model saved at: {best_model_path}")
    print("="*80)

    # Final Evaluation
    print("\nLoading best model for final evaluation...")
    # Pass the config used for training
    embedder = ParagraphEmbedder(best_model_path, config)

    # Run the *real* test
    run_semantic_evaluation(embedder)

if __name__ == "__main__":
    main()

CORRECTED PRODUCTION PARAGRAPH-BASED MULTIMODAL MoE WORLD MODEL
Training Objective: Contrastive Learning (NT-Xent)
Configuration:
  embed_dim: 1024
  hidden_dim: 512
  num_experts: 8
  top_k: 2
  num_layers: 3
  vocab_size: 97
  batch_size: 32
  learning_rate: 0.0001
  weight_decay: 1e-05
  num_epochs: 50
  warmup_epochs: 5
  dropout: 0.1
  patience: 10
  gradient_clip: 1.0
  aux_loss_weight: 0.1
  contrast_temp: 0.07
  max_seq_len: 512
--------------------------------------------------------------------------------
Device: cuda
Loading real data from 'wikitext-103-raw-v1'...
Loaded 261014 paragraphs.
Model parameters: 208,003,840

TRAINING START


Epoch 1/50 [Train]:  85%|████████▌ | 6604/7749 [25:43<04:22,  4.37it/s, aux=2.9690, c_loss=0.2642, loss=0.5611, lr=1.0e-04]