# Manipulating tensor shapes

In [1]:
import torch

## 1. squeeze and unsqueeze

### 1.1 unsqueeze
1. pytorch默认是按照batch处理数据，以image为例，默认的input大小是(N, C, H, W)。如果想让model处理单个数据，就要把当个样本的大小从(C, H, W)改成(1, C, H, W)
2. 常用于ease broadcast

In [2]:
# 用tensor.unsqueeze()来增加长度为1的新维度
a = torch.rand(3, 226, 226)
b = a.unsqueeze(0) # 在第0维增加一个维度
print('增加一个长度为1的维度到新的第0维：',b.shape)

c = a.unsqueeze(1)
print('增加一个长度为1的维度到新的第1维：',c.shape)

增加一个长度为1的维度到新的第0维： torch.Size([1, 3, 226, 226])
增加一个长度为1的维度到新的第1维： torch.Size([3, 1, 226, 226])


In [3]:
# unsqueeze常用于方便broadcast
a = torch.ones(4, 3, 2)
b = torch.rand(   3)     # a * b不能直接运算
c = b.unsqueeze(1)       # change to a 2-dimensional tensor, adding new dim at the end
print(c.shape)
print(a * c)             # broadcasting works again!

torch.Size([3, 1])
tensor([[[0.6783, 0.6783],
         [0.1558, 0.1558],
         [0.2602, 0.2602]],

        [[0.6783, 0.6783],
         [0.1558, 0.1558],
         [0.2602, 0.2602]],

        [[0.6783, 0.6783],
         [0.1558, 0.1558],
         [0.2602, 0.2602]],

        [[0.6783, 0.6783],
         [0.1558, 0.1558],
         [0.2602, 0.2602]]])


### 1.2 squeeze

In [4]:
#  用tensor.squeeze()来压缩长度为1的维度(dimensions of extent 1)
a = torch.rand(1, 2, 1, 4)
b = a.squeeze()   # 压缩所有维度为1的dims
c = a.squeeze(0)  # 压缩第0维
print(b.shape)
print(c.shape)

torch.Size([2, 4])
torch.Size([2, 1, 4])


## 2. reshape
1. pyrotch要求'shape'参数必须是tuple of ints。但是当shape是第一个参数的时候，可以用series if integers或者单个integer作为shape的值，但如果shape不是第一个参数，就必须是tuple。\
· 注：在extent of dim为1时，要注意，x=(3)是int，x=(3,)是tuple
2. 通常情况下，reshape返回的tensor只是原tensor的一个view，两者指向的是相同的memory location。但实际使用的时候，不要依赖这里的view vs. copy关系，不然容易出错。\
· reshape实际上是打包了两种处理方式的method，当reshape的input tensor和outpu tensor的shape是compatible的时候，它调用的是tensor.view()，此时没有copy发生，当两者的shape不兼容的时候，它调用tensor.contiguous()，这时候就会发生copy。

In [5]:
# x=(3)是int，x=(3,)是tuple
x = (1, 2, 3)
y = (3)
z = (2,)
type(x), type(y), type(z)

(tuple, int, tuple)

In [6]:
output3d = torch.ones(2, 2, 3)

## 用tensor.reshape,这时shape参数是第一个参数
input1d = output3d.reshape(12) # 12 = 2 * 2 * 3
print(input1d, input1d.shape)

## 用torch module的reshape method，这时shape参数不是第一个参数
# input1d_2 = torch.reshape(output3d, (12)) # 错,(12)被识别为int12
input1d_2 = torch.reshape(output3d, (12,))  # 用(12,)才是tuple
print(input1d_2.shape)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) torch.Size([12])
torch.Size([12])


In [7]:
# 通常情况下，reshape返回的tensor与原tensor指向的是相同的memory location
output3d[0] = 0
print(output3d)
print(input1d)

tensor([[[0., 0., 0.],
         [0., 0., 0.]],

        [[1., 1., 1.],
         [1., 1., 1.]]])
tensor([0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.])


### reshape在什么时候是view，什么时候不是？
1. output tensor的size and stride要与input tensor的size and stride兼容
2. 规则：dim in the new tensor要么是原tensor dims的subspace，或者span across original dimensions（doc原文描述不清楚，看案例）

In [8]:
## view前后shape兼容
x = torch.arange(16).reshape(4, 4)
print("原shape:", x.size())
y = x.view(16)
print("用view改城1维后:", y.size())
z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
print("用view改城2维后:", z.size())
x[2] = 0
print('改变x后，z也改变:\n',z)

原shape: torch.Size([4, 4])
用view改城1维后: torch.Size([16])
用view改城2维后: torch.Size([2, 8])
改变x后，z也改变:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 0,  0,  0,  0, 12, 13, 14, 15]])


In [9]:
## 前后shape不兼容，但是不想做copy的话，用transpose

a = torch.arange(24).reshape(1, 2, 3, 4)
print("原shape:", '\n', a)
print(a.size(), id(a), '\n')

b = a.transpose(1, 2)  # Swaps 2nd and 3rd dimension
print("用transpose交换了2nd和3rd dim:", '\n', b)
print(b.size(), id(b), '\n')

c = a.view(1, 3, 2, 4)  # 不改变tensor layout in memory
print("用view不改变元素在memory中的排序方式:", '\n', c)
print(c.size(), id(c), '\n')

# c和b只是shape相同，但是元素值不同
torch.equal(b, c), c is a

原shape: 
 tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]]])
torch.Size([1, 2, 3, 4]) 140126738017120 

用transpose交换了2nd和3rd dim: 
 tensor([[[[ 0,  1,  2,  3],
          [12, 13, 14, 15]],

         [[ 4,  5,  6,  7],
          [16, 17, 18, 19]],

         [[ 8,  9, 10, 11],
          [20, 21, 22, 23]]]])
torch.Size([1, 3, 2, 4]) 140126747618864 

用view不改变元素在memory中的排序方式: 
 tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7]],

         [[ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23]]]])
torch.Size([1, 3, 2, 4]) 140126738019200 



(False, False)

In [10]:
# 改变a中元素值, transpose和view对应的值都会改变
a[0][1][2] = 100
print(a, '\n')
print(b, '\n')
print(c, '\n')

tensor([[[[  0,   1,   2,   3],
          [  4,   5,   6,   7],
          [  8,   9,  10,  11]],

         [[ 12,  13,  14,  15],
          [ 16,  17,  18,  19],
          [100, 100, 100, 100]]]]) 

tensor([[[[  0,   1,   2,   3],
          [ 12,  13,  14,  15]],

         [[  4,   5,   6,   7],
          [ 16,  17,  18,  19]],

         [[  8,   9,  10,  11],
          [100, 100, 100, 100]]]]) 

tensor([[[[  0,   1,   2,   3],
          [  4,   5,   6,   7]],

         [[  8,   9,  10,  11],
          [ 12,  13,  14,  15]],

         [[ 16,  17,  18,  19],
          [100, 100, 100, 100]]]]) 



In [11]:
a = [1, 2, 3, 4, 5, 6]
b = a
id(a), id(b), b is a

(140126738153600, 140126738153600, True)