In [24]:
from pathlib import Path
import copy

import torch
from torchviz import make_dot

from ssl_brainmet.models.build_nnunet_model import build_model_from_plans
from ssl_brainmet.utils import get_device

In [None]:
# weights_path = "/home/vincent/repos/ssl-bm/weights/cnn3d_nnunet_local_global_100ep_checkpoint.pth"
weights_path = "/home/vincent/repos/ssl-bm/weights/cnn3d_nnunet_local_global_1ep_checkpoint.pth"

In [3]:
device = get_device(0)

Using cuda:0


In [4]:
project_dir = Path(".").resolve().parents[0]

model = build_model_from_plans(
    project_dir / "ssl_brainmet/config/nnUNetPlans.json",
    project_dir / "ssl_brainmet/config/dataset.json",
    configuration="3d_fullres",
    deep_supervision=True,
).to(device)

In [5]:
initial_state_dict = copy.deepcopy(model.state_dict())

In [6]:
state_dict = torch.load(weights_path, map_location=device, weights_only=False)

In [7]:
state_dict.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scaler_state_dict', 'train_losses', 'train_sim_pos', 'train_sim_neg', 'train_sim_pos_d', 'train_sim_neg_d', 'val_accuracies', 'f1_scores', 'mf1_scores', 'mae_scores', 'r2_scores'])

In [8]:
model.decoder.seg_layers

ModuleList(
  (0): Conv3d(320, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (1): Conv3d(256, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (2): Conv3d(128, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (3): Conv3d(64, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (4): Conv3d(32, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)

In [9]:
state_dict["model_state_dict"].keys()

odict_keys(['queue', 'queue_ptr', 'local_queue', 'local_queue_ptr', 'encoder_q.encoder.stages.0.0.convs.0.conv.weight', 'encoder_q.encoder.stages.0.0.convs.0.conv.bias', 'encoder_q.encoder.stages.0.0.convs.0.norm.weight', 'encoder_q.encoder.stages.0.0.convs.0.norm.bias', 'encoder_q.encoder.stages.0.0.convs.0.all_modules.0.weight', 'encoder_q.encoder.stages.0.0.convs.0.all_modules.0.bias', 'encoder_q.encoder.stages.0.0.convs.0.all_modules.1.weight', 'encoder_q.encoder.stages.0.0.convs.0.all_modules.1.bias', 'encoder_q.encoder.stages.0.0.convs.1.conv.weight', 'encoder_q.encoder.stages.0.0.convs.1.conv.bias', 'encoder_q.encoder.stages.0.0.convs.1.norm.weight', 'encoder_q.encoder.stages.0.0.convs.1.norm.bias', 'encoder_q.encoder.stages.0.0.convs.1.all_modules.0.weight', 'encoder_q.encoder.stages.0.0.convs.1.all_modules.0.bias', 'encoder_q.encoder.stages.0.0.convs.1.all_modules.1.weight', 'encoder_q.encoder.stages.0.0.convs.1.all_modules.1.bias', 'encoder_q.encoder.stages.1.0.convs.0.conv.w

In [11]:
state_dict_encoder_q = {key[len("encoder_q."):]:item for key, item in state_dict["model_state_dict"].items() if key.startswith("encoder_q.")}
state_dict_encoder_q = {key[len("model."):]:item for key, item in state_dict_encoder_q.items() if key.startswith("model.")}
state_dict_encoder_q = {k: v for k, v in state_dict_encoder_q.items() 
                       if not k.startswith("decoder.seg_layers")}

In [12]:
state_dict_encoder_q.keys()

dict_keys(['encoder.stages.0.0.convs.0.conv.weight', 'encoder.stages.0.0.convs.0.conv.bias', 'encoder.stages.0.0.convs.0.norm.weight', 'encoder.stages.0.0.convs.0.norm.bias', 'encoder.stages.0.0.convs.0.all_modules.0.weight', 'encoder.stages.0.0.convs.0.all_modules.0.bias', 'encoder.stages.0.0.convs.0.all_modules.1.weight', 'encoder.stages.0.0.convs.0.all_modules.1.bias', 'encoder.stages.0.0.convs.1.conv.weight', 'encoder.stages.0.0.convs.1.conv.bias', 'encoder.stages.0.0.convs.1.norm.weight', 'encoder.stages.0.0.convs.1.norm.bias', 'encoder.stages.0.0.convs.1.all_modules.0.weight', 'encoder.stages.0.0.convs.1.all_modules.0.bias', 'encoder.stages.0.0.convs.1.all_modules.1.weight', 'encoder.stages.0.0.convs.1.all_modules.1.bias', 'encoder.stages.1.0.convs.0.conv.weight', 'encoder.stages.1.0.convs.0.conv.bias', 'encoder.stages.1.0.convs.0.norm.weight', 'encoder.stages.1.0.convs.0.norm.bias', 'encoder.stages.1.0.convs.0.all_modules.0.weight', 'encoder.stages.1.0.convs.0.all_modules.0.bias

In [13]:
load_info = model.load_state_dict(state_dict_encoder_q, strict=False)
print("Missing keys:", load_info.missing_keys)
print("Unexpected keys:", load_info.unexpected_keys)

Missing keys: ['decoder.seg_layers.0.weight', 'decoder.seg_layers.0.bias', 'decoder.seg_layers.1.weight', 'decoder.seg_layers.1.bias', 'decoder.seg_layers.2.weight', 'decoder.seg_layers.2.bias', 'decoder.seg_layers.3.weight', 'decoder.seg_layers.3.bias', 'decoder.seg_layers.4.weight', 'decoder.seg_layers.4.bias']
Unexpected keys: []


In [14]:
encoder_params = {k: v for k, v in model.encoder.state_dict().items()}
decoder_encoder_params = {k: v for k, v in model.decoder.encoder.state_dict().items()}

for key in encoder_params:
    assert torch.equal(encoder_params[key], decoder_encoder_params[key]), f"Mismatch in {key}"
print("Encoder weights are consistent.")


Encoder weights are consistent.


In [15]:
# Compare the weights after loading
for name, param in model.state_dict().items():
    # Compute the difference between the new and initial parameters
    diff = torch.norm(param - initial_state_dict[name])
    # A small tolerance is used to account for floating point differences
    if diff > 1e-6:
        print(f"{name} has changed (difference norm: {diff.item()})")
    else:
        print(f"{name} remains unchanged.")

encoder.stages.0.0.convs.0.conv.weight has changed (difference norm: 11.772833824157715)
encoder.stages.0.0.convs.0.conv.bias has changed (difference norm: 0.3675675094127655)
encoder.stages.0.0.convs.0.norm.weight has changed (difference norm: 0.772316575050354)
encoder.stages.0.0.convs.0.norm.bias has changed (difference norm: 1.0705313682556152)
encoder.stages.0.0.convs.0.all_modules.0.weight has changed (difference norm: 11.772833824157715)
encoder.stages.0.0.convs.0.all_modules.0.bias has changed (difference norm: 0.3675675094127655)
encoder.stages.0.0.convs.0.all_modules.1.weight has changed (difference norm: 0.772316575050354)
encoder.stages.0.0.convs.0.all_modules.1.bias has changed (difference norm: 1.0705313682556152)
encoder.stages.0.0.convs.1.conv.weight has changed (difference norm: 17.97256851196289)
encoder.stages.0.0.convs.1.conv.bias has changed (difference norm: 0.4190920293331146)
encoder.stages.0.0.convs.1.norm.weight has changed (difference norm: 0.7871344685554504

In [23]:
# model.state_dict()["decoder.stages.4.convs.0.conv.weight"]
initial_state_dict["encoder.stages.0.0.convs.0.norm.weight"]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       device='cuda:0')

In [27]:
x = torch.rand((1, 1, 64, 64, 64), device=device)
y = model(x)[0]

make_dot(y, params=dict(model.named_parameters())).render("model", format="pdf")

'model.pdf'

In [26]:
y

[tensor([[[[[-5.7778e-01, -8.2275e-01, -7.3934e-01,  ..., -6.8556e-01,
             -4.8030e-01, -7.3371e-01],
            [-7.0294e-01,  9.9052e-01, -1.0010e+00,  ...,  4.9746e-01,
             -3.4970e-02, -6.4559e-01],
            [-6.7027e-01, -2.0457e+00, -1.6103e+00,  ..., -9.8429e-01,
             -4.1108e-01, -8.1526e-01],
            ...,
            [ 7.1349e-01,  5.8210e-01, -5.7201e-01,  ...,  1.9105e-01,
             -6.3197e-01, -7.4499e-01],
            [-6.7234e-02, -5.3797e-01, -6.9846e-01,  ..., -5.1967e-01,
             -7.4954e-01, -6.2250e-01],
            [ 4.8861e-01,  2.7247e-01,  1.9501e-01,  ...,  1.3229e-01,
              9.8074e-02, -3.5016e-01]],
 
           [[ 8.4821e-01, -5.5637e-01, -4.1107e-01,  ..., -1.1152e+00,
             -9.9949e-01, -1.3037e+00],
            [-6.8005e-01, -1.0728e+00, -1.6292e-01,  ..., -1.2617e+00,
             -6.7991e-01, -7.5227e-01],
            [-8.1397e-02,  5.0459e-01,  4.4954e-02,  ..., -7.1294e-01,
             -9.6325e