In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

In [2]:
from ezflow.models import MODEL_REGISTRY
from ezflow.config import configurable
from ezflow.modules import BaseModule

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

___
1. https://huggingface.co/blog/perceiver
2. https://huggingface.co/docs/transformers/v4.21.3/en/model_doc/perceiver
___
3. https://huggingface.co/docs/transformers/v4.21.3/en/model_doc/perceiver#transformers.PerceiverForOpticalFlow
4. https://github.com/huggingface/transformers/blob/v4.21.3/src/transformers/models/perceiver/modeling_perceiver.py#L1612
5. https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Perceiver_for_Optical_Flow.ipynb
___

In [12]:
from transformers import PerceiverModel, PerceiverConfig
from transformers import PerceiverForOpticalFlow
from transformers.models.perceiver.modeling_perceiver import PerceiverImagePreprocessor, PerceiverTrainablePositionEncoding

In [5]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

___

### Pretrained

In [6]:
patches = torch.randn(1,2,27,368, 496)

In [7]:
pretrained_model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")

In [8]:
count_params(pretrained_model)

41057134

In [9]:
# pretrained_model = pretrained_model.to(device)

In [10]:
# patches = patches.to(device)

In [11]:
patches.device

device(type='cpu')

In [15]:
outputs = pretrained_model(inputs=patches, return_dict=False)

flow = outputs[0]

In [16]:
flow.shape

torch.Size([1, 368, 496, 2])

In [17]:
del pretrained_model, flow, outputs

___

In [19]:
class Perceiver(BaseModule):
    """
    Implementation of PerceiverIO Optical Flow
    https://www.deepmind.com/open-source/perceiver-io
    https://huggingface.co/docs/transformers/v4.21.3/en/model_doc/perceiver#transformers.PerceiverForOpticalFlow
    https://github.com/huggingface/transformers/blob/v4.21.3/src/transformers/models/perceiver/modeling_perceiver.py#L1612
    

    Parameters
    ----------
    cfg : :class:`CfgNode`
        Configuration for the model
    """

    def __init__(self, cfg):
        super(Perceiver, self).__init__()
        
        self.config = PerceiverConfig(**cfg)
        
        fourier_position_encoding_kwargs_preprocessor = dict(
            num_bands=64,
            max_resolution=self.config.train_size,
            sine_only=False,
            concat_pos=True,
        )
        fourier_position_encoding_kwargs_decoder = dict(
            concat_pos=True, max_resolution=self.config.train_size, num_bands=64, sine_only=False
        )
        
        image_preprocessor = PerceiverImagePreprocessor(
            self.config,
            prep_type="patches",
            spatial_downsample=1,
            conv_after_patching=True,
            conv_after_patching_in_channels=54,
            temporal_downsample=2,
            position_encoding_type="fourier",
            # position_encoding_kwargs
            fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
        )
        
        self.perceiver = PerceiverModel(
            self.config,
            input_preprocessor=image_preprocessor,
            decoder=PerceiverOpticalFlowDecoder(
                self.config,
                num_channels=image_preprocessor.num_channels,
                output_image_shape=self.config.train_size,
                rescale_factor=100.0,
                use_query_residual=False,
                output_num_channels=2,
                position_encoding_type="fourier",
                fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
            ),
        )
        
        self._init_weights()
        
    def _init_weights(self):
        
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif hasattr(module, "latents"):
                module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
            elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
                module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)
            elif isinstance(module, nn.ParameterDict):
                for modality in module.keys():
                    module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
        
    # source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9
    def _extract_image_patches(self, x, kernel=3, stride=1, dilation=1):
        # Do TF 'SAME' Padding
        b,c,h,w = x.shape
        h2 = math.ceil(h / stride)
        w2 = math.ceil(w / stride)
        pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
        pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
        x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))

        # Extract patches
        patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
        patches = patches.permute(0,4,5,1,2,3).contiguous()
        
        return patches.view(b,-1,patches.shape[-2], patches.shape[-1])
        
    def forward(self, img1, img2):
        
        B, C, H, W = img1.shape
        
        patches = self._extract_image_patches(torch.concat([img1, img2], dim=0))
        _, C, H, W = patches.shape
        patches = patches.view(B, -1, C, H, W)
        
        flow = self.perceiver(
            inputs=patches,
            return_dict=False
        )[0]
        
        flow = flow.permute(0, 3, 1, 2)
        
        output = {"flow_preds": flow}
        
        if self.training:
            return output
        
        output["flow_upsampled"] = flow
        
        return output
        

In [20]:
config_dict = {
      "_name_or_path": "deepmind/optical-flow-perceiver",
      "architectures": [
        "PerceiverForOpticalFlow"
      ],
      "attention_probs_dropout_prob": 0.1,
      "audio_samples_per_frame": 1920,
      "cross_attention_shape_for_attention": "kv",
      "cross_attention_widening_factor": 1,
      "d_latents": 512,
      "d_model": 322,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "image_size": 56,
      "initializer_range": 0.02,
      "layer_norm_eps": 1e-12,
      "max_position_embeddings": 2048,
      "model_type": "perceiver",
      "num_blocks": 1,
      "num_cross_attention_heads": 1,
      "num_frames": 16,
      "num_latents": 2048,
      "num_self_attends_per_block": 24,
      "num_self_attention_heads": 16,
      "output_shape": [
        1,
        16,
        224,
        224
      ],
      "qk_channels": None,
      "samples_per_patch": 16,
      "self_attention_widening_factor": 1,
      "seq_len": 2048,
      "torch_dtype": "float32",
      "train_size": [
        368,
        496
      ],
      "transformers_version": "4.21.3",
      "use_query_residual": True,
      "v_channels": None,
      "vocab_size": 262
}

In [21]:
model = Perceiver(config_dict)

In [22]:
count_params(model)

41057134

In [23]:
img1 = torch.randn(1, 3, 368, 496)
img2 = torch.randn(1, 3, 368, 496)

In [24]:
output = model(img1, img2)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [25]:
flow = output["flow_preds"]
flow.shape

torch.Size([1, 2, 368, 496])

___

### Build using Ezflow

In [18]:
from ezflow.models import build_model
from nnflow import Perceiver

In [19]:
_model = build_model('Perceiver', cfg_path='../configs/perceiver/models/perceiver.yaml', custom_cfg=True)

In [20]:
_model.cfg

{'_name_or_path': 'deepmind/optical-flow-perceiver',
 'architectures': ['PerceiverForOpticalFlow'],
 'attention_probs_dropout_prob': 0.1,
 'audio_samples_per_frame': 1920,
 'cross_attention_shape_for_attention': 'kv',
 'cross_attention_widening_factor': 1,
 'd_latents': 512,
 'd_model': 322,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'image_size': 56,
 'initializer_range': 0.02,
 'layer_norm_eps': 1e-12,
 'max_position_embeddings': 2048,
 'model_type': 'perceiver',
 'num_blocks': 1,
 'num_cross_attention_heads': 1,
 'num_frames': 16,
 'num_latents': 2048,
 'num_self_attends_per_block': 24,
 'num_self_attention_heads': 16,
 'output_shape': [1, 16, 224, 224],
 'qk_channels': None,
 'samples_per_patch': 16,
 'self_attention_widening_factor': 1,
 'seq_len': 2048,
 'torch_dtype': 'float32',
 'train_size': [368, 496],
 'transformers_version': '4.21.3',
 'use_query_residual': True,
 'v_channels': None,
 'vocab_size': 262}

In [7]:
img1 = torch.randn(1, 3, 368, 496)
img2 = torch.randn(1, 3, 368, 496)

In [22]:
# img1 = img1.to(device)
# img2 = img2.to(device)
# _model = _model.to(device)

In [23]:
flow_outputs = _model(img1, img2)

RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 1495269376 bytes.

In [None]:
flow = flow_outputs["flow_preds"]
flow.shape

In [16]:
import math

In [17]:
def extract_image_patches(x, kernel=3, stride=1, dilation=1):
        # Do TF 'SAME' Padding
        b,c,h,w = x.shape
        h2 = math.ceil(h / stride)
        w2 = math.ceil(w / stride)
        pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
        pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
        x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))

        # Extract patches
        patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
        patches = patches.permute(0,4,5,1,2,3).contiguous()

        return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

In [12]:
imgs = torch.stack([img1, img2], dim=1)
imgs.shape

torch.Size([1, 2, 3, 368, 496])

In [14]:
b, _, c, h, w = imgs.shape
imgs = imgs.view(b*_, c, h, w)
imgs.shape

torch.Size([2, 3, 368, 496])

In [18]:
patches = extract_image_patches(imgs)
patches.shape

torch.Size([2, 27, 368, 496])

In [19]:
patches = patches.view(b, -1, c, h, w)
patches.shape

torch.Size([1, 18, 3, 368, 496])