In [1]:
#pragma cling add_include_path("../../libtorch/include")
#pragma cling add_include_path("../../libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("../../libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <tuple>
#include <torch/torch.h>
#include <ATen/ATen.h>
namespace nn = torch::nn;

In [3]:
torch::Tensor input = torch::randn({1, 2, 3, 4});
std::cout << input << std::endl;

(1,1,.,.) = 
  1.8324  0.8139 -0.2464 -0.5051
  0.7070 -0.1387 -0.2475 -1.4061
  1.0875  1.1796  0.5683  0.3874

(1,2,.,.) = 
 -1.1217  0.2747  1.2427  1.0365
  1.1961  0.7335 -0.8076 -0.0196
 -0.2003  1.4604 -1.0378  1.1555
[ CPUFloatType{1,2,3,4} ]


# nn.Unfold
Extracts sliding local blocks from a batched input tensor.

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold

In [11]:
torch::nn::UnfoldOptions unfold_option = torch::nn::UnfoldOptions(/*kernel_size=*/{2, 3});

In [12]:
torch::nn::Unfold unfold_operator(unfold_option);

In [13]:
torch::Tensor output = unfold_operator -> forward(input);

In [14]:
std::cout << output << std::endl;

(1,.,.) = 
  1.8324  0.8139  0.7070 -0.1387
  0.8139 -0.2464 -0.1387 -0.2475
 -0.2464 -0.5051 -0.2475 -1.4061
  0.7070 -0.1387  1.0875  1.1796
 -0.1387 -0.2475  1.1796  0.5683
 -0.2475 -1.4061  0.5683  0.3874
 -1.1217  0.2747  1.1961  0.7335
  0.2747  1.2427  0.7335 -0.8076
  1.2427  1.0365 -0.8076 -0.0196
  1.1961  0.7335 -0.2003  1.4604
  0.7335 -0.8076  1.4604 -1.0378
 -0.8076 -0.0196 -1.0378  1.1555
[ CPUFloatType{1,12,4} ]


# how to reshape the output of nn.Unfold to behave like a convolution

In [15]:
torch::Tensor output_like_convolution = output.reshape({1,2,6,4}).transpose(-1,-2).reshape({1,2,-1,2,3});

In [16]:
std::cout << output_like_convolution << std::endl;

(1,1,1,.,.) = 
  1.8324  0.8139 -0.2464
  0.7070 -0.1387 -0.2475

(1,2,1,.,.) = 
 -1.1217  0.2747  1.2427
  1.1961  0.7335 -0.8076

(1,1,2,.,.) = 
  0.8139 -0.2464 -0.5051
 -0.1387 -0.2475 -1.4061

(1,2,2,.,.) = 
  0.2747  1.2427  1.0365
  0.7335 -0.8076 -0.0196

(1,1,3,.,.) = 
  0.7070 -0.1387 -0.2475
  1.0875  1.1796  0.5683

(1,2,3,.,.) = 
  1.1961  0.7335 -0.8076
 -0.2003  1.4604 -1.0378

(1,1,4,.,.) = 
 -0.1387 -0.2475 -1.4061
  1.1796  0.5683  0.3874

(1,2,4,.,.) = 
  0.7335 -0.8076 -0.0196
  1.4604 -1.0378  1.1555
[ CPUFloatType{1,2,4,2,3} ]
