In [12]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm

import sys
sys.path.append('.')

from src.dataloader import SimpleTokenDataset
from src.transformer   import SANETokenAutoencoderWithRotation 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ["PYTORCH_ENABLE_FLASH_ATTN"] = "0"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False


In [13]:
checkpoint_path = "checkpoints/9ObjectsModel.pt"
if not os.path.isfile(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' nicht gefunden.")

model = SANETokenAutoencoderWithRotation(
    token_dim=2,
    d_model=64,
    nhead=4,
    num_layers=2,
    dim_feedforward=256,
    dropout=0.1,
    level_embed_dim=16,
    num_rot_classes=6, 
    rot_hidden=32
).to(device)

state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()


  state_dict = torch.load(checkpoint_path, map_location=device)


SANETokenAutoencoderWithRotation(
  (token_embed): Linear(in_features=2, out_features=64, bias=True)
  (level_emb): Embedding(5, 16)
  (pos_proj): Linear(in_features=18, out_features=64, bias=True)
  (encoder_blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=256, out_features=64, bias=True)
      )
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=256, out_features=2

In [14]:
import torch

def compute_latent_from_batch(model, tokens, abs_norm, p_norm, levels):
    """
    Führt den reinen ENCODER‐Teil von SANETokenAutoencoderWithRotation aus und gibt
    den 64-dimensionalen Latent-Vektor zurück (mean pooling über W).

    Inputs (alle bereits auf `device` verschoben):
      tokens:   FloatTensor [B, W, 2]
      abs_norm: FloatTensor [B, W, 1]
      p_norm:   FloatTensor [B, W, 1]
      levels:   LongTensor  [B, W]

    Output:
      z:        FloatTensor [B, 64]  (mean über W)
    """
    # 1) Token‐Embedding
    tok_emb = model.token_embed(tokens)   # [B, W, 64]

    # 2) Positional‐Embedding
    lvl_emb = model.level_emb(levels)     # [B, W, 16]
    pos_cat = torch.cat([abs_norm, p_norm, lvl_emb], dim=-1)  # [B, W, 18]
    pos_emb = model.pos_proj(pos_cat)     # [B, W, 64]

    # 3) Summe aus Token + Position
    x = tok_emb + pos_emb                 # [B, W, 64]

    # 4) Transformer‐Encoder‐Schichten
    for blk in model.encoder_blocks:
        x = blk(x)                        # [B, W, 64]

    # 5) Mean‐Pooling über W
    z = x.mean(dim=1)                     # [B, 64]
    return z


In [20]:
single_model_id = "MainCharacterTPOSE__base_000_000_000__checkpoints__final"

# Wir initialisieren ein Dataset, das nur genau diese eine model_id lädt
dataset_obj = SimpleTokenDataset(
    token_dir="prepared_objects_first_4_levels",  # Pfad anpassen
    model_ids=[single_model_id],
    window_size=256,
    augment=False
)

loader_obj = DataLoader(
    dataset_obj,
    batch_size=8,
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

print(f"Anzahl Fenster für {single_model_id}: {len(dataset_obj)}")


Anzahl Fenster für music_speaker__z_000_000_240__checkpoints__final: 125313


In [21]:
all_latents = []

with torch.no_grad():
    for batch in tqdm(loader_obj, desc=f"Latents sammeln für {single_model_id}", unit="Batch"):
        tokens   = batch["tokens"].to(device)    # [B, W, 2]
        abs_norm = batch["abs_norm"].to(device)  # [B, W, 1]
        p_norm   = batch["p_norm"].to(device)    # [B, W, 1]
        levels   = batch["levels"].to(device)    # [B, W]

        z_batch = compute_latent_from_batch(model, tokens, abs_norm, p_norm, levels)  # [B, 64]
        all_latents.append(z_batch.cpu()) 

# Alle Batch‐Latents zu einem Tensor zusammenfügen: [N_windows, 64]
all_latents = torch.cat(all_latents, dim=0)  # FloatTensor, shape = [num_windows, 64]
print(f"Shape aller Latents: {all_latents.shape}")


Latents sammeln für music_speaker__z_000_000_240__checkpoints__final: 100%|██████████| 15665/15665 [00:52<00:00, 299.01Batch/s]


Shape aller Latents: torch.Size([125313, 64])


In [25]:
# In NumPy‐Array konvertieren
latents_np = all_latents.numpy()  # Shape [num_windows, 64]

# DataFrame mit 64 Spalten: latent_0 … latent_63
columns = [f"latent_{i}" for i in range(latents_np.shape[1])]
df_latents = pd.DataFrame(latents_np, columns=columns)

# Zeige erste 5 Zeilen
df_latents.shape


(125313, 64)

In [23]:
global_latent = df_latents.mean(axis=0)  # Series mit 64 Werten
global_latent.values


array([-0.23162867,  1.2482456 ,  0.19822937, -1.9758267 , -2.8409946 ,
       -0.37627804,  0.26451525,  0.7105661 , -0.6235712 ,  1.2104625 ,
       -0.14439674, -0.99199474,  0.11293446, -0.1376312 ,  1.8751167 ,
        0.34313974,  0.9404803 , -0.8048887 , -1.8447216 ,  0.81312925,
        2.5209963 , -0.36012328,  0.47635847,  2.8187659 , -0.4012906 ,
       -0.20478582,  1.0413792 , -1.9827175 , -2.4068227 ,  0.22127733,
        1.6317469 ,  2.0718517 , -1.4016395 , -0.8116531 , -0.17636248,
       -0.46374115, -0.40898362,  1.2023895 , -1.0967721 , -0.12835597,
        0.04417903, -0.43669647,  1.2313483 , -0.7766703 ,  1.2298833 ,
        0.32547882,  0.51655334,  0.23126729,  0.62396985,  0.99789757,
       -1.0175884 ,  0.71796644,  1.2664471 , -2.8725252 , -0.4628302 ,
        1.5604985 ,  0.99411625,  0.25451323, -1.5814369 ,  0.27120447,
        1.5933465 , -2.27768   , -0.6604226 ,  0.4256525 ], dtype=float32)

In [26]:
# 1) Series → NumPy → Tensor
#    global_latent.values liefert ein (64,)-Array
z_np = global_latent.values                # shape (64,)
z_tensor = torch.from_numpy(z_np).float()  # FloatTensor, shape (64,)

# 2) Batch-Dimension hinzufügen: [1, 64]
z_batch = z_tensor.unsqueeze(0)            # shape [1, 64]

# 3) Auf’s richtige Device schieben
z_batch = z_batch.to(device)
classifier = model.rotation_head.to(device)

# 4) Inferenz
with torch.no_grad():
    logits = classifier(z_batch)           # [1, num_rot_classes]
    pred   = torch.argmax(logits, dim=-1)  # [1]

# 5) Ergebnis
pred_class = pred.item()
print(f"Vorhergesagte Rotations-Klasse: {pred_class}")

Vorhergesagte Rotations-Klasse: 2
