In [13]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm

import sys
sys.path.append('.')

from src.dataloader import SimpleTokenDataset
from src.transformer   import SANETokenAutoencoderWithRotation 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ["PYTORCH_ENABLE_FLASH_ATTN"] = "0"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False


In [14]:
checkpoint_path = "checkpoints/sane_asteroid6_with_rotation_final.pt"
if not os.path.isfile(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' nicht gefunden.")

model = SANETokenAutoencoderWithRotation(
    token_dim=2,
    d_model=64,
    nhead=4,
    num_layers=2,
    dim_feedforward=256,
    dropout=0.1,
    level_embed_dim=16,
    num_rot_classes=6, 
    rot_hidden=32
).to(device)

state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()


Modell und Checkpoint geladen. Encoder+Decoder im Eval-Modus.


  state_dict = torch.load(checkpoint_path, map_location=device)


In [15]:
import torch

def compute_latent_from_batch(model, tokens, abs_norm, p_norm, levels):
    """
    Führt den reinen ENCODER‐Teil von SANETokenAutoencoderWithRotation aus und gibt
    den 64-dimensionalen Latent-Vektor zurück (mean pooling über W).

    Inputs (alle bereits auf `device` verschoben):
      tokens:   FloatTensor [B, W, 2]
      abs_norm: FloatTensor [B, W, 1]
      p_norm:   FloatTensor [B, W, 1]
      levels:   LongTensor  [B, W]

    Output:
      z:        FloatTensor [B, 64]  (mean über W)
    """
    # 1) Token‐Embedding
    tok_emb = model.token_embed(tokens)   # [B, W, 64]

    # 2) Positional‐Embedding
    lvl_emb = model.level_emb(levels)     # [B, W, 16]
    pos_cat = torch.cat([abs_norm, p_norm, lvl_emb], dim=-1)  # [B, W, 18]
    pos_emb = model.pos_proj(pos_cat)     # [B, W, 64]

    # 3) Summe aus Token + Position
    x = tok_emb + pos_emb                 # [B, W, 64]

    # 4) Transformer‐Encoder‐Schichten
    for blk in model.encoder_blocks:
        x = blk(x)                        # [B, W, 64]

    # 5) Mean‐Pooling über W
    z = x.mean(dim=1)                     # [B, 64]
    return z


In [16]:
single_model_id = "asteroid6__x_180_000_000__checkpoints__final"

# Wir initialisieren ein Dataset, das nur genau diese eine model_id lädt
dataset_obj = SimpleTokenDataset(
    token_dir="prepared_objects_first_4_levels",  # Pfad anpassen
    model_ids=[single_model_id],
    window_size=256,
    augment=False
)

loader_obj = DataLoader(
    dataset_obj,
    batch_size=8,
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

print(f"Anzahl Fenster für {single_model_id}: {len(dataset_obj)}")


Anzahl Fenster für asteroid6__x_180_000_000__checkpoints__final: 125313


  tokens = torch.load(tpath).float()        # [N_i, 2]
  positions_raw = torch.load(ppath).float() # [N_i, 3]


In [17]:
all_latents = []

with torch.no_grad():
    for batch in tqdm(loader_obj, desc=f"Latents sammeln für {single_model_id}", unit="Batch"):
        tokens   = batch["tokens"].to(device)    # [B, W, 2]
        abs_norm = batch["abs_norm"].to(device)  # [B, W, 1]
        p_norm   = batch["p_norm"].to(device)    # [B, W, 1]
        levels   = batch["levels"].to(device)    # [B, W]

        z_batch = compute_latent_from_batch(model, tokens, abs_norm, p_norm, levels)  # [B, 64]
        all_latents.append(z_batch.cpu()) 

# Alle Batch‐Latents zu einem Tensor zusammenfügen: [N_windows, 64]
all_latents = torch.cat(all_latents, dim=0)  # FloatTensor, shape = [num_windows, 64]
print(f"Shape aller Latents: {all_latents.shape}")


Latents sammeln für asteroid6__x_180_000_000__checkpoints__final: 100%|██████████| 15665/15665 [00:52<00:00, 296.90Batch/s]

Shape aller Latents: torch.Size([125313, 64])





In [22]:
# In NumPy‐Array konvertieren
latents_np = all_latents.numpy()  # Shape [num_windows, 64]

# DataFrame mit 64 Spalten: latent_0 … latent_63
columns = [f"latent_{i}" for i in range(latents_np.shape[1])]
df_latents = pd.DataFrame(latents_np, columns=columns)

# Zeige erste 5 Zeilen
df_latents.shape


(125313, 64)

In [21]:
global_latent = df_latents.mean(axis=0)  # Series mit 64 Werten
global_latent.values


array([-1.0787305 ,  0.86425096,  0.24380201, -0.19859868, -0.63663256,
        0.5050171 ,  0.228165  , -0.02080427,  0.95969355, -0.8193425 ,
       -1.5879287 ,  1.0069091 , -0.00849972,  1.8441646 ,  0.20862353,
       -1.1891987 , -0.23626281, -2.6892693 ,  1.4019604 ,  0.1969252 ,
        0.9531218 ,  1.4496051 , -0.692852  ,  1.3389231 ,  0.09656914,
        0.55910605,  0.4791876 , -0.5994338 , -1.4496104 , -0.423876  ,
        0.15592796, -1.3908839 ,  1.4891913 , -0.70935416,  0.2986975 ,
       -0.20084505, -2.055918  ,  0.08354307, -0.64897424, -0.6869845 ,
        1.1565031 ,  0.76106966, -1.4867622 ,  1.2699127 ,  0.6818915 ,
        2.135164  ,  0.05533303,  1.8550663 , -0.65478253, -1.012297  ,
       -0.88713187, -1.7164764 ,  1.4929034 ,  0.01869997, -1.3419614 ,
       -0.13675399, -3.0545068 ,  0.556637  ,  0.3505878 ,  1.5049071 ,
        0.27042785, -0.85351056,  0.5527474 , -1.6015583 ], dtype=float32)