# Explore model once it has been trained

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from network.models.architectures.CIL_multiview.CIL_multiview import CIL_multiview
import os
from configs import g_conf, set_type_of_process, merge_with_yaml

exp_batch = 'CILv2'
exp_name = 'CILv2_3cam_vit_Town01Full'

In [2]:
merge_with_yaml(os.path.join('configs', exp_batch, f'{exp_name}.yaml'))
g_conf.PROCESS_NAME = 'train_val'
g_conf.DATASET_PATH = '/datatmp/Datasets/yixiao/CARLA'

os.environ['DATASET_PATH'] = g_conf.DATASET_PATH

model = CIL_multiview(g_conf.MODEL_CONFIGURATION)

Loading pretrained weights from: https://download.pytorch.org/models/vit_b_32-d86f8d99.pth...


In [3]:
model

CIL_multiview(
  (encoder_embedding_perception): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear_1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (dropout_1): Dropout(p=0.0, inplace=False)
            (linear_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout_2): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_

## Load a pretrained checkpoint

In [4]:
from _utils.training_utils import check_saved_checkpoints

latest_checkpoint = check_saved_checkpoints(os.path.join('/home/dporres/Documents/PRUEBA/CILv2_multiview/VisionTFM',
                                                         '_results',
                                                         g_conf.EXPERIMENT_BATCH_NAME,
                                                         g_conf.EXPERIMENT_NAME,
                                                         'checkpoints'))

In [5]:
latest_checkpoint

'/home/dporres/Documents/PRUEBA/CILv2_multiview/VisionTFM/_results/CILv2/CILv2_3cam_vit_Town01Full/checkpoints/CILv2_multiview_attention_80_353016.pth'

In [6]:
import torch

checkpoint = torch.load(latest_checkpoint)

In [7]:
checkpoint['model']

OrderedDict([('_model.tfx_class_token',
              tensor([[[-4.4972e-03, -1.2049e-03, -8.8942e-03, -1.0903e-02, -1.3768e-02,
                         2.7268e-02,  3.2536e-03,  6.5453e-02,  5.7828e-03,  8.2787e-04,
                         9.0422e-04,  5.1996e-03, -9.1170e-03,  1.9343e-02, -4.0995e-03,
                         6.5558e-03, -9.0981e-03, -6.5428e-03, -6.4278e-03,  1.2826e-02,
                        -6.9174e-03, -4.0944e-03, -8.1965e-03,  4.5711e-02, -2.3963e-03,
                         9.9305e-04, -5.6588e-03, -4.5699e-03, -6.3736e-03,  7.2764e-03,
                         3.3657e-03,  2.8413e-04, -7.3598e-03, -1.3300e-03,  8.2440e-03,
                        -5.7169e-03,  1.0094e-02, -7.9648e-03, -4.1698e-02,  5.9487e-02,
                        -9.4528e-03, -2.8337e-03, -1.2906e-02,  1.1221e-02, -7.0067e-03,
                         2.7589e-04, -1.8944e-02, -3.0013e-02,  1.0540e-02,  8.2764e-03,
                         6.4406e-03, -5.1002e-03, -6.1087e-03,  7.2190

In [8]:
# As we can see, the keys have _model at the beginning; let's remove it
new_state_dict = {}
for k, v in checkpoint['model'].items():
    new_state_dict[k[7:]] = v

In [9]:
new_state_dict

{'tfx_class_token': tensor([[[-4.4972e-03, -1.2049e-03, -8.8942e-03, -1.0903e-02, -1.3768e-02,
            2.7268e-02,  3.2536e-03,  6.5453e-02,  5.7828e-03,  8.2787e-04,
            9.0422e-04,  5.1996e-03, -9.1170e-03,  1.9343e-02, -4.0995e-03,
            6.5558e-03, -9.0981e-03, -6.5428e-03, -6.4278e-03,  1.2826e-02,
           -6.9174e-03, -4.0944e-03, -8.1965e-03,  4.5711e-02, -2.3963e-03,
            9.9305e-04, -5.6588e-03, -4.5699e-03, -6.3736e-03,  7.2764e-03,
            3.3657e-03,  2.8413e-04, -7.3598e-03, -1.3300e-03,  8.2440e-03,
           -5.7169e-03,  1.0094e-02, -7.9648e-03, -4.1698e-02,  5.9487e-02,
           -9.4528e-03, -2.8337e-03, -1.2906e-02,  1.1221e-02, -7.0067e-03,
            2.7589e-04, -1.8944e-02, -3.0013e-02,  1.0540e-02,  8.2764e-03,
            6.4406e-03, -5.1002e-03, -6.1087e-03,  7.2190e-03, -2.1260e-02,
           -3.4526e-04, -3.3051e-02,  4.7215e-03, -1.3939e-02, -9.6334e-03,
           -8.1722e-03, -2.1636e-02, -1.7888e-03,  1.2470e-02,  5.501

In [69]:
from typing import OrderedDict

new_state_dict = OrderedDict(new_state_dict)

In [70]:
new_state_dict

In [71]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>