# TensorConversionwithPermuteAndView
# 问题： 在pytorch中使用permute和view实现空间通道、特征通道或者batch通道之间的互相转换。
*****************************************
最近用pixel shuffle模块实现空间通道和特征通道之间的互相转换，由于自己的突发奇想，想利用空间通道和batch通道的互换将一些循环操作可变成并行操作（感觉就是在给自己找事）。问题不是很难，但是很容易粗心写错了转换顺序，从而使处理效果变差（虽然我后来改正了，训练效果也没有好，o(╥﹏╥)o）。这里给出我自己在check顺序时的代码，方便理解。

In [None]:
'''
code for pixel_shuffle from:  https://github.com/myungsub/CAIN/blob/master/model/common.py
'''
def pixel_shuffle(input, scale_factor):
    batch_size, channels, in_height, in_width = input.size()

    out_channels = int(int(channels / scale_factor) / scale_factor)
    out_height = int(in_height * scale_factor)
    out_width = int(in_width * scale_factor)

    if scale_factor >= 1:
        input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width)
        shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
    else:
        block_size = int(1 / scale_factor)
        input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size)
        shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()

    return shuffle_out.view(batch_size, out_channels, out_height, out_width)

class PixelShuffle(nn.Module):
    def __init__(self, scale_factor):
        super(PixelShuffle, self).__init__()
        self.scale_factor = scale_factor
    
    def forward(self, x):
        return pixel_shuffle(x, self.scale_factor)
    def extra_repr(self):
        return 'scale_factor={}'.format(self.scale_factor)


【懒得分解每个单元，一些解释就放在了代码的注释里，聪明的人一看就懂】

In [None]:
import torch
import torch.nn.functional as f
import matplotlib.pyplot as plt

# convert (N,C,H,W) - (N,C*16,H//4,W//4)
in_shuffler = PixelShuffle(1/4)
# convert (N,C*16,H//4,W//4) - (N,C,H,W) 
out_shuffler = PixelShuffle(4)

# generate a image set (2,3,1024,1024)
color_image1 = torch.from_numpy(np.array(Image.open("/path/for/image1")).astype(np.float32).transpose(2,0,1)/255.0).unsqueeze(0)
color_image2 = torch.from_numpy(np.array(Image.open("/path/for/image1")).astype(np.float32).transpose(2,0,1)/255.0).unsqueeze(0)
color_image = torch.cat((color_image1,color_image2),0)
color_image = f.interpolate(color_image,(1024,1024))
# generate its coarse version 
color_std = f.interpolate(color_image,(256,256)).permute(0,2,3,1).numpy()

color_image_patch = in_shuffler(color_image)
print(color_image_patch.shape)
# decompose the slices from the channel
color_image_patch1 = color_image_patch.view(2,3,16,256,256)
# convert it into the batch
color_image_patch1 = color_image_patch1.permute(0,2,1,3,4).contiguous().view(-1,3,256,256)
image_patch = color_image_patch1.permute(0,2,3,1).numpy()
plt.figure(figsize=(16,32))
for i in range(32):
    plt.subplot(8,4,i+1)
    # show the differenece between each slice from the fine version and its corresponding coarse version
    plt.imshow(abs(image_patch[i]-color_std[i//16]))
plt.show()

In [None]:
# reconstruct the origin data with the (N*16,C,H//4,W//4)
color_image_fuse  = out_shuffler(color_image_patch1.view(2,16,3,256,256).permute(0,2,1,3,4).contiguous().view(2,-1,256,256))
# reconstruct the origin data with the (N,C*16,H//4,W//4)
color_image_fuse_gt = out_shuffler(color_image_patch)
print(color_image_fuse.shape)

image_fuse = color_image_fuse.permute(0,2,3,1).numpy()
image_fuse_gt = color_image_fuse_gt.permute(0,2,3,1).numpy()
plt.figure(figsize=(32,16))
for i in range(2):
    plt.subplot(1,2,i+1)
    print(image_fuse[i].shape)
    # plot the difference between two reconstructed coarse images
    plt.imshow(image_fuse[i]-image_fuse_gt[i])
    print(image_fuse[i]-image_fuse_gt[i])
plt.show()


再后来，我还尝试将不同batch的特征通道进组合，举例就是,x1有x11、x12,x2有x21,x22，从(x11,x21) (x12,x22)变成（x11,x21）(x12,x21)(x11,x22)(x21,x22)。

In [None]:
color_set = torch.cat((color_image1.unsqueeze(2),color_image2.unsqueeze(2)),2)
color_set_x = color_set[:,0,:,:,:].unsqueeze(1).unsqueeze(2).repeat(1,1,2,1,1,1)
color_set_y = color_set[:,1,:,:,:].unsqueeze(1).unsqueeze(3).repeat(1,1,1,2,1,1)
print(color_set_x.shape,color_set_y.shape)
color_set_stack = torch.cat((color_set_x,color_set_y),1)
print(color_set_stack.shape)
image_fuse= color_set_stack.numpy()
plt.figure(figsize=(16,12))
for i in range(2):
    plt.subplot(2,4,i+1)
    print(color_set_stack[0,i,:].shape)
    plt.imshow(image_fuse[0,i,0,0])
    plt.subplot(2,4,i+3)
#     print(color_set_stack[0,i,:].shape)
    plt.imshow(image_fuse[0,i,0,1])
    plt.subplot(2,4,i+5)
#     print(color_set_stack[0,i,:].shape)
    plt.imshow(image_fuse[0,i,1,0])
                          
    plt.subplot(2,4,i+7)
    plt.imshow(image_fuse[0,i,1,1])
plt.show()