# torch.scatter_reduce_ and torch.gather



| Operation        | Direction | Meaning                                               |
| ---------------- | --------- | ----------------------------------------------------- |
| `gather`         | read      | “pull” data from positions in input                   |
| `scatter_reduce` | write     | “push” data into positions of output (with reduction) |


`torch.scatter_reduce(input, dim, index, src, reduce, *, include_self=True)`
- goal: write (“scatter”) values into an output tensor at given indices, reducing if multiple writes go to the same spot.
- `input`: base tensor to scatter into
- `dim`: dimension to scatter along
- `index`: where to write each element
- `src`: what values to write
> - `index` and `src` should all have the same number of dimensions. 
> - reduce argument ("sum", "prod", "mean", "amax", "amin").







In [1]:
import torch

index = torch.tensor([[0, 1, 1, 2],
                      [1, 2, 2, 3]])
src = torch.tensor([[10, 20, 30, 40],
                    [50, 60, 70, 80]])
x = torch.zeros(2, 4, dtype = src.dtype)
out = torch.scatter_reduce(x, dim=1, index=index, src=src, reduce='sum')
print(out)


tensor([[ 10,  50,  40,   0],
        [  0,  50, 130,  80]])


`torch.gather(input, dim, index, *, sparse_grad=False, out=None)`
- goal: Collect (“gather”) elements from a source tensor according to given indices.
- `input` is the source tensor from which to gather values.
- `dim` is the dimension along which to index.
- `index` is a tensor containing the indices of the values to gather.
> - `input` and `index` must have the same number of dimensions. 
> - `out` will have the same shape as `index`

In [2]:
import torch
x = torch.tensor([[10, 20, 30],
                  [40, 50, 60]])  # shape (2,3)

idx = torch.tensor([[2, 1, 0],
                    [0, 2, 1]])   # shape (2,3)

y = torch.gather(x, dim=1, index=idx)
print(y)


tensor([[30, 20, 10],
        [40, 60, 50]])


## Example: merge the kv cache

Goal: Compute the cosine similarity between two groups of key caches (A and B), and group cache A based on the similarity scores with B. 

In [3]:
import torch.nn.functional as F
import torch

In [4]:

bsz, num_head, head_dim = 2, 4, 64
A_len = 20
B_len = 5

A =  torch.rand((bsz, num_head, A_len, head_dim))
B =  torch.rand((bsz, num_head, B_len, head_dim))

# compute the cosine similarity between A and B
A_norm = A.norm(dim=-1, keepdim=True)
B_norm = B.norm(dim=-1, keepdim=True)
similarity = (A @ B.transpose(-2, -1)) / (A_norm @ B_norm.transpose(-2, -1)) # [bsz, num_head, A_len, B_len]

group_idx_per_A_val, group_idx_per_A_idx = similarity.max(dim = -1)

# compute the per-group softmax over group_idx_per_A_val (grouped by group_idx_per_A_idx)
# softmax: exp(i - max)/\sum exp(j-max)
## get the maximum per group
group_max = torch.full((bsz, num_head, B_len), float('-inf'), dtype=B.dtype, device = B.device)
group_max.scatter_reduce_(dim = 2, index = group_idx_per_A_idx, src = group_idx_per_A_val, reduce = 'amax', include_self=True) 

## compute the weight by softmax
weight_per_A = group_max.gather(index = group_idx_per_A_idx, dim = 2) # [bsz, num_head, A_len]
weight_per_A = torch.exp(group_idx_per_A_val - weight_per_A)

weight_sum_per_group = torch.zeros((bsz, num_head, B_len),dtype = weight_per_A.dtype)
weight_sum_per_group.scatter_reduce_(index = group_idx_per_A_idx, src = weight_per_A, reduce = 'sum', dim = 2 )

weighted_A_per_group = weight_per_A.unsqueeze(-1).expand(-1,-1,-1, head_dim) * A 

weighted_sum_A_per_group = torch.zeros((bsz, num_head,B_len, head_dim))
weighted_sum_A_per_group.scatter_reduce_(index = group_idx_per_A_idx.unsqueeze(-1).expand(-1,-1,-1, head_dim), src = weighted_A_per_group, dim = 2, reduce = 'sum' )


weighted_mean_A_per_group = weighted_sum_A_per_group / weight_sum_per_group.unsqueeze(-1).expand(-1,-1,-1, head_dim).clamp_min(1e-12)

