In [2]:
import torch


In [15]:
d_k = 50
device='cpu'
theta = 100

In [16]:
torch.arange(0, d_k // 2, dtype=torch.float32, device=device)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.])

In [17]:
k_indices = torch.arange(0, d_k//2, device = device)
k_indices

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24])

In [18]:
freqs = theta**(2 * k_indices / d_k)
freqs

tensor([ 1.0000,  1.2023,  1.4454,  1.7378,  2.0893,  2.5119,  3.0200,  3.6308,
         4.3652,  5.2481,  6.3096,  7.5858,  9.1201, 10.9648, 13.1826, 15.8489,
        19.0546, 22.9087, 27.5423, 33.1131, 39.8107, 47.8630, 57.5440, 69.1831,
        83.1764])

In [23]:
torch.arange(20, dtype=torch.float32, device=device)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19.])

In [None]:
torch.outer()

In [25]:
freqs_neg = -1 * freqs
freqs_neg

tensor([ -1.0000,  -1.2023,  -1.4454,  -1.7378,  -2.0893,  -2.5119,  -3.0200,
         -3.6308,  -4.3652,  -5.2481,  -6.3096,  -7.5858,  -9.1201, -10.9648,
        -13.1826, -15.8489, -19.0546, -22.9087, -27.5423, -33.1131, -39.8107,
        -47.8630, -57.5440, -69.1831, -83.1764])

In [26]:
torch.stack([freqs, freqs_neg], dim=-1)

tensor([[  1.0000,  -1.0000],
        [  1.2023,  -1.2023],
        [  1.4454,  -1.4454],
        [  1.7378,  -1.7378],
        [  2.0893,  -2.0893],
        [  2.5119,  -2.5119],
        [  3.0200,  -3.0200],
        [  3.6308,  -3.6308],
        [  4.3652,  -4.3652],
        [  5.2481,  -5.2481],
        [  6.3096,  -6.3096],
        [  7.5858,  -7.5858],
        [  9.1201,  -9.1201],
        [ 10.9648, -10.9648],
        [ 13.1826, -13.1826],
        [ 15.8489, -15.8489],
        [ 19.0546, -19.0546],
        [ 22.9087, -22.9087],
        [ 27.5423, -27.5423],
        [ 33.1131, -33.1131],
        [ 39.8107, -39.8107],
        [ 47.8630, -47.8630],
        [ 57.5440, -57.5440],
        [ 69.1831, -69.1831],
        [ 83.1764, -83.1764]])

In [27]:
import torch

def demo_merge_techniques():
    # Example: d_k = 6, so d_k/2 = 3
    batch_size, seq_len, d_k = 2, 4, 6
    
    # Simulate rotated even and odd elements
    x_even_rot = torch.randn(batch_size, seq_len, d_k // 2)  # Shape: (2, 4, 3)
    x_odd_rot = torch.randn(batch_size, seq_len, d_k // 2)   # Shape: (2, 4, 3)
    
    print(f"x_even_rot shape: {x_even_rot.shape}")
    print(f"x_odd_rot shape: {x_odd_rot.shape}")
    print(f"Target shape: ({batch_size}, {seq_len}, {d_k})")
    
    print("\n" + "="*50)
    print("METHOD 1: Direct indexing (most efficient)")
    print("="*50)
    
    result1 = torch.empty(batch_size, seq_len, d_k)
    result1[..., 0::2] = x_even_rot  # Fill positions 0, 2, 4, ...
    result1[..., 1::2] = x_odd_rot   # Fill positions 1, 3, 5, ...
    print(f"Result1 shape: {result1.shape}")
    
    print("\n" + "="*50)
    print("METHOD 2: Stack + flatten (your approach fixed)")
    print("="*50)
    
    stacked = torch.stack([x_even_rot, x_odd_rot], dim=-1)  # (..., seq_len, d_k/2, 2)
    result2 = stacked.flatten(-2)  # Flatten last 2 dims: (..., seq_len, d_k)
    print(f"Stacked shape: {stacked.shape}")
    print(f"Result2 shape: {result2.shape}")
    
    print("\n" + "="*50)
    print("METHOD 3: Stack + reshape")
    print("="*50)
    
    stacked = torch.stack([x_even_rot, x_odd_rot], dim=-1)
    result3 = stacked.reshape(*stacked.shape[:-2], d_k)
    print(f"Result3 shape: {result3.shape}")
    
    print("\n" + "="*50)
    print("METHOD 4: Concatenate + rearrange (using einops)")
    print("="*50)
    
    from einops import rearrange
    
    # Stack along a new dimension, then rearrange
    stacked = torch.stack([x_even_rot, x_odd_rot], dim=-1)  # (..., d_k/2, 2)
    result4 = rearrange(stacked, '... pairs two -> ... (pairs two)')
    print(f"Result4 shape: {result4.shape}")
    
    print("\n" + "="*50)
    print("METHOD 5: Interleave using repeat_interleave")
    print("="*50)
    
    # This doesn't work directly but shows the concept
    # We'd need to manually interleave, which is essentially method 1
    
    print("\n" + "="*50)
    print("VERIFICATION: All methods give same result?")
    print("="*50)
    
    print(f"Method 1 == Method 2: {torch.allclose(result1, result2)}")
    print(f"Method 1 == Method 3: {torch.allclose(result1, result3)}")
    print(f"Method 1 == Method 4: {torch.allclose(result1, result4)}")
    
    print("\n" + "="*50)
    print("ELEMENT ORDER VERIFICATION")
    print("="*50)
    
    # Check that the elements are in the right order
    print("First few elements of result1:")
    print("Should be: even[0], odd[0], even[1], odd[1], even[2], odd[2]")
    print(f"result1[0, 0, :]: {result1[0, 0, :]}")
    print(f"x_even_rot[0, 0, :]: {x_even_rot[0, 0, :]}")
    print(f"x_odd_rot[0, 0, :]: {x_odd_rot[0, 0, :]}")

demo_merge_techniques()

x_even_rot shape: torch.Size([2, 4, 3])
x_odd_rot shape: torch.Size([2, 4, 3])
Target shape: (2, 4, 6)

METHOD 1: Direct indexing (most efficient)
Result1 shape: torch.Size([2, 4, 6])

METHOD 2: Stack + flatten (your approach fixed)
Stacked shape: torch.Size([2, 4, 3, 2])
Result2 shape: torch.Size([2, 4, 6])

METHOD 3: Stack + reshape
Result3 shape: torch.Size([2, 4, 6])

METHOD 4: Concatenate + rearrange (using einops)
Result4 shape: torch.Size([2, 4, 6])

METHOD 5: Interleave using repeat_interleave

VERIFICATION: All methods give same result?
Method 1 == Method 2: True
Method 1 == Method 3: True
Method 1 == Method 4: True

ELEMENT ORDER VERIFICATION
First few elements of result1:
Should be: even[0], odd[0], even[1], odd[1], even[2], odd[2]
result1[0, 0, :]: tensor([ 1.8610, -1.4852, -1.3570,  0.9508,  1.7945,  1.1790])
x_even_rot[0, 0, :]: tensor([ 1.8610, -1.3570,  1.7945])
x_odd_rot[0, 0, :]: tensor([-1.4852,  0.9508,  1.1790])


In [21]:
range(30)

range(0, 30)

In [28]:
from einops import rearrange, einsum, reduce

In [48]:
query = torch.tensor([[1,2,3],[6, 8, 10]])
key = torch.tensor([[-1,-2,-3],[6, 8, 10], [-3, 2, -1], [8, 4, 0]])
query.size(), key.size()

(torch.Size([2, 3]), torch.Size([4, 3]))

In [53]:
scores = einsum(query, key, "... q_len d_k, ... k_len d_k -> ... q_len k_len") / torch.sqrt(torch.tensor(d_k, dtype=query.dtype))
scores

tensor([[-1.9799,  7.3539, -0.2828,  2.2627],
        [-7.3539, 28.2843, -1.6971, 11.3137]])

In [3]:
torch.manual_seed(42)
logits_2d = torch.randn(5, 10)  # (batch_size, vocab_size)
targets_2d = torch.randint(0, 10, (5,))

In [4]:
logits_2d

tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
         -0.7521,  1.6487],
        [-0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624,  1.6423, -0.1596,
         -0.4974,  0.4396],
        [-0.7581,  1.0783,  0.8008,  1.6806,  1.2791,  1.2964,  0.6105,  1.3347,
         -0.2316,  0.0418],
        [-0.2516,  0.8599, -1.3847, -0.8712,  0.0780,  0.5258, -0.4880,  1.1914,
         -0.8140, -0.7360],
        [-0.8371, -0.9224, -0.0635,  0.6756, -0.0978,  1.8446, -1.1845,  1.3835,
         -1.2024,  0.7078]])

In [5]:
targets_2d

tensor([3, 0, 1, 1, 7])

In [7]:
max_logits = torch.max(logits_2d, dim=1, keepdim=True)[0]
max_logits

tensor([[1.9269],
        [1.6423],
        [1.6806],
        [1.1914],
        [1.8446]])

In [10]:
logits_stable = logits_2d - max_logits

In [12]:
log_sum_exp = torch.log(torch.sum(torch.exp(logits_stable), dim=1))
log_sum_exp

tensor([1.2072, 0.9042, 1.5671, 1.2431, 1.0228])

In [14]:
target_logits = logits_stable[torch.arange(logits_stable.shape[0]), targets_2d]
target_logits

tensor([-4.0324, -2.0348, -0.6023, -0.3315, -0.4610])

In [17]:
[torch.arange(logits_stable.shape[0]), targets_2d]

[tensor([0, 1, 2, 3, 4]), tensor([3, 0, 1, 1, 7])]

In [22]:
logits_stable.unsqueeze(-1).shape, logits_stable.shape

(torch.Size([5, 10, 1]), torch.Size([5, 10]))

In [24]:
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"] # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p] # Get state associated with p.
                t = state.get("t", 0) # Get iteration number from the state, or initial value.
                grad = p.grad.data # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
                state["t"] = t + 1 # Increment iteration number.
        return loss

In [None]:
weights = torch.nn.Parameter(5 * torch.randn((10, 10)))
opt = SGD([weights], lr=1)
for t in range(100):
    opt.zero_grad() # Reset the gradients for all learnable parameters.
    loss = (weights**2).mean() # Compute a scalar loss value.
    if t//10 == 0:
        print(loss.cpu().item())
    loss.backward() # Run backward pass, which computes gradients.
    opt.step()

21.302330017089844
20.458757400512695
19.884187698364258
19.42763328552246
19.04102325439453
18.701929092407227
18.397775650024414
18.120677947998047
17.865320205688477
17.627910614013672
17.405637741088867
17.196352005004883
16.99835777282715
16.810300827026367
16.631071090698242
16.459749221801758
16.295564651489258
16.137855529785156
15.986066818237305
15.839704513549805
15.698347091674805
15.56161880493164
15.429190635681152
15.300771713256836
15.176095962524414
15.05492877960205
14.937060356140137
14.822297096252441
14.710461616516113
14.601398468017578
14.49496078491211
14.391012191772461
14.289432525634766
14.190107345581055
14.092930793762207
13.997806549072266
13.904642105102539
13.81335735321045
13.723869323730469
13.63610553741455
13.550000190734863
13.465487480163574
13.3825044631958
13.300997734069824
13.220909118652344
13.142192840576172
13.064800262451172
12.98868179321289
12.913800239562988
12.840112686157227
12.76758098602295
12.69616985321045
12.625840187072754
12.556