<a href="https://colab.research.google.com/github/prakash-bisht/Pytorch_Basic/blob/master/torch_math_operation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

# Basic examples of torch.cat
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.cat([x, y], dim=0)  # Concatenate along dimension 0
print("Concatenated along dim=0:")
print(z)
print("Shape:", z.shape)

z = torch.cat([x, y], dim=1)  # Concatenate along dimension 1
print("\nConcatenated along dim=1:")
print(z)
print("Shape:", z.shape)

# Example with dim = -1
# For a 2D tensor, dim=-1 is equivalent to dim=1
z = torch.cat([x, y], dim=-1) # Concatenate along the last dimension
print("\nConcatenated along dim=-1:")
print(z)
print("Shape:", z.shape)

# Example with 3D tensors
a = torch.randn(1, 2, 3)
b = torch.randn(1, 2, 3)

c = torch.cat([a, b], dim=0) # Concatenate along dimension 0
print("\nConcatenated 3D tensors along dim=0:")
print(c)
print("Shape:", c.shape)

c = torch.cat([a, b], dim=1) # Concatenate along dimension 1
print("\nConcatenated 3D tensors along dim=1:")
print(c)
print("Shape:", c.shape)

c = torch.cat([a, b], dim=2) # Concatenate along dimension 2
print("\nConcatenated 3D tensors along dim=2:")
print(c)
print("Shape:", c.shape)

c = torch.cat([a, b], dim=-1) # Concatenate along the last dimension (dim=2)
print("\nConcatenated 3D tensors along dim=-1:")
print(c)
print("Shape:", c.shape)

Concatenated along dim=0:
tensor([[-0.9952,  0.6874,  0.5074],
        [ 0.4323,  0.4121, -0.8569],
        [ 0.3915, -0.9231, -1.1468],
        [-0.1419, -0.1024, -1.2639]])
Shape: torch.Size([4, 3])

Concatenated along dim=1:
tensor([[-0.9952,  0.6874,  0.5074,  0.3915, -0.9231, -1.1468],
        [ 0.4323,  0.4121, -0.8569, -0.1419, -0.1024, -1.2639]])
Shape: torch.Size([2, 6])

Concatenated along dim=-1:
tensor([[-0.9952,  0.6874,  0.5074,  0.3915, -0.9231, -1.1468],
        [ 0.4323,  0.4121, -0.8569, -0.1419, -0.1024, -1.2639]])
Shape: torch.Size([2, 6])

Concatenated 3D tensors along dim=0:
tensor([[[-0.2414,  1.3774,  0.0271],
         [ 0.5963,  0.6126,  0.2702]],

        [[ 0.2197, -0.7314, -0.3730],
         [ 1.1572,  1.2152, -0.4455]]])
Shape: torch.Size([2, 2, 3])

Concatenated 3D tensors along dim=1:
tensor([[[-0.2414,  1.3774,  0.0271],
         [ 0.5963,  0.6126,  0.2702],
         [ 0.2197, -0.7314, -0.3730],
         [ 1.1572,  1.2152, -0.4455]]])
Shape: torch.Size([

In [None]:
# Colab-ready: inspect RoPE slicing, rotation, and broadcasting step-by-step
import torch
torch.set_printoptions(precision=4, sci_mode=False)

# Small deterministic example:
batch_size = 1
num_heads = 1
seq_len = 2
head_dim = 2   # must be even

# x shape: (batch_size, num_heads, seq_len, head_dim)
# Values chosen to match the hand-calculation in the explanation:
x = torch.tensor([[[[1.0, 2.0],
                    [3.0, 4.0]]]])  # shape (1,1,2,2)

print("x shape:", x.shape)
print("x:\n", x)

# Split last dim into two halves: x1 = first half, x2 = second half
x1 = x[..., : head_dim // 2]   # (1,1,2,1)
x2 = x[..., head_dim // 2 :]   # (1,1,2,1)

print("\nAfter split:")
print("x1 shape:", x1.shape, "x1:\n", x1)
print("x2 shape:", x2.shape, "x2:\n", x2)

# rotated = cat((-x2), x1) along last dim
rotated = torch.cat((-x2, x1), dim=-1)  # (1,1,2,2)
print("\nrotated shape:", rotated.shape)
print("rotated:\n", rotated)
# Explanation per position:
# pos0: x=[1,2] -> x1=[1], x2=[2] -> rotated=[-2, 1]
# pos1: x=[3,4] -> x1=[3], x2=[4] -> rotated=[-4, 3]

# Create toy cos and sin arrays (context_length, head_dim).
# Here context_length == seq_len == 2, head_dim == 2
# Using the simple cos/sin used in the explanation:
# pos0: cos=[1,1], sin=[0,0]   (angle 0)
# pos1: cos=[0,0], sin=[1,1]   (angle pi/2)
cos = torch.tensor([[1.0, 1.0],
                    [0.0, 0.0]])   # shape (2,2)
sin = torch.tensor([[0.0, 0.0],
                    [1.0, 1.0]])   # shape (2,2)

print("\ncos shape:", cos.shape)
print("cos:\n", cos)
print("\nsin shape:", sin.shape)
print("sin:\n", sin)

# Take only first seq_len rows (here same) and unsqueeze to broadcast:
cos_slice = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # shape (1,1,2,2)
sin_slice = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

print("\ncos_slice shape (after unsqueeze):", cos_slice.shape)
print("cos_slice:\n", cos_slice)
print("\nsin_slice shape (after unsqueeze):", sin_slice.shape)
print("sin_slice:\n", sin_slice)

# Now the elementwise operation:
x_times_cos = x * cos_slice        # broadcast (1,1,2,2) * (1,1,2,2) -> (1,1,2,2)
rotated_times_sin = rotated * sin_slice
x_rotated = x_times_cos + rotated_times_sin

print("\nIntermediate results:")
print("x * cos:\n", x_times_cos)
print("rotated * sin:\n", rotated_times_sin)
print("\nFinal x_rotated:\n", x_rotated)

# Sanity: expected outputs from hand calculation:
# pos0: [1,2]  (unchanged because cos=[1,1], sin=[0,0])
# pos1: [-4,3] (because cos=[0,0], sin=[1,1], rotated pos1 is [-4,3])

# OPTIONAL: a small helper function that performs the same computation (no debug prints)
def compute_rope_simple(x, cos, sin):
    # x: (B, H, S, D)
    B, H, S, D = x.shape
    x1 = x[..., :D//2]
    x2 = x[..., D//2:]
    rotated = torch.cat((-x2, x1), dim=-1)
    cos_b = cos[:S, :].unsqueeze(0).unsqueeze(0)  # (1,1,S,D)
    sin_b = sin[:S, :].unsqueeze(0).unsqueeze(0)
    return (x * cos_b) + (rotated * sin_b)

x_rotated_func = compute_rope_simple(x, cos, sin)
print("\ncompute_rope_simple result matches manual?", torch.allclose(x_rotated, x_rotated_func))


x shape: torch.Size([1, 1, 2, 2])
x:
 tensor([[[[1., 2.],
          [3., 4.]]]])

After split:
x1 shape: torch.Size([1, 1, 2, 1]) x1:
 tensor([[[[1.],
          [3.]]]])
x2 shape: torch.Size([1, 1, 2, 1]) x2:
 tensor([[[[2.],
          [4.]]]])

rotated shape: torch.Size([1, 1, 2, 2])
rotated:
 tensor([[[[-2.,  1.],
          [-4.,  3.]]]])

cos shape: torch.Size([2, 2])
cos:
 tensor([[1., 1.],
        [0., 0.]])

sin shape: torch.Size([2, 2])
sin:
 tensor([[0., 0.],
        [1., 1.]])

cos_slice shape (after unsqueeze): torch.Size([1, 1, 2, 2])
cos_slice:
 tensor([[[[1., 1.],
          [0., 0.]]]])

sin_slice shape (after unsqueeze): torch.Size([1, 1, 2, 2])
sin_slice:
 tensor([[[[0., 0.],
          [1., 1.]]]])

Intermediate results:
x * cos:
 tensor([[[[1., 2.],
          [0., 0.]]]])
rotated * sin:
 tensor([[[[-0.,  0.],
          [-4.,  3.]]]])

Final x_rotated:
 tensor([[[[ 1.,  2.],
          [-4.,  3.]]]])

compute_rope_simple result matches manual? True


In [None]:
x = torch.randn(2, 3)
print(x)

tensor([[-0.7488,  0.9776,  1.0700],
        [ 1.0171,  0.0370, -1.1758]])


In [None]:
print(x[1,:])

tensor([ 1.0171,  0.0370, -1.1758])


In [None]:
print(x[:,:1])

tensor([[-0.7488],
        [ 1.0171]])


In [None]:
positions = torch.arange(6)
positions

tensor([0, 1, 2, 3, 4, 5])

In [None]:
angles = positions[:, None]
angles

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5]])

In [None]:
angles.shape

torch.Size([6, 1])

In [None]:
angles_ = positions[None,:]
angles_

tensor([[0, 1, 2, 3, 4, 5]])

In [None]:
import torch

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
cond = torch.tensor([True, False, True, False])

result = torch.where(cond, a, b)
print(result)  # tensor([ 1, 20,  3, 40])


tensor([ 1, 20,  3, 40])


In [None]:
import torch
keys = torch.tensor([[1, 2, 3],
                     [4, 5, 6]])   # shape = (2, 3)

out = keys.repeat_interleave(2, dim=1)
print(out)


tensor([[1, 1, 2, 2, 3, 3],
        [4, 4, 5, 5, 6, 6]])


In [None]:
import torch

scores = torch.tensor([[0.1, 0.3, 0.2, 0.9]])
topk_scores, topk_indices = torch.topk(scores, k=2, dim=-1)

print(topk_scores)   # tensor([[0.9000, 0.3000]])
print(topk_indices)  # tensor([[3, 1]])
topk_probs = torch.softmax(topk_scores, dim=-1)
print(topk_probs)

tensor([[0.9000, 0.3000]])
tensor([[3, 1]])
tensor([[0.6457, 0.3543]])


In [None]:
import torch

gating_probs = torch.zeros(1, 4)
indices = torch.tensor([[2, 0]])
prob = torch.tensor([[0.6, 0.4]])

gating_probs.scatter_(dim=-1, index=indices, src=prob)
print(gating_probs)
# tensor([[0.4000, 0.0000, 0.6000, 0.0000]])


tensor([[0.4000, 0.0000, 0.6000, 0.0000]])


In [None]:
import torch
import torch.nn as nn

emb_dim = 4
num_experts = 3
gate = nn.Linear(emb_dim, num_experts, bias=False)

print(gate.weight.shape)


torch.Size([3, 4])
