-
Notifications
You must be signed in to change notification settings - Fork 254
/
im2col.cpp
76 lines (66 loc) · 3.07 KB
/
im2col.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <algorithm>
template <typename T>
void im2col(const T *img, T *col, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / (kernel_h * kernel_w);
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h*stride_h - pad_h + h_offset;
int w_pad = w*stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) {
col[(c*height_col+h) * width_col + w] =
img[(c_im * height + h_pad) * width + w_pad];
} else {
col[(c*height_col+h) * width_col + w] = 0;
}
}
}
}
}
template <typename T>
void col2im(const T *col, T *img, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
std::fill(img, img + width*height*channels, 0);
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / (kernel_h * kernel_w);
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h*stride_h - pad_h + h_offset;
int w_pad = w*stride_w - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) {
img[(c_im * height + h_pad) * width + w_pad] +=
col[(c * height_col + h) * width_col + w];
}
}
}
}
}
extern "C" {
void im2col_float(const float *img, float *col, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
im2col(img, col, width, height, channels, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
}
void im2col_double(const double *img, double *col, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
im2col(img, col, width, height, channels, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
}
void col2im_float(const float *col, float *img, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
col2im(col, img, width, height, channels, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
}
void col2im_double(const double *col, double *img, int width, int height, int channels,
int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h) {
col2im(col, img, width, height, channels, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
}
}