In [None]:
from pathlib import Path
import copy

import torch
from torchviz import make_dot

from ssl_brainmet.models.build_nnunet_model import build_model_from_plans
from ssl_brainmet.utils import get_device

In [None]:
weights_path = "/home/vincent/repos/ssl-bm/weights/cnn3d_nnunet_local_global_100ep_checkpoint.pth"
# weights_path = "/home/vincent/repos/ssl-bm/weights/cnn3d_nnunet_local_global_1ep_checkpoint.pth"

In [None]:
device = get_device(0)

In [None]:
project_dir = Path(".").resolve().parents[0]

model = build_model_from_plans(
    project_dir / "ssl_brainmet/config/nnUNetPlans.json",
    project_dir / "ssl_brainmet/config/dataset.json",
    configuration="3d_fullres",
    deep_supervision=True,
).to(device)

In [None]:
initial_state_dict = copy.deepcopy(model.state_dict())

In [None]:
state_dict = torch.load(weights_path, map_location=device, weights_only=False)

In [None]:
state_dict.keys()

In [None]:
model.decoder.seg_layers

In [None]:
state_dict["model_state_dict"].keys()

In [None]:
state_dict_encoder_q = {key[len("encoder_q."):]:item for key, item in state_dict["model_state_dict"].items() if key.startswith("encoder_q.")}
state_dict_encoder_q = {key[len("model."):]:item for key, item in state_dict_encoder_q.items() if key.startswith("model.")}
state_dict_encoder_q = {k: v for k, v in state_dict_encoder_q.items() 
                       if not k.startswith("decoder.seg_layers")}

In [None]:
state_dict_encoder_q.keys()

In [None]:
load_info = model.load_state_dict(state_dict_encoder_q, strict=False)
print("Missing keys:", load_info.missing_keys)
print("Unexpected keys:", load_info.unexpected_keys)

In [None]:
encoder_params = {k: v for k, v in model.encoder.state_dict().items()}
decoder_encoder_params = {k: v for k, v in model.decoder.encoder.state_dict().items()}

for key in encoder_params:
    assert torch.equal(encoder_params[key], decoder_encoder_params[key]), f"Mismatch in {key}"
print("Encoder weights are consistent.")


In [None]:
# Compare the weights after loading
for name, param in model.state_dict().items():
    # Compute the difference between the new and initial parameters
    diff = torch.norm(param - initial_state_dict[name])
    # A small tolerance is used to account for floating point differences
    if diff > 1e-6:
        print(f"{name} has changed (difference norm: {diff.item()})")
    else:
        print(f"{name} remains unchanged.")

In [None]:
# model.state_dict()["decoder.stages.4.convs.0.conv.weight"]
initial_state_dict["encoder.stages.0.0.convs.0.norm.weight"]

In [None]:
x = torch.rand((1, 1, 64, 64, 64), device=device)
y = model(x)[0]

make_dot(y, params=dict(model.named_parameters())).render("model", format="pdf")