In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [3]:
img1, img2 = torch.randn(1, 3, 256, 256), torch.randn(1, 3, 256, 256)

# img1 = img1.to(device)
# img2 = img2.to(device)

img1.device

device(type='cpu')

In [4]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

___

## EzFlow PWCNet

In [5]:
from ezflow.models import build_model

In [7]:
ezflow_model = build_model('PWCNet', cfg_path='../configs/pwcnet/models/pwcnet.yaml', custom_cfg=True)

  (out_channels, in_channels // groups, *kernel_size), **factory_kwargs))


In [8]:
count_params(ezflow_model)

9374274

In [20]:
output = ezflow_model(img1, img2)

In [21]:
for flow in output["flow_preds"]:
    print(flow.shape)

torch.Size([1, 2, 4, 4])
torch.Size([1, 2, 8, 8])
torch.Size([1, 2, 16, 16])
torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 64, 64])


In [22]:
ezflow_model.eval()

output = ezflow_model(img1, img2)
output["flow_upsampled"].shape

torch.Size([1, 2, 256, 256])

___

#### Loss computation

In [23]:
target = torch.ones(1,2,256,256)

In [24]:
from ezflow.functional import FUNCTIONAL_REGISTRY

In [25]:
loss = FUNCTIONAL_REGISTRY.get('MultiScaleLoss')
loss

ezflow.functional.criterion.multiscale.MultiScaleLoss

In [26]:
loss_params = {
    "norm":"l2",
    "weights":[0.32, 0.08, 0.02, 0.01, 0.005],
    "average": "sum",
    "resize_flow": "downsample"
}

In [27]:
loss_fn_2 = loss(**loss_params)
loss_fn_2.average

'sum'

In [29]:
loss_fn_2(output["flow_preds"], target/20.0)

tensor(0.6452, grad_fn=<DivBackward0>)