In [27]:
import torch
import torch.nn as nn

from ezflow.encoder import build_encoder
from ezflow.engine import get_training_cfg as get_cfg
from ezflow.models import build_model

In [2]:
from nnflow import BasicEncoderV2

In [3]:
raft_cfg = get_cfg("../configs/raft/models/raft.yaml")
raft_cfg.ENCODER.FEATURE.INTERMEDIATE_FEATURES

False

In [4]:
raft_cfg.ENCODER.FEATURE.INTERMEDIATE_FEATURES = True

In [5]:
raft_encoder = build_encoder(raft_cfg.ENCODER.FEATURE)

In [6]:
features_raft = raft_encoder(torch.randn(1,3,256,256))

In [7]:
for feature in features_raft:
    print(feature.shape)

torch.Size([1, 32, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 96, 32, 32])
torch.Size([1, 256, 32, 32])


In [8]:
raft_cfg.ENCODER.FEATURE.LAYER_CONFIG

[32, 64, 96]

In [80]:
def count_params(model):
    return str(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000) + "M params"

___
### PWCNet

In [9]:
pwc_cfg = get_cfg("../configs/pwcnet/models/pwcnet.yaml")
pwc_cfg.ENCODER.CONFIG

[16, 32, 64, 96, 128, 196]

In [10]:
pwc_encoder = build_encoder(pwc_cfg.ENCODER)

In [11]:
features_pwc = pwc_encoder(torch.randn(1,3,256,256)) 

In [12]:
for feature in features_pwc:
    print(feature.shape)

torch.Size([1, 16, 128, 128])
torch.Size([1, 32, 64, 64])
torch.Size([1, 64, 32, 32])
torch.Size([1, 96, 16, 16])
torch.Size([1, 128, 8, 8])
torch.Size([1, 196, 4, 4])


In [13]:
pwc_cfg2 = get_cfg("../configs/pwcnet/models/pwcnet_raft_encoder.yaml")

In [14]:
pwc_raft_encoder = build_encoder(pwc_cfg2.ENCODER)

In [15]:
features_pwc_raft = pwc_raft_encoder(torch.randn(1,3,256,256)) 

In [16]:
len(features_pwc_raft)

6

In [17]:
for i in range(len(features_pwc_raft)):
    print(features_pwc_raft[i].shape == features_pwc[i].shape, " ", features_pwc_raft[i].shape)

True   torch.Size([1, 16, 128, 128])
True   torch.Size([1, 32, 64, 64])
True   torch.Size([1, 64, 32, 32])
True   torch.Size([1, 96, 16, 16])
True   torch.Size([1, 128, 8, 8])
True   torch.Size([1, 196, 4, 4])


In [81]:
pwcnet_model_v1 = build_model("PWCNet", cfg_path="../configs/pwcnet/models/pwcnet.yaml", custom_cfg=True)

In [82]:
pwcnet_model_v2 = build_model("PWCNet", cfg_path="../configs/pwcnet/models/pwcnet_raft_encoder.yaml", custom_cfg=True)

In [37]:
pwcnet_model_v2.eval()
print()




In [38]:
flows = pwcnet_model_v2(torch.randn(1,3,256,256), torch.randn(1,3,256,256))
flows['flow_upsampled'].shape

torch.Size([1, 2, 256, 256])

In [87]:
count_params(pwcnet_model_v1), count_params(pwcnet_model_v2)

('9.374274M params', '10.05017M params')

___

### Flownet C

In [63]:
flownetc_cfg = get_cfg("../configs/flownet_c/models/flownet_c.yaml")
flownetc_cfg.ENCODER.CONFIG[:3]

[64, 128, 256]

In [64]:
flownetc_cfg.ENCODER.CONFIG = flownetc_cfg.ENCODER.CONFIG[:3]

In [65]:
flownetc_encoder = build_encoder(flownetc_cfg.ENCODER)

In [66]:
features_flc = flownetc_encoder(torch.randn(1,3,256,256)) 

In [67]:
for feature in features_flc:
    print(feature.shape)

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])


In [68]:
flownetc_cfg2 = get_cfg("../configs/flownet_c/models/flownet_c_raft_encoder.yaml")

In [69]:
flowc_raft_encoder = build_encoder(flownetc_cfg2.ENCODER)

In [70]:
features_flowc_raft = flowc_raft_encoder(torch.randn(1,3,256,256)) 

In [71]:
len(features_flowc_raft) == len(features_flc)

True

In [72]:
for i in range(len(features_flowc_raft)):
    print(features_flowc_raft[i].shape == features_flc[i].shape, " ", features_flowc_raft[i].shape)

True   torch.Size([1, 64, 128, 128])
True   torch.Size([1, 128, 64, 64])
True   torch.Size([1, 256, 32, 32])


In [85]:
flownetc_model_v1 = build_model("FlowNetC", cfg_path="../configs/flownet_c/models/flownet_c.yaml", custom_cfg=True)

In [86]:
flownetc_model_v2 = build_model("FlowNetC", cfg_path="../configs/flownet_c/models/flownet_c_raft_encoder.yaml", custom_cfg=True)

In [74]:
flownetc_model_v2.eval()
print()




In [75]:
flows = flownetc_model_v2(torch.randn(1,3,256,256), torch.randn(1,3,256,256))
flows['flow_upsampled'].shape

torch.Size([1, 2, 256, 256])

In [88]:
count_params(flownetc_model_v1), count_params(flownetc_model_v2)

('39.24424M params', '40.98216M params')

___