In [1]:
import torch 
import os
import yaml 
from pathlib import Path
from dp_gs.util.args import ExperimentConfig
from transformers import AutoProcessor
from dp_gs.policy.model import Dinov2DiscretePolicy
from dp_gs.dataset.utils import default_vision_transform, aug_vision_transform
from dp_gs.dataset.image_dataset import SequenceDataset, VideoSampler, CollateFunction
from dp_gs.dataset.image_dataset_sim import SequenceDataset

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
model_ckpt_folder = "/shared/projects/icrl/dp_outputs/250208_1027"

ckpt1_path = "/shared/projects/icrl/dp_outputs/250208_1027/checkpoint_50.pt"
ckpt2_path = "/shared/projects/icrl/dp_outputs/250208_1027/checkpoint_95.pt"

ckpt1 = torch.load(ckpt1_path, weights_only=True)
ckpt2 = torch.load(ckpt2_path, weights_only=True)

In [4]:
# Compare the keys of both checkpoints
keys_ckpt1 = set(ckpt1.keys())
keys_ckpt2 = set(ckpt2.keys())

if keys_ckpt1 != keys_ckpt2:
    print("The checkpoints have different keys.")
else:
    print("The checkpoints have the same keys.")

# Compare the values of both checkpoints
same_values = []
different_values = []

for key in keys_ckpt1:
    if torch.equal(ckpt1[key], ckpt2[key]):
        same_values.append(key)
    else:
        different_values.append(key)

print("Keys with the same values:")
for key in same_values:
    print(key)

print("\nKeys with different values:")
for key in different_values:
    print(key)

The checkpoints have the same keys.
Keys with the same values:
module.dino.norm.bias
module.dino.blocks.6.norm2.weight
module.dino.blocks.1.mlp.fc1.bias
module.dino.blocks.11.ls2.gamma
module.dino.blocks.4.ls2.gamma
module.decoder.blocks.1.cross_attention.rope.emb
module.dino.patch_embed.proj.weight
module.dino.blocks.1.attn.qkv.bias
module.dino.blocks.1.norm2.bias
module.dino.blocks.3.attn.proj.weight
module.dino.blocks.6.mlp.fc2.weight
module.dino.blocks.3.norm2.bias
module.dino.blocks.10.ls2.gamma
module.dino.blocks.5.mlp.fc1.weight
module.dino.blocks.5.norm2.bias
module.dino.blocks.8.mlp.fc2.bias
module.dino.blocks.10.norm2.bias
module.dino.blocks.8.attn.proj.weight
module.dino.blocks.1.ls1.gamma
module.decoder.blocks.3.self_attention.rope.inv_freq
module.dino.blocks.1.norm1.bias
module.dino.blocks.1.norm2.weight
module.dino.blocks.6.norm2.bias
module.dino.blocks.5.attn.qkv.weight
module.dino.blocks.6.ls1.gamma
module.dino.blocks.9.attn.proj.bias
module.dino.blocks.3.attn.qkv.bias


In [5]:
train_yaml_path = os.path.join(model_ckpt_folder, 'run.yaml')
args : ExperimentConfig = yaml.load(Path(train_yaml_path).read_text(), Loader=yaml.Loader)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
extra_kwargs = {
    'vocab_size': tokenizer.vocab_size + 1,  # eos token
    'num_tokens': args.shared_cfg.num_tokens
}
model = Dinov2DiscretePolicy(
    shared_cfg=args.shared_cfg,
    model_cfg=args.model_cfg,
    **(extra_kwargs if args.model_cfg.policy_type == "discrete" else {})
)

Some kwargs in processor config are unused and will not have any effect: time_horizon, min_token, vocab_size, scale, action_dim. 


In [14]:
ckpt2 = {
    key.replace("module.", ""): value for key, value in ckpt1.items()
}
model.load_state_dict(ckpt2, strict=True)
model = model.to(device)

In [16]:
if args.shared_cfg.s2:
    resolution = args.shared_cfg.image_size * 2
else:
    resolution = args.shared_cfg.image_size
base_vision_transform = default_vision_transform(resolution=resolution) # default img_size: Union[int, Tuple[int, int]] = 224,
aug_transform = aug_vision_transform(resolution=resolution) # default img_size: Union[int, Tuple[int, int]] = 224,
collate_fn = CollateFunction(
    args.shared_cfg.seq_length, 
    tokenizer=tokenizer, 
    max_token_length=args.shared_cfg.num_tokens
)

In [None]:
dataset_train = SequenceDataset(
    dataset_config=args.dataset_cfg,
    shared_config=args.shared_cfg,
    logging_config=args.logging_cfg,
    vision_transform=aug_transform,
    split="train",
)

Length of data before balancing:  1570
Length of data after balancing:  1304


100%|██████████| 1304/1304 [00:08<00:00, 162.08it/s]


Action statistics saved to  /shared/projects/icrl/dp_outputs/250208_1027/action_statistics.json
using numeric brightness and contrast augmentation
contrast range:  [0.8, 1.2]
brightness range:  [-0.1, 0.1]


In [17]:
dataset_train.vision_transform = base_vision_transform

In [18]:
from dp_gs.util.misc import MultiEpochsDataLoader
train_sampler = VideoSampler(
    dataset_train, 
    batch_size=args.shared_cfg.batch_size, 
    sequence_length=args.shared_cfg.seq_length, 
)
dataloader_train = MultiEpochsDataLoader(
    dataset_train,
    batch_sampler=train_sampler,
    num_workers=args.trainer_cfg.num_workers,
    collate_fn=collate_fn,
    pin_memory=False,
    prefetch_factor=4,
)

In [20]:
model.eval()
idx = 0 
for dataset_item in dataloader_train:
    if idx == 10:
        break 
    idx += 1
    for k, v in dataset_item.items():
        dataset_item[k] = v.to(device, non_blocking=True)
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        loss = model(dataset_item)
    print(loss)

{'loss': tensor(0.3391, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8774, device='cuda:0')}
{'loss': tensor(0.2982, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8968, device='cuda:0')}
{'loss': tensor(0.3608, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8592, device='cuda:0')}
{'loss': tensor(0.3804, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8652, device='cuda:0')}
{'loss': tensor(0.1943, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.9379, device='cuda:0')}
{'loss': tensor(0.4400, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8326, device='cuda:0')}
{'loss': tensor(0.3351, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.8871, device='cuda:0')}
{'loss': tensor(0.2315, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.9184, device='cuda:0')}
{'loss': tensor(0.2513, device='cuda:0', grad_fn=<NllLossBackward0>), 'acc': tensor(0.9108, device='cuda:0')}
{'loss': t