In [1]:
import torch as th

# Aim 
* Understand the tensor reshaping operations 

# Introduction 

Following are a few shape operations:
* Reshape: Reshape an input tensor to a target dimension.
* View: Returns a reshaped tensor that shares the same memory as the original tensor.
* Stack: Stacks tensors in a specified dimension.
* Squeeze: Returns a tensor by removing any dimension of size 1.
* Unsqueeze: Adds a dimension of size 1 at a specified dimension.
* Permute: Returns a view of the input by permuting the dimensions in a specific order.

# Reshape and View 

Reshape and view perform similar operations. View ensures input and output tensors share the same memory location. Reshape may or may not use the same memory location.

In [20]:
vector = th.arange(1, 11)

print(f'vector : {vector}') 
print(f'vector reshaped as 5, 2 : {vector.reshape(5, 2)}')
print(f'vector reshaped as 2, 5 : {vector.reshape(2, 5)}')

print(f'vector viewed as 5, 2 : {vector.view(5, 2)}')
print(f'vector viewed as 2, 5 : {vector.view(2, 5)}')

vector : tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
vector reshaped as 5, 2 : tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10]])
vector reshaped as 2, 5 : tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
vector viewed as 5, 2 : tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10]])
vector viewed as 2, 5 : tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])


* View shares the same memory address. So updating the reshaped tensor updates the original tensor.

In [21]:
reshaped_vector = vector.view(5, 2)
reshaped_vector[0, 0] = 1000
reshaped_vector, vector

(tensor([[1000,    2],
         [   3,    4],
         [   5,    6],
         [   7,    8],
         [   9,   10]]),
 tensor([1000,    2,    3,    4,    5,    6,    7,    8,    9,   10]))

# Stack
Concatinate a tensor along a new dimention

In [25]:
vector = th.arange(0, 5)

vector_stack_0 = th.stack([vector, vector, vector], dim=0)

vector_stack_1 = th.stack([vector, vector, vector], dim=1)


print(f'vector_stack_0 and shape : {vector_stack_0}, {vector_stack_0.shape}')
print(f'vector_stack_1 and shape : {vector_stack_1}, {vector_stack_1.shape}')


vector_stack_0 and shape : tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]]), torch.Size([3, 5])
vector_stack_1 and shape : tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), torch.Size([5, 3])


# Squeeze and Unsqueeze

Squeeze removes any dimension of size 1. Unsqueeze adds a dimension on a specified axis.

In [32]:
M = th.rand(1, 2, 3)

print(f'M and shape: {M}, {M.shape}')  
print(f'M.squeeze() and shape : {M.squeeze()}, {M.squeeze().shape}') 

N = th.tensor([[[1, 2, 3]]])

print(f'N and shape: {N}, {N.shape}')  
print(f'N.squeeze() and shape : {N.squeeze()}, {N.squeeze().shape}') 

M and shape: tensor([[[0.2786, 0.0214, 0.4552],
         [0.6411, 0.6284, 0.8966]]]), torch.Size([1, 2, 3])
M.squeeze() and shape : tensor([[0.2786, 0.0214, 0.4552],
        [0.6411, 0.6284, 0.8966]]), torch.Size([2, 3])
N and shape: tensor([[[1, 2, 3]]]), torch.Size([1, 1, 3])
N.squeeze() and shape : tensor([1, 2, 3]), torch.Size([3])


In [34]:
P = th.rand(1, 2, 1)
Q = P.unsqueeze(0)
R = P.unsqueeze(1)
S = P.unsqueeze(2)

print(f'P and shape: {P}, {P.shape}')
print(f'Q and shape: {Q}, {Q.shape}')
print(f'R and shape: {R}, {R.shape}')
print(f'S and shape: {S}, {S.shape}')

P and shape: tensor([[[0.5835],
         [0.8572]]]), torch.Size([1, 2, 1])
Q and shape: tensor([[[[0.5835],
          [0.8572]]]]), torch.Size([1, 1, 2, 1])
R and shape: tensor([[[[0.5835],
          [0.8572]]]]), torch.Size([1, 1, 2, 1])
S and shape: tensor([[[[0.5835]],

         [[0.8572]]]]), torch.Size([1, 2, 1, 1])


# Permute 

Returns a view of a tensor by permuting the dimension in a specified way. 

In [36]:
IMAGE = th.rand(256, 256, 3)

IMAGE_3_256_256 = IMAGE.permute(2, 0, 1)

print(f'IMAGE shape: {IMAGE.shape}')
print(f'IMAGE_3_256_256 shape: {IMAGE_3_256_256.shape}')

IMAGE shape: torch.Size([256, 256, 3])
IMAGE_3_256_256 shape: torch.Size([3, 256, 256])
