In [2]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('logs/best_checkpoints/SWaT_parameters.pt', map_location=device)
for m in model.modules():
    if isinstance(m, torch.nn.GELU) and not hasattr(m, 'approximate'):
        m.approximate = 'none'  # ou 'tanh' si c'était le cas à l'entraînement

model.to(device).eval()

  model = torch.load('logs/best_checkpoints/SWaT_parameters.pt', map_location=device)


AnomalyTransformer(
  (linear_embedding): Linear(in_features=700, out_features=512, bias=True)
  (transformer_encoder): TransformerEncoder(
    (encoder_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention_layer): MultiHeadAttentionLayer(
          (word_fc_layers): ModuleList(
            (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
          )
          (output_fc_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward_layer): PositionWiseFeedForwardLayer(
          (first_fc_layer): Linear(in_features=512, out_features=2048, bias=True)
          (second_fc_layer): Linear(in_features=2048, out_features=512, bias=True)
          (activation_layer): GELU(approximate='none')
          (dropout_layer): Dropout(p=0.1, inplace=False)
        )
        (norm_layers): ModuleList(
          (0-1): 2 x LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        )
        (dropout_layer): Dropout(p=0.1, inplace=False)
 

In [5]:
torch.save(model.state_dict(), "weights/SWaT_state_dict.pt")

In [None]:
import math
import torch
from collections import OrderedDict
from models.anomaly_transformer import get_anomaly_transformer

def build_from_state_dict_path(state_path, device=None):
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sd = torch.load(state_path, map_location='cpu',weights_only=True)
    if isinstance(sd, dict) and 'state_dict' in sd:
        sd = sd['state_dict']
    # retire 'module.' si besoin
    sd = OrderedDict((k.replace('module.', ''), v) for k,v in sd.items())

    # --- inférences depuis les shapes ---
    lin_w = sd['linear_embedding.weight']          # [d_embed, input_d_data*patch_size]
    d_embed = lin_w.shape[0]
    in_prod = lin_w.shape[1]

    mlp0_w = sd['mlp_layers.0.weight']             # [hidden_dim, d_embed]
    hidden_dim = mlp0_w.shape[0]

    mlp2_w = sd['mlp_layers.2.weight']             # [output_d_data*patch_size, hidden_dim]
    out_prod = mlp2_w.shape[0]

    # relative position embedding -> [2*L-1, n_head]
    rpe_key = next(k for k in sd.keys() if 'relative_position_embedding_table' in k)
    rpe = sd[rpe_key]
    max_seq_len = (rpe.shape[0] + 1) // 2
    n_head = rpe.shape[1]

    # n_layer
    layer_ids = set()
    for k in sd.keys():
        if k.startswith('transformer_encoder.encoder_layers.'):
            parts = k.split('.')
            if len(parts) > 3 and parts[3].isdigit():
                layer_ids.add(int(parts[3]))
    n_layer = (max(layer_ids) + 1) if layer_ids else 6

    # patch_size, input_d_data, output_d_data, type de sortie
    g = math.gcd(in_prod, out_prod)
    patch_size = g
    input_d_data = in_prod // patch_size
    output_d_data = out_prod // patch_size
    loss_type = 'bce' if output_d_data == 1 and out_prod == patch_size else 'reconstruction'

    print(f"[inféré] d_embed={d_embed}, hidden_dim={hidden_dim}, n_head={n_head}, "
          f"n_layer={n_layer}, max_seq_len={max_seq_len}")
    print(f"[inféré] patch_size={patch_size}, input_d_data={input_d_data}, "
          f"output_d_data={output_d_data} ({loss_type})")

    # --- reconstruction exacte ---
    model = get_anomaly_transformer(
        input_d_data=input_d_data,
        output_d_data=output_d_data,
        patch_size=patch_size,
        d_embed=d_embed,
        hidden_dim_rate=hidden_dim / d_embed,   # 4.0 normalement
        max_seq_len=max_seq_len,
        positional_encoding=None,
        relative_position_embedding=True,
        transformer_n_layer=n_layer,
        transformer_n_head=n_head,
        dropout=0.1
    )

    missing, unexpected = model.load_state_dict(sd, strict=False)
    print("missing:", missing)
    print("unexpected:", unexpected)

    model.to(device).eval()
    return model

# --- utilisation ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_from_state_dict_path("weights/SWaT_state_dict.pt", device=device)

[inféré] d_embed=512, hidden_dim=2048, n_head=8, n_layer=6, max_seq_len=512
[inféré] patch_size=14, input_d_data=50, output_d_data=1 (bce)
missing: []
unexpected: []


  sd = torch.load(state_path, map_location='cpu')


In [7]:
model.eval()
with torch.no_grad():
    # dummy batch: (B, n_token, d_data) avec n_token = max_seq_len*patch_size
    n_token = model.max_seq_len * model.patch_size
    d_data  = model.linear_embedding.in_features // model.patch_size
    x = torch.rand(2, n_token, d_data, device=next(model.parameters()).device)
    y = model(x)
    print(x.shape, "→", y.shape)

torch.Size([2, 7168, 50]) → torch.Size([2, 7168, 1])


In [None]:
torch.save(model.state_dict(), "weights/SWaT_state_dict_clean.pt")