In [None]:
!pip install -q "transformers>=4.40.0" "datasets" "accelerate" "bitsandbytes" "sentencepiece" "tqdm"


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import math
import random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.manual_seed(42)
random.seed(42)

# Better matmul perf on Ampere+
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")


Using device: cuda


  _C._set_float32_matmul_precision(precision)


In [None]:
BASE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"  # smarter than GPT-Neo 2.7B

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # uses your A100 automatically
)

# We won't train the LLM in this prototype.
base_model.eval()
for p in base_model.parameters():
    p.requires_grad = False

hidden_size = base_model.config.hidden_size
print("Hidden size:", hidden_size)


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Hidden size: 4096


In [None]:
from datasets import load_dataset

# Heavier settings for A100
MAX_SEQ_LEN = 256
NUM_SAMPLES = 300_000    # more data → heavier training
BATCH_SIZE_ENC = 32      # can increase if VRAM allows

# Use the new parquet-based Wikipedia dataset
# This loads the 2023-11-01 English dump
wiki_ds = load_dataset(
    "wikimedia/wikipedia",
    "20231101.en",                 # <-- fix: use 20231101.en
    split=f"train[:{NUM_SAMPLES*2}]",
)

# Filter for non-trivial texts, then cap to NUM_SAMPLES
texts = [x["text"] for x in wiki_ds if len(x["text"].strip()) > 80][:NUM_SAMPLES]
print("Num texts used:", len(texts))


def encode_batch_to_hidden(batch_texts):
    enc = tokenizer(
        batch_texts,
        padding="max_length",
        truncation=True,
        max_length=MAX_SEQ_LEN,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"].to(device)
    attn_mask = enc["attention_mask"].to(device)
    with torch.no_grad():
        outputs = base_model(
            input_ids=input_ids,
            attention_mask=attn_mask,
            output_hidden_states=True,
        )
        last_hidden = outputs.hidden_states[-1]  # [B, L, H]
        # Mean-pool over non-pad tokens to get a single vector per sequence
        mask = attn_mask.unsqueeze(-1)  # [B, L, 1]
        summed = (last_hidden * mask).sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1)
        pooled = summed / counts
        # pooled: [B, H]
    return pooled.cpu()


all_hidden_vectors = []

for i in tqdm(range(0, len(texts), BATCH_SIZE_ENC), desc="Encoding to hidden (Wikipedia)"):
    batch_texts = texts[i : i + BATCH_SIZE_ENC]
    h = encode_batch_to_hidden(batch_texts)
    # store as float16 to save RAM; we'll cast to float32 on GPU during training
    all_hidden_vectors.append(h.half())

all_hidden_vectors = torch.cat(all_hidden_vectors, dim=0)  # [N, H], float16 on CPU
print("Hidden dataset shape:", all_hidden_vectors.shape, all_hidden_vectors.dtype)


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Num texts used: 300000


Encoding to hidden (Wikipedia):   0%|          | 0/9375 [00:00<?, ?it/s]

Hidden dataset shape: torch.Size([300000, 4096]) torch.float16


In [None]:
class HiddenDataset(Dataset):
    def __init__(self, hidden_tensor: torch.Tensor):
        self.hidden = hidden_tensor

    def __len__(self):
        return self.hidden.size(0)

    def __getitem__(self, idx):
        x = self.hidden[idx]
        return x

hidden_dataset = HiddenDataset(all_hidden_vectors)
vae_loader = DataLoader(
    hidden_dataset,
    batch_size=256,        # bigger batch for A100
    shuffle=True,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)
print("Batches per epoch (VAE):", len(vae_loader))


Batches per epoch (VAE): 1171


In [None]:
class LatentVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=1024, hidden_dim=8192):
        super().__init__()
        # Encoder: 2-layer big MLP
        self.enc_fc1 = nn.Linear(input_dim, hidden_dim)
        self.enc_fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu_head = nn.Linear(hidden_dim, latent_dim)
        self.logvar_head = nn.Linear(hidden_dim, latent_dim)
        # Decoder: mirror encoder
        self.dec_fc1 = nn.Linear(latent_dim, hidden_dim)
        self.dec_fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out_head = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.enc_fc1(x))
        h = F.relu(self.enc_fc2(h))
        mu = self.mu_head(h)
        logvar = self.logvar_head(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.dec_fc1(z))
        h = F.relu(self.dec_fc2(h))
        recon = self.out_head(h)
        return recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction="mean")
    # KL divergence to N(0, I)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 1e-3 * kld, recon_loss.detach(), kld.detach()


latent_dim = 1024
vae = LatentVAE(input_dim=hidden_size, latent_dim=latent_dim, hidden_dim=8192).to(device)
print(vae)


LatentVAE(
  (enc_fc1): Linear(in_features=4096, out_features=8192, bias=True)
  (enc_fc2): Linear(in_features=8192, out_features=8192, bias=True)
  (mu_head): Linear(in_features=8192, out_features=1024, bias=True)
  (logvar_head): Linear(in_features=8192, out_features=1024, bias=True)
  (dec_fc1): Linear(in_features=1024, out_features=8192, bias=True)
  (dec_fc2): Linear(in_features=8192, out_features=8192, bias=True)
  (out_head): Linear(in_features=8192, out_features=4096, bias=True)
)


In [None]:
VAE_EPOCHS = 20       # serious training
VAE_LR = 1e-3

opt_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR)

for epoch in range(1, VAE_EPOCHS + 1):
    vae.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kld = 0.0
    n_batches = 0
    for batch in tqdm(vae_loader, desc=f"VAE epoch {epoch}"):
        # batch is float16 on CPU → cast to float32 on GPU
        x = batch.to(device).float()
        opt_vae.zero_grad()
        recon, mu, logvar = vae(x)
        loss, recon_l, kld_l = vae_loss(recon, x, mu, logvar)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
        opt_vae.step()
        total_loss += loss.item()
        total_recon += recon_l.item()
        total_kld += kld_l.item()
        n_batches += 1
    print(
        f"[VAE] epoch {epoch} | loss={total_loss/n_batches:.4f} | "
        f"recon={total_recon/n_batches:.4f} | kld={total_kld/n_batches:.4f}"
    )

vae.eval()
torch.save(vae.state_dict(), "vae_mistral_wiki_heavy.pt")
print("Saved VAE to vae_mistral_wiki_heavy.pt")


VAE epoch 1:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 1 | loss=95919488849.1544 | recon=95919362800.0846 | kld=124349443.6143


VAE epoch 2:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 2 | loss=1.0886 | recon=1.0851 | kld=3.5205


VAE epoch 3:   0%|          | 0/1171 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

[VAE] epoch 3 | loss=0.9573 | recon=0.9541 | kld=3.2331


VAE epoch 4:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 4 | loss=0.8962 | recon=0.8931 | kld=3.1299


VAE epoch 5:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 5 | loss=0.8567 | recon=0.8536 | kld=3.0733


VAE epoch 6:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 6 | loss=0.8283 | recon=0.8253 | kld=3.0105


VAE epoch 7:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 7 | loss=0.8073 | recon=0.8043 | kld=2.9428


VAE epoch 8:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 8 | loss=0.7914 | recon=0.7886 | kld=2.8639


VAE epoch 9:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 9 | loss=0.7779 | recon=0.7751 | kld=2.7840


VAE epoch 10:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 10 | loss=0.7672 | recon=0.7645 | kld=2.7013


VAE epoch 11:   0%|          | 0/1171 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

[VAE] epoch 11 | loss=0.7582 | recon=0.7556 | kld=2.6320


VAE epoch 12:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 12 | loss=0.7500 | recon=0.7475 | kld=2.5654


VAE epoch 13:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 13 | loss=0.7437 | recon=0.7412 | kld=2.5075


VAE epoch 14:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 14 | loss=0.7375 | recon=0.7351 | kld=2.4576


VAE epoch 15:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 15 | loss=0.7327 | recon=0.7303 | kld=2.4176


VAE epoch 16:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 16 | loss=0.7287 | recon=0.7263 | kld=2.3837


VAE epoch 17:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 17 | loss=0.7241 | recon=0.7218 | kld=2.3524


VAE epoch 18:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 18 | loss=0.7207 | recon=0.7184 | kld=2.3274


VAE epoch 19:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 19 | loss=0.7171 | recon=0.7148 | kld=2.3037


VAE epoch 20:   0%|          | 0/1171 [00:00<?, ?it/s]

[VAE] epoch 20 | loss=0.7139 | recon=0.7116 | kld=2.2863
Saved VAE to vae_mistral_wiki_heavy.pt


In [None]:
latent_list = []

vae.eval()
with torch.no_grad():
    for batch in tqdm(vae_loader, desc="Encoding hidden → latent mu (heavy)"):
        x = batch.to(device).float()
        mu, logvar = vae.encode(x)
        latent_list.append(mu.cpu())

all_latents = torch.cat(latent_list, dim=0)  # [N, latent_dim]
print("Latent dataset shape:", all_latents.shape)


class LatentDataset(Dataset):
    def __init__(self, latents: torch.Tensor):
        self.latents = latents

    def __len__(self):
        return self.latents.size(0)

    def __getitem__(self, idx):
        z0 = self.latents[idx]  # data latent
        return z0

latent_dataset = LatentDataset(all_latents)
flow_loader = DataLoader(
    latent_dataset,
    batch_size=512,     # big batch for A100
    shuffle=True,
    drop_last=True,
    num_workers=4,
    pin_memory=True,
)
print("Batches per epoch (Flow):", len(flow_loader))


Encoding hidden → latent mu (heavy):   0%|          | 0/1171 [03:00<?, ?it/s]

Latent dataset shape: torch.Size([299776, 1024])
Batches per epoch (Flow): 585


In [None]:
class FlowMatchingNet(nn.Module):
    def __init__(self, latent_dim, hidden_dim=4096, num_layers=4):
        super().__init__()
        layers = []
        in_dim = latent_dim + 1  # z_t + t
        for i in range(num_layers):
            out_dim = hidden_dim
            layers.append(nn.Linear(in_dim, out_dim))
            in_dim = hidden_dim
        self.layers = nn.ModuleList(layers)
        self.final = nn.Linear(hidden_dim, latent_dim)

    def forward(self, z_t, t):
        """
        z_t: [B, latent_dim]
        t:   [B, 1] scaled to [0,1]
        """
        x = torch.cat([z_t, t], dim=-1)
        for layer in self.layers:
            x = F.relu(layer(x))
        v = self.final(x)
        return v


flow_model = FlowMatchingNet(latent_dim=latent_dim, hidden_dim=4096, num_layers=4).to(device)
print(flow_model)


FlowMatchingNet(
  (layers): ModuleList(
    (0): Linear(in_features=1025, out_features=4096, bias=True)
    (1-3): 3 x Linear(in_features=4096, out_features=4096, bias=True)
  )
  (final): Linear(in_features=4096, out_features=1024, bias=True)
)


In [None]:
FM_EPOCHS = 40          # heavy FM training
FM_LR = 1e-3

opt_flow = torch.optim.AdamW(flow_model.parameters(), lr=FM_LR)

for epoch in range(1, FM_EPOCHS + 1):
    flow_model.train()
    total_loss = 0.0
    for z0 in tqdm(flow_loader, desc=f"Flow epoch {epoch}"):
        z0 = z0.to(device).float()  # [B, D]
        B, D = z0.shape

        # Sample noise z1 and time t
        z1 = torch.randn_like(z0)  # N(0, I)
        t = torch.rand(B, 1, device=device)  # U[0,1]

        # Straight-line interpolation
        z_t = t * z0 + (1.0 - t) * z1

        # Target velocity v = z0 - z1
        v_target = z0 - z1

        opt_flow.zero_grad()
        v_pred = flow_model(z_t, t)
        loss = F.mse_loss(v_pred, v_target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(flow_model.parameters(), 1.0)
        opt_flow.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(flow_loader)
    print(f"[Flow] epoch {epoch} | loss={avg_loss:.4f}")

flow_model.eval()
torch.save(flow_model.state_dict(), "flow_mistral_wiki_heavy.pt")
print("Saved flow model to flow_mistral_wiki_heavy.pt")


Flow epoch 1:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 1 | loss=1.2316


Flow epoch 2:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 2 | loss=1.1682


Flow epoch 3:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 3 | loss=1.1584


Flow epoch 4:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0> 
^^^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^    self._shutdown_workers()^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^^    if w.is_alive():^
^ Exception ignored in: ^ 
 <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 
     Traceback (most rece

[Flow] epoch 4 | loss=1.1540


Flow epoch 5:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
   Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
       self._shutdown_workers() 
^Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^<function _MultiProcessingDat

[Flow] epoch 5 | loss=1.1515


Flow epoch 6:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
        self._shutdown_workers()    self._shutdown_workers()
self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
  F

[Flow] epoch 6 | loss=1.1488


Flow epoch 7:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 7 | loss=1.1468


Flow epoch 8:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 8 | loss=1.1451


Flow epoch 9:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 9 | loss=1.1438


Flow epoch 10:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 10 | loss=1.1428


Flow epoch 11:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 11 | loss=1.1420


Flow epoch 12:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 12 | loss=1.1412


Flow epoch 13:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 13 | loss=1.1408


Flow epoch 14:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^^    ^if w.is_alive():^

   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process'
      Exception ignored in:   <function 

[Flow] epoch 14 | loss=1.1396


Flow epoch 15:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()Exception ignored in:     
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>if w.is_alive():    
if w.is_alive():
 
 Traceback (most recent call last):
     File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

[Flow] epoch 15 | loss=1.1391


Flow epoch 16:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>



Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()            
self._shutdown_workers()self._shutdown_workers(

[Flow] epoch 16 | loss=1.1386


Flow epoch 17:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 17 | loss=1.1384


Flow epoch 18:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 18 | loss=1.1376


Flow epoch 19:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 19 | loss=1.1370


Flow epoch 20:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 20 | loss=1.1372


Flow epoch 21:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 21 | loss=1.1366


Flow epoch 22:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 22 | loss=1.1365


Flow epoch 23:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 23 | loss=1.1361


Flow epoch 24:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 24 | loss=1.1358


Flow epoch 25:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  Exception ignored in:    <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^^    ^self._shutdown_workers()
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^^^^    ^if w.is_alive():^^
^ 
Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0> 
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
     Traceback (most rece

[Flow] epoch 25 | loss=1.1356


Flow epoch 26:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>

Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__

        self._shutdown_workers()self._shutdown_workers()Traceback (most recent call last):


  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
          File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
self._shutdown_workers()

[Flow] epoch 26 | loss=1.1354


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>Exception ignored in: Exception ignored in: Traceback (most recent call last):

<function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
Traceback (most recent call last):


      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
Traceback (most recent call last):
Traceback (most recent call last):
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__

    self._shutdown_workers()          File "/usr/local/lib/

Flow epoch 27:   0%|          | 0/585 [00:02<?, ?it/s]

[Flow] epoch 27 | loss=1.1349


Flow epoch 28:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 28 | loss=1.1351


Flow epoch 29:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 29 | loss=1.1345


Flow epoch 30:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 30 | loss=1.1340


Flow epoch 31:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 31 | loss=1.1339


Flow epoch 32:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 32 | loss=1.1338


Flow epoch 33:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 33 | loss=1.1335


Flow epoch 34:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^^    ^^self._shutdown_workers()^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    ^if w.is_alive():^
^^ 
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
 Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0> assert self._parent_pid

[Flow] epoch 34 | loss=1.1336


Flow epoch 35:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>if w.is_alive():

Exception ignored in: Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0> 
     Traceback (most recent call last):
 self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
 
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
     

[Flow] epoch 35 | loss=1.1332


Flow epoch 36:   0%|          | 0/585 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7879fe545ee0>



Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
            self._shutdown_workers()    self._shutdown_workers()self._shutdown_workers()

[Flow] epoch 36 | loss=1.1332


Flow epoch 37:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 37 | loss=1.1329


Flow epoch 38:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 38 | loss=1.1331


Flow epoch 39:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 39 | loss=1.1329


Flow epoch 40:   0%|          | 0/585 [00:00<?, ?it/s]

[Flow] epoch 40 | loss=1.1327
Saved flow model to flow_mistral_wiki_heavy.pt


In [None]:
@torch.no_grad()
def sample_latent_trajectories(
    flow_model,
    num_trajectories: int = 4,
    num_steps: int = 50,
    diversity_weight: float = 0.1,
    latent_dim: int = latent_dim,
):
    """
    Returns: [K, latent_dim] sampled latents.
    """
    flow_model.eval()
    dt = 1.0 / num_steps

    # Start from noise at t=1
    z = torch.randn(num_trajectories, latent_dim, device=device)

    for step in range(num_steps, 0, -1):
        t_val = step / num_steps
        t = torch.full((num_trajectories, 1), t_val, device=device)

        # Flow-Matching velocity
        v = flow_model(z, t)
        # Integrate backward in time: z_{t-dt} = z_t - v * dt
        z = z - v * dt

        # Simple pairwise repulsion for diversity (not exact gradient, but works)
        if diversity_weight > 0.0 and num_trajectories > 1:
            rep = torch.zeros_like(z)
            for i in range(num_trajectories):
                for j in range(num_trajectories):
                    if i == j:
                        continue
                    diff = z[i] - z[j]
                    dist = diff.norm(p=2) + 1e-6
                    rep[i] += diff / dist  # normalized repulsion
            rep = rep / num_trajectories
            z = z + diversity_weight * rep * dt

    return z  # [K, D]


sampled_latents = sample_latent_trajectories(
    flow_model,
    num_trajectories=4,
    num_steps=50,
    diversity_weight=0.3,
    latent_dim=latent_dim,
)
print("Sampled latents shape:", sampled_latents.shape)


Sampled latents shape: torch.Size([4, 1024])


In [None]:
@torch.no_grad()
def decode_latents_to_hidden(vae, latents):
    vae.eval()
    z = latents.to(device).float()
    hidden = vae.decode(z)
    return hidden  # [K, hidden_size]


decoded_hidden = decode_latents_to_hidden(vae, sampled_latents)
print("Decoded hidden shape:", decoded_hidden.shape)


Decoded hidden shape: torch.Size([4, 4096])


In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def generate_baseline(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 64,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1,
):
    model.eval()
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = enc["input_ids"]  # [1, L]

    for _ in range(max_new_tokens):
        outputs = model(input_ids=input_ids)
        logits = outputs.logits  # [1, seq_len, vocab_size]
        next_token_logits = logits[:, -1, :]  # [1, vocab_size]

        # repetition penalty
        if repetition_penalty is not None and repetition_penalty > 1.0:
            token_ids = input_ids[0].unique()
            next_token_logits[0, token_ids] = (
                next_token_logits[0, token_ids] / repetition_penalty
            )

        # temperature
        if temperature is not None and temperature > 0.0:
            next_token_logits = next_token_logits / temperature

        # top-k
        if top_k is not None and top_k > 0:
            values, _ = torch.topk(next_token_logits, top_k)
            min_values = values[:, -1].unsqueeze(-1)
            next_token_logits = torch.where(
                next_token_logits < min_values,
                torch.full_like(next_token_logits, float("-inf")),
                next_token_logits,
            )

        # top-p
        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(
                next_token_logits, descending=True, dim=-1
            )
            probs = F.softmax(sorted_logits, dim=-1)
            cum_probs = torch.cumsum(probs, dim=-1)

            cutoff = cum_probs > top_p
            cutoff[..., 1:] = cutoff[..., :-1].clone()
            cutoff[..., 0] = 0
            sorted_logits[cutoff] = float("-inf")

            next_token_logits = torch.full_like(next_token_logits, float("-inf"))
            next_token_logits.scatter_(1, sorted_indices, sorted_logits)

        probs = F.softmax(next_token_logits, dim=-1)
        next_token_id = torch.multinomial(probs, num_samples=1)  # [1,1]

        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


@torch.no_grad()
def latent_logit_steered_generate(
    model,
    tokenizer,
    prompt: str,
    injected_hidden_vector: torch.Tensor,
    max_new_tokens: int = 64,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1,
    logit_alpha: float = 0.3,
    inject_steps: int = 12,
):
    """
    At each step:
      - get base logits from model
      - get latent_logits = lm_head(injected_hidden_vector)
      - final_logits = base_logits + logit_alpha * latent_logits
      - sample next token from final_logits
    """
    model.eval()
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = enc["input_ids"]  # [1, L]

    injected_vec = injected_hidden_vector.to(device)
    # latent_logits: [vocab_size]
    latent_logits = model.lm_head(injected_vec.to(model.dtype))  # [vocab_size]
    # center logits to make it a relative bias
    latent_logits = latent_logits - latent_logits.mean()
    latent_logits = latent_logits.unsqueeze(0)  # [1, vocab_size]

    for step in range(max_new_tokens):
        outputs = model(input_ids=input_ids)
        logits = outputs.logits  # [1, seq_len, vocab_size]
        base_next_token_logits = logits[:, -1, :]  # [1, vocab_size]

        # start from base logits
        next_token_logits = base_next_token_logits.clone()

        # add latent steering only for the first `inject_steps` tokens
        if step < inject_steps and logit_alpha != 0.0:
            next_token_logits = next_token_logits + logit_alpha * latent_logits

        # repetition penalty
        if repetition_penalty is not None and repetition_penalty > 1.0:
            token_ids = input_ids[0].unique()
            next_token_logits[0, token_ids] = (
                next_token_logits[0, token_ids] / repetition_penalty
            )

        # temperature
        if temperature is not None and temperature > 0.0:
            next_token_logits = next_token_logits / temperature

        # top-k
        if top_k is not None and top_k > 0:
            values, _ = torch.topk(next_token_logits, top_k)
            min_values = values[:, -1].unsqueeze(-1)
            next_token_logits = torch.where(
                next_token_logits < min_values,
                torch.full_like(next_token_logits, float("-inf")),
                next_token_logits,
            )

        # top-p
        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(
                next_token_logits, descending=True, dim=-1
            )
            probs = F.softmax(sorted_logits, dim=-1)
            cum_probs = torch.cumsum(probs, dim=-1)

            cutoff = cum_probs > top_p
            cutoff[..., 1:] = cutoff[..., :-1].clone()
            cutoff[..., 0] = 0
            sorted_logits[cutoff] = float("-inf")

            next_token_logits = torch.full_like(next_token_logits, float("-inf"))
            next_token_logits.scatter_(1, sorted_indices, sorted_logits)

        probs = F.softmax(next_token_logits, dim=-1)
        next_token_id = torch.multinomial(probs, num_samples=1)

        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


# ========= Usage =========
prompt = "Explain why climate change is real in a concise, fact-based way without flattery."

# pick one of your latent-derived hidden vectors
injected_vec = decoded_hidden[0]  # [hidden_size]

print("=== Baseline (no latent) ===")
baseline_out = generate_baseline(
    base_model,
    tokenizer,
    prompt,
    max_new_tokens=128,
    temperature=0.9,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.1,
)
print(baseline_out)

print("\n=== Latent-steered (logit injection) ===")
latent_out = latent_logit_steered_generate(
    base_model,
    tokenizer,
    prompt,
    injected_vec,
    max_new_tokens=128,
    temperature=0.9,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.1,
    logit_alpha=0.3,   # softer, more stable
    inject_steps=12,   # steering first 12 tokens
)
print(latent_out)


=== Baseline (no latent) ===
Explain why climate change is real in a concise, fact-based way without flattery. Climate change is a long-term alteration of temperature and typical weather patterns in a place. The planet's climate has been changing throughout history, but the current trend is especially concerning because human activities, such as burning fossil fuels for energy, deforestation, and agriculture, have significantly increased greenhouse gas concentrations in the Earth's atmosphere since the Industrial Revolution. These gases trap heat from the sun, leading to rising global temperatures, melting glaciers and ice caps, more frequent extreme weather events, sea level rise, and other adverse effects. Over 97% of climate scientists agree that

=== Latent-steered (logit injection) ===
Explain why climate change is real in a concise, fact-based way without flattery. and A  and O and D and and and and and and and and

Clclimate change is a long-term trend caused by human activities