# 4. Reshaping Tensors: flatten, reshape, view, squeeze, unsqueeze

Change the shape of tensors without changing the data!
Essential for preparing data for neural networks.


In [11]:
import torch


## 1. Flatten: Make Tensor 1D

Flatten converts a tensor to 1D (all dimensions merged)!
Very useful for feeding data into fully connected layers.


In [15]:
# Create a 2D tensor
t = torch.tensor([[1, 2, 3], 
                  [4, 5, 6]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

# Flatten to 1D
t_flat = t.flatten()
print("Flattened tensor:")
print(t_flat)
print(f"Shape: {t_flat.shape}")
print()

# Flatten from start dimension
t_flat_start = torch.flatten(t, start_dim=0)
print("Flattened from start_dim=0:")
print(t_flat_start)
print(f"Shape: {t_flat_start.shape}")
print()

# 3D example
t_3d = torch.tensor([[[1, 2], [3, 4]], 
                     [[5, 6], [7, 8]]], dtype=torch.float32)
print("3D tensor:")
print(t_3d)
print(f"Shape: {t_3d.shape}")
print()

t_3d_flat = torch.flatten(t_3d, start_dim=1)
print("Flattened 3D tensor:")
print(t_3d_flat)
print(f"Shape: {t_3d_flat.shape}")


Original tensor:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
Shape: torch.Size([2, 3])

Flattened tensor:
tensor([1., 2., 3., 4., 5., 6.])
Shape: torch.Size([6])

Flattened from start_dim=0:
tensor([1., 2., 3., 4., 5., 6.])
Shape: torch.Size([6])

3D tensor:
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
Shape: torch.Size([2, 2, 2])

Flattened 3D tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])


## 2. Reshape: Change Shape

Reshape changes the shape while keeping all elements!
Total number of elements must stay the same.


In [6]:
# Create a tensor
t = torch.tensor([[1, 2, 3, 4], 
                  [5, 6, 7, 8]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape} (2×4 = 8 elements)")
print()

# Reshape to 4×2
t_reshaped = t.reshape(4, 2)
print("Reshaped to (4, 2):")
print(t_reshaped)
print(f"Shape: {t_reshaped.shape} (4×2 = 8 elements)")
print()

# Reshape to 1D
t_1d = t.reshape(8)
print("Reshaped to 1D:")
print(t_1d)
print(f"Shape: {t_1d.shape}")
print()

# Reshape with -1 (auto-calculate)
t_auto = t.reshape(-1)  # -1 means "figure it out"
print("Reshaped with -1:")
print(t_auto)
print(f"Shape: {t_auto.shape}")
print()

# Reshape 3D to 2D
t_3d = torch.arange(24).reshape(2, 3, 4)
print("3D tensor (2×3×4):")
print(t_3d)
print(f"Shape: {t_3d.shape}")
print()

t_2d_from_3d = t_3d.reshape(6, 4)
print("Reshaped to 2D (6×4):")
print(t_2d_from_3d)
print(f"Shape: {t_2d_from_3d.shape}")


Original tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4]) (2×4 = 8 elements)

Reshaped to (4, 2):
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
Shape: torch.Size([4, 2]) (4×2 = 8 elements)

Reshaped to 1D:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
Shape: torch.Size([8])

Reshaped with -1:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
Shape: torch.Size([8])

3D tensor (2×3×4):
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
Shape: torch.Size([2, 3, 4])

Reshaped to 2D (6×4):
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])
Shape: torch.Size([6, 4])


In [7]:
# Create a tensor
t = torch.tensor([[1, 2, 3, 4], 
                  [5, 6, 7, 8]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

# View (similar to reshape)
t_view = t.view(4, 2)
print("View as (4, 2):")
print(t_view)
print(f"Shape: {t_view.shape}")
print()

# View shares memory with original
print("View shares memory:")
print(f"t[0, 0] = {t[0, 0]}")
print(f"t_view[0, 0] = {t_view[0, 0]}")
print()

# Modify view, original changes too!
t_view[0, 0] = 99
print("After modifying t_view[0, 0] = 99:")
print("Original t:")
print(t)
print("View t_view:")
print(t_view)
print("(Both changed - they share memory!)")
print()

# Reshape vs View
print("Difference:")
print("- reshape(): Works even if tensor is not contiguous (copies if needed)")
print("- view(): Requires contiguous tensor (faster, shares memory)")
print("- Both change shape without changing data")


Original tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])

View as (4, 2):
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]])
Shape: torch.Size([4, 2])

View shares memory:
t[0, 0] = 1.0
t_view[0, 0] = 1.0

After modifying t_view[0, 0] = 99:
Original t:
tensor([[99.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.]])
View t_view:
tensor([[99.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.]])
(Both changed - they share memory!)

Difference:
- reshape(): Works even if tensor is not contiguous (copies if needed)
- view(): Requires contiguous tensor (faster, shares memory)
- Both change shape without changing data


## 4. Squeeze: Remove Dimensions of Size 1

Squeeze removes dimensions with size 1!
Useful for removing unnecessary dimensions.


In [8]:
# Create tensor with size-1 dimensions
t = torch.tensor([[[1, 2, 3]]], dtype=torch.float32)
print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

# Squeeze all size-1 dimensions
t_squeezed = t.squeeze()
print("Squeezed (all size-1 dims removed):")
print(t_squeezed)
print(f"Shape: {t_squeezed.shape}")
print()

# Squeeze specific dimension
t2 = torch.tensor([[[1], [2], [3]]], dtype=torch.float32)
print("Another tensor:")
print(t2)
print(f"Shape: {t2.shape}")
print()

t2_squeezed_dim = t2.squeeze(dim=0)
print("Squeezed dimension 0:")
print(t2_squeezed_dim)
print(f"Shape: {t2_squeezed_dim.shape}")
print()

# Squeeze dimension 2
t2_squeezed_dim2 = t2.squeeze(dim=2)
print("Squeezed dimension 2:")
print(t2_squeezed_dim2)
print(f"Shape: {t2_squeezed_dim2.shape}")


Original tensor:
tensor([[[1., 2., 3.]]])
Shape: torch.Size([1, 1, 3])

Squeezed (all size-1 dims removed):
tensor([1., 2., 3.])
Shape: torch.Size([3])

Another tensor:
tensor([[[1.],
         [2.],
         [3.]]])
Shape: torch.Size([1, 3, 1])

Squeezed dimension 0:
tensor([[1.],
        [2.],
        [3.]])
Shape: torch.Size([3, 1])

Squeezed dimension 2:
tensor([[1., 2., 3.]])
Shape: torch.Size([1, 3])


## 5. Unsqueeze: Add Dimension of Size 1

Unsqueeze adds a new dimension of size 1!
Useful for adding batch dimension or matching shapes.


In [9]:
# Create a 1D tensor
t = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

# Unsqueeze at dimension 0 (add dimension at start)
t_unsqueezed0 = t.unsqueeze(0)
print("Unsqueezed at dim 0:")
print(t_unsqueezed0)
print(f"Shape: {t_unsqueezed0.shape}")
print()

# Unsqueeze at dimension 1 (add dimension in middle)
t_unsqueezed1 = t.unsqueeze(1)
print("Unsqueezed at dim 1:")
print(t_unsqueezed1)
print(f"Shape: {t_unsqueezed1.shape}")
print()

# Unsqueeze at dimension -1 (add dimension at end)
t_unsqueezed_end = t.unsqueeze(-1)
print("Unsqueezed at dim -1 (end):")
print(t_unsqueezed_end)
print(f"Shape: {t_unsqueezed_end.shape}")
print()

# Common use case: add batch dimension
batch = t.unsqueeze(0)
print("Adding batch dimension:")
print(f"Original: {t.shape}")
print(f"With batch: {batch.shape}")
print("(Now ready for batch processing!)")


Original tensor:
tensor([1., 2., 3., 4.])
Shape: torch.Size([4])

Unsqueezed at dim 0:
tensor([[1., 2., 3., 4.]])
Shape: torch.Size([1, 4])

Unsqueezed at dim 1:
tensor([[1.],
        [2.],
        [3.],
        [4.]])
Shape: torch.Size([4, 1])

Unsqueezed at dim -1 (end):
tensor([[1.],
        [2.],
        [3.],
        [4.]])
Shape: torch.Size([4, 1])

Adding batch dimension:
Original: torch.Size([4])
With batch: torch.Size([1, 4])
(Now ready for batch processing!)


## 6. Comparing All Methods

Let's see all reshaping methods together!


In [10]:
# Create a tensor
t = torch.tensor([[1, 2, 3, 4], 
                  [5, 6, 7, 8]], dtype=torch.float32)

print("Original tensor:")
print(t)
print(f"Shape: {t.shape}")
print()

print("=" * 60)
print("FLATTEN:")
print("=" * 60)
t_flat = t.flatten()
print(f"t.flatten(): {t_flat.shape}")
print()

print("=" * 60)
print("RESHAPE:")
print("=" * 60)
t_reshaped = t.reshape(4, 2)
print(f"t.reshape(4, 2): {t_reshaped.shape}")
t_reshaped2 = t.reshape(8)
print(f"t.reshape(8): {t_reshaped2.shape}")
print()

print("=" * 60)
print("VIEW:")
print("=" * 60)
t_view = t.view(4, 2)
print(f"t.view(4, 2): {t_view.shape}")
print("(Shares memory with original)")
print()

print("=" * 60)
print("SQUEEZE:")
print("=" * 60)
t_with_ones = torch.tensor([[[1, 2, 3]]], dtype=torch.float32)
print(f"Original: {t_with_ones.shape}")
t_squeezed = t_with_ones.squeeze()
print(f"t.squeeze(): {t_squeezed.shape}")
print()

print("=" * 60)
print("UNSQUEEZE:")
print("=" * 60)
t_1d = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
print(f"Original: {t_1d.shape}")
t_unsqueezed = t_1d.unsqueeze(0)
print(f"t.unsqueeze(0): {t_unsqueezed.shape}")
print()

print("=" * 60)
print("SUMMARY:")
print("=" * 60)
print("flatten(): Make 1D")
print("reshape(): Change shape (may copy)")
print("view(): Change shape (shares memory, faster)")
print("squeeze(): Remove size-1 dimensions")
print("unsqueeze(): Add size-1 dimension")


Original tensor:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
Shape: torch.Size([2, 4])

FLATTEN:
t.flatten(): torch.Size([8])

RESHAPE:
t.reshape(4, 2): torch.Size([4, 2])
t.reshape(8): torch.Size([8])

VIEW:
t.view(4, 2): torch.Size([4, 2])
(Shares memory with original)

SQUEEZE:
Original: torch.Size([1, 1, 3])
t.squeeze(): torch.Size([3])

UNSQUEEZE:
Original: torch.Size([4])
t.unsqueeze(0): torch.Size([1, 4])

SUMMARY:
flatten(): Make 1D
reshape(): Change shape (may copy)
view(): Change shape (shares memory, faster)
squeeze(): Remove size-1 dimensions
unsqueeze(): Add size-1 dimension


## 7. Key Takeaways

**Reshaping operations:**
- `flatten()` - make tensor 1D
- `reshape(shape)` - change shape (may copy)
- `view(shape)` - change shape (shares memory, faster)
- `squeeze(dim)` - remove size-1 dimension
- `unsqueeze(dim)` - add size-1 dimension

**Important:**
- Total elements must stay the same for reshape/view
- View requires contiguous memory
- Squeeze/unsqueeze only affect size-1 dimensions
- Use -1 in reshape to auto-calculate dimension size

**Remember:** Reshaping doesn't change data, only how it's organized!
