In [4]:
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from src.transformer import SANETokenAutoencoderWithRotation
from src.dataloader import SimpleTokenDataset

In [5]:
# 2) Pfad zum gespeicherten Checkpoint
checkpoint_path = "checkpoints/sane_asteroid6_with_rotation_final.pt"
if not os.path.isfile(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint {checkpoint_path} nicht gefunden.")

# 3) Modell-Instanz mit exakt denselben Hyperparametern
model = SANETokenAutoencoderWithRotation(
    token_dim=2,
    d_model=64,
    nhead=4,
    num_layers=2,
    dim_feedforward=256,
    dropout=0.1,
    level_embed_dim=16,
    num_rot_classes=6,
    rot_hidden=32
)

# 4) Gewichte laden
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()
print("Checkpoint geladen und Modell im Eval‐Modus.")

# 5) Datensatz & DataLoader für Evaluation erstellen
#    (Hier musst du wieder SimpleTokenDataset so instanziieren,
#     dass es dieselben 6 asteroid6-Dateien findet, die du zum Training verwendet hast.)
token_dir = "prepared_objects_first_4_levels"
model_ids = [
    "asteroid6__base_000_000_000__checkpoints__final",
    "asteroid6__compound_090_000_090__checkpoints__final",
    "asteroid6__x_180_000_000__checkpoints__final",
    "asteroid6__y_000_180_000__checkpoints__final",
    "asteroid6__z_000_000_120__checkpoints__final",
    "asteroid6__z_000_000_240__checkpoints__final"
]
dataset = SimpleTokenDataset(
    token_dir=token_dir,
    model_ids=model_ids,
    window_size=256,
    augment=False  # Beim Evaluieren brauchen wir keine Augmentation
)
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# 6) Kurze Evaluations‐Schleife (hier z. B. nur Accuracy auf Rotation‐Head)
correct = 0
total = 0
with torch.no_grad():
    for batch in dataloader:
        tokens   = batch["tokens"].to(device)
        abs_norm = batch["abs_norm"].to(device)
        p_norm   = batch["p_norm"].to(device)
        levels   = batch["levels"].to(device)
        labels   = batch["label_rot"].to(device)

        # Vorwärts‐Pass: wir interessieren uns nur für den Rotations‐Logits‐Ausgang
        _, logits_rot = model(tokens, abs_norm, p_norm, levels)  # [B, 6]
        preds = logits_rot.argmax(dim=1)                         # [B]
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Rotation‐Accuracy: {100 * correct / total:.2f}%")


Checkpoint geladen und Modell im Eval‐Modus.


  model.load_state_dict(torch.load(checkpoint_path, map_location=device))
  tokens = torch.load(tpath).float()        # [N_i, 2]
  positions_raw = torch.load(ppath).float() # [N_i, 3]


Rotation‐Accuracy: 99.96%
