In [1]:
import torch
import torch.nn as nn

from ezflow.encoder import build_encoder
from ezflow.engine import get_training_cfg as get_cfg
from ezflow.models import build_model

In [2]:
from nnflow import BasicEncoderV2

In [3]:
raft_cfg = get_cfg("../configs/raft/models/raft.yaml")
raft_cfg.ENCODER.FEATURE.INTERMEDIATE_FEATURES

False

In [4]:
raft_cfg.ENCODER.FEATURE.INTERMEDIATE_FEATURES = True

In [5]:
raft_encoder = build_encoder(raft_cfg.ENCODER.FEATURE)

In [6]:
features_raft = raft_encoder(torch.randn(1,3,256,256))

In [7]:
for feature in features_raft:
    print(feature.shape)

torch.Size([1, 64, 128, 128])
torch.Size([1, 96, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 32, 32])


In [8]:
raft_cfg.ENCODER.FEATURE.LAYER_CONFIG

[64, 96, 128]

In [9]:
raft_model = build_model('RAFT', cfg_path="../configs/raft/models/raft.yaml", custom_cfg=True)

In [10]:
def count_params(model):
    return str(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000) + "M params"

___
### PWCNet

In [11]:
pwc_cfg = get_cfg("../configs/pwcnet/models/pwcnet.yaml")
pwc_cfg.ENCODER.CONFIG

[16, 32, 64, 96, 128, 196]

In [12]:
pwc_encoder = build_encoder(pwc_cfg.ENCODER)

In [13]:
features_pwc = pwc_encoder(torch.randn(1,3,256,256)) 

In [14]:
for feature in features_pwc:
    print(feature.shape)

torch.Size([1, 16, 128, 128])
torch.Size([1, 32, 64, 64])
torch.Size([1, 64, 32, 32])
torch.Size([1, 96, 16, 16])
torch.Size([1, 128, 8, 8])
torch.Size([1, 196, 4, 4])


In [15]:
pwc_cfg2 = get_cfg("../configs/pwcnet/models/pwcnet_raft_encoder.yaml")

In [16]:
pwc_raft_encoder = build_encoder(pwc_cfg2.ENCODER)

In [17]:
features_pwc_raft = pwc_raft_encoder(torch.randn(1,3,256,256)) 

In [18]:
len(features_pwc_raft)

6

In [19]:
for i in range(len(features_pwc_raft)):
    print(features_pwc_raft[i].shape == features_pwc[i].shape, " ", features_pwc_raft[i].shape)

True   torch.Size([1, 16, 128, 128])
True   torch.Size([1, 32, 64, 64])
True   torch.Size([1, 64, 32, 32])
True   torch.Size([1, 96, 16, 16])
True   torch.Size([1, 128, 8, 8])
True   torch.Size([1, 196, 4, 4])


In [20]:
pwcnet_model_v1 = build_model("PWCNet", cfg_path="../configs/pwcnet/models/pwcnet.yaml", custom_cfg=True)

  (out_channels, in_channels // groups, *kernel_size), **factory_kwargs))


In [21]:
pwcnet_model_v2 = build_model("PWCNet", cfg_path="../configs/pwcnet/models/pwcnet_raft_encoder.yaml", custom_cfg=True)

In [24]:
pwcnet_model_v2.eval()
print()




In [25]:
flows = pwcnet_model_v2(torch.randn(1,3,256,256), torch.randn(1,3,256,256))
flows['flow_upsampled'].shape

torch.Size([1, 2, 256, 256])

In [28]:
pwcnet_model_v3 = build_model("PWCNet", cfg_path="../configs/pwcnet/models/pwcnet_raft_encoder_no_norm.yaml", custom_cfg=True)

In [29]:
pwcnet_model_v3.eval()

PWCNet(
  (encoder): BasicEncoderV2(
    (encoder): ModuleList(
      (0): Conv2d(3, 16, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): Identity()
      (2): ReLU(inplace=True)
      (3): Sequential(
        (0): BasicBlock(
          (activation): ReLU(inplace=True)
          (residual_fn): Sequential(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Identity()
            (2): ReLU(inplace=True)
            (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): Identity()
          )
          (shortcut): Identity()
        )
        (1): BasicBlock(
          (activation): ReLU(inplace=True)
          (residual_fn): Sequential(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Identity()
            (2): ReLU(inplace=True)
            (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): Identity()
    

In [30]:
flows = pwcnet_model_v3(torch.randn(1,3,256,256), torch.randn(1,3,256,256))
flows['flow_upsampled'].shape

torch.Size([1, 2, 256, 256])

In [31]:
count_params(pwcnet_model_v1), count_params(pwcnet_model_v2), count_params(pwcnet_model_v3)

('9.374274M params', '10.088782M params', '10.088782M params')

___

### Flownet C

In [11]:
from nnflow.models.flownet_c_v2 import FlowNetC_V2

In [12]:
flownetc_cfg = get_cfg("../configs/flownet_c/models/flownet_c.yaml")
channels = flownetc_cfg.ENCODER.CONFIG
flownetc_cfg.ENCODER.CONFIG = flownetc_cfg.ENCODER.CONFIG[:3]
channels[3:]

[256, 512, 512, 512, 512, 1024, 1024]

In [13]:
flownetc_cfg.ENCODER.CONFIG

[64, 128, 256]

In [14]:
flownetc_encoder = build_encoder(flownetc_cfg.ENCODER)

In [15]:
count_params(flownetc_encoder)

'1.033856M params'

In [16]:
features_flc = flownetc_encoder(torch.randn(1,3,256,256)) 

In [17]:
for feature in features_flc:
    print(feature.shape)

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])


In [18]:
flownetc_cfg2 = get_cfg("../configs/flownet_c/models/flownet_c_raft_encoder.yaml")

In [19]:
flowc_raft_encoder = build_encoder(flownetc_cfg2.ENCODER)

In [20]:
count_params(flowc_raft_encoder)

'2.846336M params'

In [21]:
features_flowc_raft = flowc_raft_encoder(torch.randn(1,3,256,256)) 

In [22]:
len(features_flowc_raft) == len(features_flc)

True

In [23]:
for i in range(len(features_flowc_raft)):
    print(features_flowc_raft[i].shape == features_flc[i].shape, " ", features_flowc_raft[i].shape)

True   torch.Size([1, 64, 128, 128])
True   torch.Size([1, 128, 64, 64])
True   torch.Size([1, 256, 32, 32])


In [24]:
flownetc_model_v1 = build_model("FlowNetC", cfg_path="../configs/flownet_c/models/flownet_c.yaml", custom_cfg=True)

In [31]:
flownetc_model_v2 = build_model("FlowNetC_V2", cfg_path="../configs/flownet_c/models/flownet_c_raft_encoder_no_norm.yaml", custom_cfg=True)

In [35]:
flownetc_model_v2.eval()

FlowNetC_V2(
  (feature_encoder): BasicEncoderV2(
    (encoder): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): Identity()
      (2): ReLU(inplace=True)
      (3): Sequential(
        (0): BasicBlock(
          (activation): ReLU(inplace=True)
          (residual_fn): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Identity()
            (2): ReLU(inplace=True)
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): Identity()
          )
          (shortcut): Identity()
        )
        (1): BasicBlock(
          (activation): ReLU(inplace=True)
          (residual_fn): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Identity()
            (2): ReLU(inplace=True)
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): Id

In [33]:
count_params(flownetc_model_v1), count_params(flownetc_model_v2)

('39.24424M params', '78.80824M params')

In [34]:
flows = flownetc_model_v2(torch.randn(1,3,256,256), torch.randn(1,3,256,256))
flows['flow_upsampled'].shape

torch.Size([1, 2, 256, 256])

___

### Test

In [95]:
from ezflow.engine import schedulers, optimizers

In [96]:
opt = optimizers.get("AdamW")
opt

torch.optim.adamw.AdamW

In [90]:
sched = schedulers.get("OneCycleLR")
sched

torch.optim.lr_scheduler.OneCycleLR

In [92]:
train_cfg = get_cfg("../configs/raft/trainer/kubric_v1_0.yaml")

In [93]:
train_cfg.SCHEDULER

CfgNode({'USE': True, 'NAME': 'OneCycleLR', 'PARAMS': CfgNode({'max_lr': 0.0004, 'total_steps': 100100, 'pct_start': 0.05, 'cycle_momentum': False, 'anneal_strategy': 'linear'})})

In [122]:
sched_params = train_cfg.SCHEDULER.PARAMS.to_dict()
sched_params

{'max_lr': 0.0004,
 'total_steps': 100100,
 'pct_start': 0.05,
 'cycle_momentum': False,
 'anneal_strategy': 'linear'}

In [110]:
_optim = opt(raft_model.parameters(), lr=train_cfg.OPTIMIZER.LR,**train_cfg.OPTIMIZER.PARAMS.to_dict())

In [111]:
_optim.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.0004,
   'betas': [0.9, 0.999],
   'eps': 1e-08,
   'weight_decay': 0.0001,
   'amsgrad': False,
   'params': [0,
    1,
    2,
    3,
    4,
    5,
    6,
    7,
    8,
    9,
    10,
    11,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    20,
    21,
    22,
    23,
    24,
    25,
    26,
    27,
    28,
    29,
    30,
    31,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
    48,
    49,
    50,
    51,
    52,
    53,
    54,
    55,
    56,
    57,
    58,
    59,
    60,
    61,
    62,
    63,
    64,
    65,
    66,
    67,
    68,
    69,
    70,
    71,
    72,
    73,
    74,
    75,
    76,
    77,
    78,
    79,
    80,
    81,
    82,
    83,
    84,
    85,
    86,
    87,
    88,
    89,
    90,
    91,
    92,
    93,
    94,
    95,
    96,
    97,
    98,
    99,
    100,
    101,
    102,
    103,
    104,
    105,
    106,

In [112]:
state_dict = torch.load("../../results/raft/ckpts/exp200/raft_step30000.pth", map_location=torch.device('cpu'))

FileNotFoundError: [Errno 2] No such file or directory: '../../results/raft/ckpts/exp200/raft_step30000.pth'

In [113]:
state_dict.keys(), state_dict['step']

(dict_keys(['model_state_dict', 'optimizer_state_dict', 'step', 'scheduler_state_dict']),
 30000)

In [114]:
opt_state_dict = state_dict['optimizer_state_dict']
sch_state_dict = state_dict['scheduler_state_dict']

In [116]:
_optim.load_state_dict(opt_state_dict)

In [123]:
_sched = sched(_optim, **sched_params)

In [124]:
_sched.state_dict()

{'total_steps': 100100,
 '_schedule_phases': [{'end_step': 5004.0,
   'start_lr': 'initial_lr',
   'end_lr': 'max_lr',
   'start_momentum': 'max_momentum',
   'end_momentum': 'base_momentum'},
  {'end_step': 100099,
   'start_lr': 'max_lr',
   'end_lr': 'min_lr',
   'start_momentum': 'base_momentum',
   'end_momentum': 'max_momentum'}],
 'anneal_func': <bound method OneCycleLR._annealing_linear of <torch.optim.lr_scheduler.OneCycleLR object at 0x2ad4cee0eb50>>,
 'cycle_momentum': False,
 'base_lrs': [1.6e-05],
 'last_epoch': 0,
 '_step_count': 1,
 'verbose': False,
 '_get_lr_called_within_step': False,
 '_last_lr': [1.6e-05]}

In [119]:
sch_state_dict

{'total_steps': 100100,
 '_schedule_phases': [{'end_step': 5004.0,
   'start_lr': 'initial_lr',
   'end_lr': 'max_lr',
   'start_momentum': 'max_momentum',
   'end_momentum': 'base_momentum'},
  {'end_step': 100099,
   'start_lr': 'max_lr',
   'end_lr': 'min_lr',
   'start_momentum': 'base_momentum',
   'end_momentum': 'max_momentum'}],
 'anneal_func': <bound method OneCycleLR._annealing_linear of <torch.optim.lr_scheduler.OneCycleLR object at 0x2ad4cede8b90>>,
 'cycle_momentum': False,
 'base_lrs': [1.6e-05],
 'last_epoch': 30000,
 '_step_count': 30001,
 'verbose': False,
 '_get_lr_called_within_step': False,
 '_last_lr': [0.00029485924594983967]}

In [125]:
_sched.load_state_dict(sch_state_dict)

In [126]:
_sched.state_dict()

{'total_steps': 100100,
 '_schedule_phases': [{'end_step': 5004.0,
   'start_lr': 'initial_lr',
   'end_lr': 'max_lr',
   'start_momentum': 'max_momentum',
   'end_momentum': 'base_momentum'},
  {'end_step': 100099,
   'start_lr': 'max_lr',
   'end_lr': 'min_lr',
   'start_momentum': 'base_momentum',
   'end_momentum': 'max_momentum'}],
 'anneal_func': <bound method OneCycleLR._annealing_linear of <torch.optim.lr_scheduler.OneCycleLR object at 0x2ad4cede8b90>>,
 'cycle_momentum': False,
 'base_lrs': [1.6e-05],
 'last_epoch': 30000,
 '_step_count': 30001,
 'verbose': False,
 '_get_lr_called_within_step': False,
 '_last_lr': [0.00029485924594983967]}