In [48]:
import torch
import torch.nn as nn
import os
import numpy as np

from util.helpers import vis, get_all_frames_labels
from mae import MaskedAutoencoderViT, mae_vit_base_patch16_dec512d8b

os.chdir('../aot-benchmark')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
device='cuda:0'

In [74]:
model = MaskedAutoencoderViT(img_size=464, patch_size=16, in_chans=3,
                 embed_dim=[96, 96, 96, 192, 192, 192, 384, 384, 384, 768, 768, 768], depth=12, num_heads=16,
                 decoder_embed_dim=768, decoder_depth=1, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.Identity, norm_pix_loss=False)

Dim for norm: 96
Dim for norm: 96
Dim for norm: 96
Dim for norm: 192
Dim for norm: 192
Dim for norm: 192
Dim for norm: 384
Dim for norm: 384
Dim for norm: 384
Dim for norm: 768
Dim for norm: 768
Dim for norm: 768
Dim for norm: 768


In [75]:
model = model.to(device)

In [None]:
model = mae_vit_base_patch16_dec512d8b(img_size=465).to(device)

In [52]:
# load dataset

import importlib
from dataloaders.train_datasets import AOT_Train
from torch.utils.data import DataLoader
import dataloaders.video_transforms as tr
from torchvision import transforms

stage = 'pre'
exp_name = "tristan_test"
model_name = "aott"
engine_config = importlib.import_module('configs.' + stage)
cfg = engine_config.EngineConfig(exp_name, model_name)

cfg.DIST_START_GPU = 0  # default value
cfg.TRAIN_GPUS = 1
cfg.DATASETS = 'AOT'
cfg.DIR_AOT = '/gv1/projects/AI_Surrogate/dev/tristan/aot-benchmark/datasets/AOT'
cfg.PRETRAIN_MODEL = "pretrain_model/mobilenet_v2-b0353104.pth"
cfg.DIST_ENABLE = False
cfg.DATA_WORKERS = 2
cfg.TRAIN_BATCH_SIZE = 1
cfg.DATA_RANDOMCROP=464

cfg.enable_prev_frame = True
train_sampler = None

composed_transforms = transforms.Compose([
                tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR,
                               cfg.DATA_MAX_SCALE_FACTOR,
                               cfg.DATA_SHORT_EDGE_LEN),
                tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP,
                                      max_obj_num=cfg.MODEL_MAX_OBJ_NUM),
                tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True),
                tr.ToTensor()])

train_dataset = AOT_Train(
                root=cfg.DIR_AOT,
                transform=composed_transforms,
                seq_len=cfg.DATA_SEQ_LEN,
                rand_gap=cfg.DATA_RANDOM_GAP_AOT,
                rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ,
                merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB,
                enable_prev_frame=cfg.enable_prev_frame,
                max_obj_n=cfg.MODEL_MAX_OBJ_NUM)

train_loader = DataLoader(train_dataset,
               batch_size=int(cfg.TRAIN_BATCH_SIZE /
                              cfg.TRAIN_GPUS),
               shuffle=False if cfg.DIST_ENABLE else True,
               num_workers=cfg.DATA_WORKERS,
               pin_memory=True,
               sampler=train_sampler,
               drop_last=True,
               prefetch_factor=4)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~
./result/tristan_test_AOTT/PRE/ckpt
*********************************************************************************************************************************
Video Num: 23 X 1


In [53]:
for frame_idx, sample in enumerate(train_loader):
    all_frames, all_labels, bs, obj_nums = get_all_frames_labels(sample, 0) # obj nums is the classes
    break

In [50]:
all_frames.shape

torch.Size([6, 3, 464, 464])

In [13]:
obj_nums

[1]

In [None]:
model(all_frames)

In [18]:
from timm.models import VisionTransformer

In [23]:
model2 = VisionTransformer(
    img_size=465,
    patch_size=15,
    num_classes=10,
    embed_dim=768,
    depth=12,
    num_heads=12).to(device)

In [24]:
model2(all_frames)

tensor([[ 0.0541,  0.3998,  1.0502, -0.7676,  0.0494, -0.7532,  0.9203, -0.5654,
         -0.3561,  0.3753],
        [ 0.0354,  0.3576,  1.0069, -0.7382,  0.0290, -0.8259,  0.8893, -0.5758,
         -0.2815,  0.4169],
        [ 0.0296,  0.3585,  1.0130, -0.7445,  0.0295, -0.8259,  0.8898, -0.5791,
         -0.2739,  0.4148],
        [ 0.0368,  0.3589,  1.0240, -0.7475,  0.0292, -0.8081,  0.8945, -0.5773,
         -0.2934,  0.4151],
        [ 0.0444,  0.3634,  1.0280, -0.7481,  0.0295, -0.7995,  0.8929, -0.5710,
         -0.2928,  0.4132],
        [ 0.0405,  0.3898,  1.0462, -0.7587,  0.0332, -0.7734,  0.8959, -0.5792,
         -0.3250,  0.3978]], device='cuda:0', grad_fn=<AddmmBackward>)

In [39]:
from typing import Callable

def func(module, name):
    print("found:", name)

def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = '.'.join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module

In [39]:
named_apply(func, model)

found: patch_embed.proj
found: patch_embed.norm
found: patch_embed
found: blocks.0.norm1
found: blocks.0.attn.qkv
found: blocks.0.attn.attn_drop
found: blocks.0.attn.proj
found: blocks.0.attn.proj_drop
found: blocks.0.attn
found: blocks.0.ls1
found: blocks.0.drop_path1
found: blocks.0.norm2
found: blocks.0.mlp.fc1
found: blocks.0.mlp.act
found: blocks.0.mlp.drop1
found: blocks.0.mlp.fc2
found: blocks.0.mlp.drop2
found: blocks.0.mlp
found: blocks.0.ls2
found: blocks.0.drop_path2
found: blocks.0
found: blocks
found: norm
found: decoder_embed
found: decoder_blocks.0.norm1
found: decoder_blocks.0.attn.qkv
found: decoder_blocks.0.attn.attn_drop
found: decoder_blocks.0.attn.proj
found: decoder_blocks.0.attn.proj_drop
found: decoder_blocks.0.attn
found: decoder_blocks.0.ls1
found: decoder_blocks.0.drop_path1
found: decoder_blocks.0.norm2
found: decoder_blocks.0.mlp.fc1
found: decoder_blocks.0.mlp.act
found: decoder_blocks.0.mlp.drop1
found: decoder_blocks.0.mlp.fc2
found: decoder_blocks.0.mlp

MaskedAutoencoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(15, 15), stride=(15, 15))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (ls2): Identity()
      (drop_path2): Identity()
    )
 

In [45]:
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def load_network(net, pretrained_dir, gpu):
    antialias = True
    pretrained = torch.load(pretrained_dir,
                            map_location=device)
    print(pretrained.keys())
    if 'state_dict' in pretrained.keys():
        pretrained_dict = pretrained['state_dict']
    elif 'model_state' in pretrained.keys():
        pretrained_dict = pretrained['model_state']
    else:
        pretrained_dict = pretrained
    model_dict = net.state_dict()
    pretrained_dict_update = {}
    pretrained_dict_remove = []
    for k, v in pretrained_dict.items():
         # move to device
        if 'patch_embed.proj.weight' in k:
            O, I, H, W = model.patch_embed.proj.weight.shape
            if len(v.shape) < 4:
                # For old models that I trained prior to conv based patchification
                O, I, H, W = model.patch_embed.proj.weight.shape
                v = v.reshape(O, -1, H, W)
                if v.shape[-1] != W or v.shape[-2] != H:
                    v = resample_patch_embed(
                        v,
                        (H, W),
                        interpolation='bicubic',
                        antialias=antialias,
                        verbose=True,
                    )
            elif v.shape != model.patch_embed.proj.weight.shape:
                print("interpolating v")
                # send to cpu and then bring back after
                v = v.to('cpu')
                O, I, H, W = model.patch_embed.proj.weight.shape
                v = resample_patch_embed(
                        v,
                        (H, W),
                        interpolation='bicubic',
                        antialias=antialias,
                        verbose=True,
                    )
                v = v.to(device)
            elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
                # To resize pos embedding when using model at different size from pretrained weights
                num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
                v = resample_abs_pos_embed(
                    v,
                    new_size=model.patch_embed.grid_size,
                    num_prefix_tokens=num_prefix_tokens,
                    interpolation=interpolation,
                    antialias=antialias,
                    verbose=True,
                )
            if k in model_dict:
                pretrained_dict_update[k] = v
            elif k[:7] == 'module.':
                if k[7:] in model_dict:
                    pretrained_dict_update[k[7:]] = v
            else:
                print("removing:", k)
                pretrained_dict_remove.append(k)
    print(pretrained_dict_update.keys())
    model_dict.update(pretrained_dict_update)
    net.load_state_dict(model_dict, strict=False)
    del (pretrained)
    return net.cuda(gpu), pretrained_dict_remove

In [46]:
load_network(model, 'pretrain_models/hiera_base_224.pth', 0)

dict_keys(['epoch', 'model_state', 'optimizer_state', 'scaler_state'])
interpolating v
dict_keys(['patch_embed.proj.weight'])


(MaskedAutoencoderViT(
   (patch_embed): PatchEmbed(
     (proj): Conv2d(3, 96, kernel_size=(16, 16), stride=(16, 16))
     (norm): Identity()
   )
   (blocks): ModuleList(
     (0-2): 3 x Block(
       (norm1): Identity()
       (attn): Attention(
         (qkv): Linear(in_features=96, out_features=288, bias=True)
         (attn_drop): Dropout(p=0.0, inplace=False)
         (proj): Linear(in_features=96, out_features=96, bias=True)
         (proj_drop): Dropout(p=0.0, inplace=False)
       )
       (ls1): Identity()
       (drop_path1): Identity()
       (norm2): Identity()
       (mlp): Mlp(
         (fc1): Linear(in_features=96, out_features=384, bias=True)
         (act): GELU(approximate='none')
         (drop1): Dropout(p=0.0, inplace=False)
         (fc2): Linear(in_features=384, out_features=96, bias=True)
         (drop2): Dropout(p=0.0, inplace=False)
       )
       (ls2): Identity()
       (drop_path2): Identity()
     )
     (3-5): 3 x Block(
       (norm1): Identity()
   

In [None]:
model(all_frames)

In [32]:
checkpoint = torch.load('pretrain_models/hiera_base_224.pth')

In [33]:
model_dict = checkpoint['model_state']
print(model_dict['patch_embed.proj.weight'].shape)
for k, v in model_dict.items():
    if 'blocks' in k:
        print(k, v.shape)

torch.Size([96, 3, 7, 7])
blocks.0.norm1.weight torch.Size([96])
blocks.0.norm1.bias torch.Size([96])
blocks.0.attn.qkv.weight torch.Size([288, 96])
blocks.0.attn.qkv.bias torch.Size([288])
blocks.0.attn.proj.weight torch.Size([96, 96])
blocks.0.attn.proj.bias torch.Size([96])
blocks.0.norm2.weight torch.Size([96])
blocks.0.norm2.bias torch.Size([96])
blocks.0.mlp.fc1.weight torch.Size([384, 96])
blocks.0.mlp.fc1.bias torch.Size([384])
blocks.0.mlp.fc2.weight torch.Size([96, 384])
blocks.0.mlp.fc2.bias torch.Size([96])
blocks.1.norm1.weight torch.Size([96])
blocks.1.norm1.bias torch.Size([96])
blocks.1.attn.qkv.weight torch.Size([288, 96])
blocks.1.attn.qkv.bias torch.Size([288])
blocks.1.attn.proj.weight torch.Size([96, 96])
blocks.1.attn.proj.bias torch.Size([96])
blocks.1.norm2.weight torch.Size([96])
blocks.1.norm2.bias torch.Size([96])
blocks.1.mlp.fc1.weight torch.Size([384, 96])
blocks.1.mlp.fc1.bias torch.Size([384])
blocks.1.mlp.fc2.weight torch.Size([96, 384])
blocks.1.mlp.fc

In [21]:
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

Using cache found in /home/tpeat3/.cache/torch/hub/facebookresearch_dino_main
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /home/tpeat3/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth
100%|██████████| 82.7M/82.7M [00:03<00:00, 28.7MB/s]


In [None]:
print(vits16.state_dict().keys())

for k, v in vits16.state_dict().items():
    if 'blocks' in k:
        print(k, v.shape)

In [20]:
for k, v in model.state_dict().items():
    if 'blocks' in k and 'decoder' not in k:
        print(k, v.shape)

blocks.0.norm1.weight torch.Size([96])
blocks.0.norm1.bias torch.Size([96])
blocks.0.attn.qkv.weight torch.Size([288, 96])
blocks.0.attn.qkv.bias torch.Size([288])
blocks.0.attn.proj.weight torch.Size([96, 96])
blocks.0.attn.proj.bias torch.Size([96])
blocks.0.norm2.weight torch.Size([96])
blocks.0.norm2.bias torch.Size([96])
blocks.0.mlp.fc1.weight torch.Size([384, 96])
blocks.0.mlp.fc1.bias torch.Size([384])
blocks.0.mlp.fc2.weight torch.Size([96, 384])
blocks.0.mlp.fc2.bias torch.Size([96])
blocks.1.norm1.weight torch.Size([96])
blocks.1.norm1.bias torch.Size([96])
blocks.1.attn.qkv.weight torch.Size([288, 96])
blocks.1.attn.qkv.bias torch.Size([288])
blocks.1.attn.proj.weight torch.Size([96, 96])
blocks.1.attn.proj.bias torch.Size([96])
blocks.1.norm2.weight torch.Size([96])
blocks.1.norm2.bias torch.Size([96])
blocks.1.mlp.fc1.weight torch.Size([384, 96])
blocks.1.mlp.fc1.bias torch.Size([384])
blocks.1.mlp.fc2.weight torch.Size([96, 384])
blocks.1.mlp.fc2.bias torch.Size([96])
bl

In [None]:
load_network(model, 'pretrain_models/dino_deitsmall16_pretrain.pth', 0)

In [None]:
model(all_frames)

In [None]:
shortcuts = model.get_intermediates()
for i in shortcuts:
    print(i.shape)