In [1]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
import torch

In [3]:
import sys
sys.path.append("..")

In [4]:
from source.datasets.ptz_dataset import PTZImageDataset

In [5]:
import logging

In [6]:
# shandler = logging.StreamHandler()
# shandler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# shandler.setLevel(logging.INFO)

In [7]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# logger.addHandler(shandler)

In [8]:
import yaml

In [9]:
with open("/Users/yufengluo/Research/anl/su24/up-PTZJEPA/configs/Config_file.yaml", 'r') as fp:
    params = yaml.load(fp, Loader=yaml.FullLoader)
args = params

In [10]:
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint']
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']
camera_brand = args['meta']['camera_brand']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- DATA
use_gaussian_blur = args['data']['use_gaussian_blur']
use_horizontal_flip = args['data']['use_horizontal_flip']
use_color_distortion = args['data']['use_color_distortion']
color_jitter = args['data']['color_jitter_strength']
# --
global_batch_size = args['data']['global_batch_size']
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
image_folder = args['data']['image_folder']
crop_size = args['data']['crop_size']
crop_scale = args['data']['crop_scale']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
num_epochs = args['optimization']['epochs']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- PLATEAU
patience = args['plateau']['wm_patience']
threshold = args['plateau']['wm_threshold']

# -- LOGGING
folder = args['logging']['folder']
ownership_folder = args['logging']['ownership_folder']
tag = args['logging']['write_tag']

# -- MEMORY
memory_models = args['memory']['models']


In [11]:
import copy

In [12]:
import source
import source.helper
import importlib
importlib.reload(source)
importlib.reload(source.helper)

<module 'source.helper' from '/Users/yufengluo/Research/anl/su24/up-PTZJEPA/notebooks/../source/helper.py'>

In [13]:
from source.helper import (
    load_checkpoint,
    init_model,
    init_world_model,
    init_opt)
from source.transforms import make_transforms

In [14]:
data_dir = Path("/Users/yufengluo/Research/anl/su24/data/collected_imgs")

In [15]:
dataset = PTZImageDataset(data_dir, transform=make_transforms(to_tensor=True))
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False,)

INFO:data_transforms:making ptz image data transforms


In [24]:
img, pos = next(iter(loader))

In [25]:
img.shape

torch.Size([4, 4, 224, 224])

In [26]:
pos.shape

torch.Size([4, 2])

In [19]:
encoder, predictor = init_world_model(
    device=device,
    patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_name=model_name,
    in_chans=4)
target_encoder = copy.deepcopy(encoder)

INFO:source.helper:VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(4, 192, kernel_size=(14, 14), stride=(14, 14))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
)


In [20]:
target_encoder(img).shape

torch.Size([4, 256, 192])

In [21]:
from source.run_jepa import forward_context, arrange_inputs

In [22]:
context_imgs, context_poss, target_imgs, target_poss = arrange_inputs(img, pos, "cpu")

In [23]:
forward_context(context_imgs, context_poss, target_poss, encoder, predictor, camera_brand, return_rewards=True)

(tensor([[[ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          ...,
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440]],
 
         [[ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          ...,
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440],
          [ 0.1075,  0.1310,  0.4331,  ..., -0.6255, -0.0306,  0.3440]],
 
         [[ 0.1016,  0.1288,  0.4331,  ..., -0.6296, -0.0164,  0.3493],
          [ 0.1016,  0.1288,