# **Reshape, Permute, Squeeze, Unsqueeze using einops**
* The content is followed using "Reshape,Permute,Squeeze,Unsqueeze made simple using einops | The Gems".<br>Reference: https://www.youtube.com/watch?v=xGy75Pjsqzo&list=PLivJwLo9VCUJXdO8SiOjZTWr_fXrAy4OQ&index=5
* Extended by **Vigyannveshi** 

In [1]:
import einops
import torch as tr
import numpy as np
import torch.nn as nn

In [2]:
x=tr.randn(size=(1,3,85,13,13))
x.shape

torch.Size([1, 3, 85, 13, 13])

In [3]:
# I would like the shape to be 
# (1, 3, 13, 13, 85)

y=x.permute(0,1,3,4,2)
y.shape

torch.Size([1, 3, 13, 13, 85])

In [4]:
y1 = einops.rearrange(x, "b num_anchors p h w -> b num_anchors h w p")
y1.shape

torch.Size([1, 3, 13, 13, 85])

In [5]:
# I would like the reshape 
# from (1, 507, 85)
# to (1, 3, 13, 13, 85)

x2 = tr.randn(size=(1, 507, 85))
x2.shape

torch.Size([1, 507, 85])

In [6]:
y3 = einops.rearrange(
    x2, 
    "b (num_anchors h w) p -> b num_anchors h w p", 
    num_anchors=3, h=13, w=13)

y3.shape

torch.Size([1, 3, 13, 13, 85])

In [7]:
pred = tr.randn(size=(1, 3 * 85, 13, 13))

pred_reshaped = pred.permute(0, 2, 3, 1).contiguous().view(1, -1, 85)

pred.shape, pred_reshaped.shape

(torch.Size([1, 255, 13, 13]), torch.Size([1, 507, 85]))

In [8]:
ein_pred = einops.rearrange(pred, 
                            "b (num_anchors p) h w -> b (num_anchors h w) p",
                            num_anchors=3,
                            h=13,
                            w=13)

ein_pred.shape

torch.Size([1, 507, 85])

In [9]:
# let's handle the ugly reshape

w = 13
h = 13

t = tr.arange(w, dtype=tr.float32)

t.shape

torch.Size([13])

In [10]:
ugly_c_x = t.reshape(1, 1, -1, 1)

ugly_c_x.shape

torch.Size([1, 1, 13, 1])

In [11]:
nice_c_x = einops.rearrange(t, "w -> 1 1 w 1")

nice_c_x.shape

torch.Size([1, 1, 13, 1])

In [13]:
tr.allclose(ugly_c_x, nice_c_x)

True

In [15]:
tr.squeeze(nice_c_x).shape

torch.Size([13])

In [None]:
tr.squeeze(nice_c_x, 0).shape

In [16]:
einops.rearrange(nice_c_x, "1 1 w 1 -> 1 w").shape

torch.Size([1, 13])

In [17]:
# The Rarrange pytorch layer
from einops.layers.torch import Rearrange

In [19]:
class ANeuralNetwork(nn.Module):
  def __init__(self, 
               in_channels:int,
               num_anchors_per_cell:int, 
               num_classes:int):
    
    super().__init__()

    num_predicted_channels = num_anchors_per_cell * (4 + 1 + num_classes)

    self.conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=num_predicted_channels,
        kernel_size=1,
        stride=1,
    )

    self.rearrange = Rearrange("b (num_anchors_per_cell p) h w -> b num_anchors_per_cell h w p",
                               num_anchors_per_cell=num_anchors_per_cell)


  def forward(self, x: tr.Tensor) -> tr.Tensor:
    x = self.conv(x)
    x = self.rearrange(x)
    return x
  

In [21]:
net = ANeuralNetwork(in_channels=512, num_anchors_per_cell=3, num_classes=80)

input_x = tr.randn(size=(1, 512, 13, 13))

output = net(input_x)

output.shape

torch.Size([1, 3, 13, 13, 85])