In [2]:
import torch 
import torch.nn as nn

In [5]:
x = torch.rand((2, 3, 224, 224))
print(x.shape)
# x.flatten() concept

# x.shape: (B, C, H, W)


x = x.flatten( 1 ) # starts from dim = 1, multiplies all subsequent dimensions
x.shape


torch.Size([2, 3, 224, 224])


torch.Size([2, 150528])

In [6]:
# Register Buffer

# - Scenario: Don't want to learn a tensor
# - Just a utility
    # - mask
    # - non-learnbale pos-encoding etc

# => Soln
# Make it part of model's state as a buffer NOT a param
# nn.Module has thismethod called register_buffer

class Temp(nn.Module):
    def __init__(self, num_pos):
        self.num_pos = num_pos

        self.pos_encodings = torch.arange(self.num_pos)
        self.register_buffer( name = "pos_encodings",
                             tensor = self.pos_encodings,
                              persistent=False ) # Don't want to be part of model's state_dict i.e
        # When loaded, don't want to retain it. When? Typically, when I can easily construct it.


In [4]:
# Expand and braoadcast technique
# This is a cool trick in scenarios, where we want to expand a dim by copying
# copying/broadcsating

x = torch.tensor([ [1, 2, 3], [4, 5, 6] ]) # (2, 3)
x_mask = (x >= 4)

print(x_mask)
print(x_mask.shape)


# (Batch = 2, seq_len = 3) -> Expand 3rd dim to 4 i.e embed_dim = 4
x_mask = x_mask[:, :, None].expand(-1, -1, 4)
print(x_mask.shape)

tensor([[False, False, False],
        [ True,  True,  True]])
torch.Size([2, 3])
torch.Size([2, 3, 4])


In [6]:
import torch

# Set device and define tensor dimensions
device = torch.device("cpu")
batch_size, seq_len, embed_dim = 2, 5, 10

# Initialize input embeddings with the desired device and data type
input_embeds = torch.randn(size=(batch_size, seq_len, embed_dim), device=device, dtype=torch.float32)
dtype = input_embeds.dtype  # Extract dtype from input_embeds

# Create the final embedding tensor with zeros, matching input's shape, device, and dtype
final_embedding = torch.zeros((batch_size, seq_len, embed_dim), dtype=dtype, device=device)

print(final_embedding)


tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
