# Comparison between dictionary keys of pretrained model and raw one

There is a descrepancy between real Pi3 encoder/decoder dictionary keys and how MVGGT handles them.

See changes in encoder/decoder treatment:
- Encoder: https://github.com/vasabi-root/mvggt/commit/0ce4d1f0248fbd1f1a160c600c0d01e1c6052329
- Decoder: https://github.com/vasabi-root/mvggt/commit/ff33729e9bc1d350972823b663de1250475455e7

Pi3 and Pi3x were downloaded from official Pi3 repo: https://github.com/yyfz/Pi3/tree/08d7288aaf4b0c08c8498bea7bafedc4672bb006:
- Pi3: https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors
- Pi3x: https://huggingface.co/yyfz233/Pi3X/resolve/main/model.safetensors

Notebook below shows, why this changes are necessary. MVGGT models were loaded with changed encoder/decoder treatment, because, as shown in [Aggregators](#aggregators) section, neither Pi3 nor Pi3x have 'aggregator.' keys.

In [1]:
import torch
from mvggt.models.mvggt_training import MVGGT



## Raw model

In [2]:
model_raw =  MVGGT(
    use_referring_segmentation=True,
    freeze_visual_modules=True,           # frozen Pi3/VGGT visual
    freeze_encoder=True,
    num_multimodal_layers=12,             # from the article
    multimodal_layer_selection='back',
    fusion_mode='pwa_only',               # can be 'interleaved', but in the article only "point-wise add" used
    use_lang_vision_fusion=False,
    use_controlnet_injection=True,
    use_lora=False,                       # True to fine-tune decoder
    text_model_name="roberta-base",
    load_vggt=True,                       # Pi3 encoder/decoder or vggt...
    use_pretrained_weights=False          # loads Pi3 VGGT decoder (as multimodal_decoder)
)

Loading vggt encoder <All keys matched successfully>
Loading vggt decoder <All keys matched successfully>
Freezing all visual modules for referring segmentation training.


## Pretrained

In [3]:
ckpt_path = '/home/bashmac/MIPT/VLM/mvggt/ckpts/pytorch_model.bin'
model_pretrained = MVGGT(
    use_referring_segmentation=True, 
    load_vggt=False, 
    train_conf=True, 
    ckpt=ckpt_path,
    use_pretrained_weights=False
)

Freezing the encoder.
[MVGGT] Load checkpoints from /home/bashmac/MIPT/VLM/mvggt/ckpts/pytorch_model.bin: <All keys matched successfully>


## Comparison

In [4]:
from typing import Tuple, List, Dict

def compare_models_dicts(model1_dict: Dict, model2_dict: Dict) -> Tuple[List[str]]:
    keys_model1 = model1_dict.keys()
    keys_model2 = model2_dict.keys()
    
    # Are there keys from model1, that is not presented in model2?
    kyes_only_in_model1 = []
    for key in keys_model1:
        if key not in keys_model2:
            kyes_only_in_model1.append(key)

    
    # Are there keys from model2, that is not presented in model1?
    kyes_only_in_model2 = []
    for key in keys_model2:
        if key not in keys_model1:
            kyes_only_in_model2.append(key)
            
    return kyes_only_in_model1, kyes_only_in_model2

In [5]:
kyes_only_in_raw, kyes_only_in_pretrained = compare_models_dicts(model_raw.state_dict(), model_pretrained.state_dict())

print(f'Q: How many keys from model_raw are not presented in model_pretrained?')
print(f'A: {len(kyes_only_in_raw)}')
print()
print(f'Q: How many keys from model_pretrained are not presented in model_raw?')
print(f'A: {len(kyes_only_in_pretrained)}')

Q: How many keys from model_raw are not presented in model_pretrained?
A: 0

Q: How many keys from model_pretrained are not presented in model_raw?
A: 0


# Comparison between dictionary keys of Pi3 model and Pi3x

- Pi3: https://huggingface.co/yyfz233/Pi3/resolve/main/model.safetensors (renamed to `ckpts/VGGT-1B/pi3.safetensors`)
- Pi3x: https://huggingface.co/yyfz233/Pi3X/resolve/main/model.safetensors (renamed to `ckpts/VGGT-1B/pi3x.safetensors`)

In [6]:
from safetensors.torch import load_file

In [7]:
pi3 = load_file('ckpts/VGGT-1B/pi3.safetensors')

In [8]:
pi3x = load_file('ckpts/VGGT-1B/pi3x.safetensors')

In [9]:
def print_pi3s_comparison(pi3, pi3x):
    kyes_only_in_pi3, kyes_only_in_pi3x = compare_models_dicts(pi3, pi3x)

    print(f'Q: How many keys from pi3 are not presented in pi3x?')
    print(f'A: {len(kyes_only_in_pi3)}')
    print()
    print(f'Q: How many keys from pi3x are not presented in pi3?')
    print(f'A: {len(kyes_only_in_pi3x)}')

In [10]:
print_pi3s_comparison(pi3, pi3x)

Q: How many keys from pi3 are not presented in pi3x?
A: 4

Q: How many keys from pi3x are not presented in pi3?
A: 667


## Encoders

In [11]:
pi3_encoder = {key:pi3[key] for key in pi3.keys() if key.startswith('encoder.')}
pi3x_encoder = {key:pi3x[key] for key in pi3x.keys() if key.startswith('encoder.')}

print_pi3s_comparison(pi3_encoder, pi3x_encoder)

Q: How many keys from pi3 are not presented in pi3x?
A: 0

Q: How many keys from pi3x are not presented in pi3?
A: 0


## Decoders

In [12]:
pi3_decoder = {key:pi3[key] for key in pi3.keys() if key.startswith('decoder.')}
pi3x_decoder = {key:pi3x[key] for key in pi3x.keys() if key.startswith('decoder.')}

print_pi3s_comparison(pi3_decoder, pi3x_decoder)

Q: How many keys from pi3 are not presented in pi3x?
A: 0

Q: How many keys from pi3x are not presented in pi3?
A: 0


## Aggregators

In [13]:
pi3_aggregator = {key:pi3[key] for key in pi3.keys() if key.startswith('aggregator.')}
pi3x_aggregator = {key:pi3x[key] for key in pi3x.keys() if key.startswith('aggregator.')}

print(f'Len of pi3 "aggregator.": {len(pi3_aggregator)}')
print(f'Len of pi3x "aggregator.": {len(pi3x_aggregator)}')


Len of pi3 "aggregator.": 0
Len of pi3x "aggregator.": 0


In [14]:
pi3_aggregator = {key:pi3[key] for key in pi3.keys() if 'aggregator' in key}
pi3x_aggregator = {key:pi3x[key] for key in pi3x.keys() if 'aggregator' in key}

print(f'Len of pi3 "aggregator.": {len(pi3_aggregator)}')
print(f'Len of pi3x "aggregator.": {len(pi3x_aggregator)}')

Len of pi3 "aggregator.": 0
Len of pi3x "aggregator.": 0
