In [2]:
import torch
import numpy as np
from torch.nn import functional as F

In [39]:
msm_bernoulli_prob = [0.2, 0.8]

p = np.random.uniform(*msm_bernoulli_prob)
print(p)
mask1 = torch.bernoulli(torch.ones(3, 4) * p)
print(mask1)


0.33066794459097115
tensor([[1., 1., 1., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.]])


In [6]:
x = torch.rand(2, 3)
print(x)

y = F.softmax(x, dim=1)
print(y)

z = torch.sum(y, dim=1)
print(z)

tensor([[0.6817, 0.9028, 0.9015],
        [0.7266, 0.8289, 0.2455]])
tensor([[0.2863, 0.3571, 0.3566],
        [0.3669, 0.4064, 0.2268]])
tensor([1., 1.])


## torch.min() and torch.max()

In [11]:
import torch

data = torch.randn(3, 2)
print(data)

min_dist, min_idx = torch.min(data, dim=-1)
print(min_dist, min_dist.shape)
print(min_idx)

min_dist = min_dist.unsqueeze(-1)
print(min_dist, min_dist.shape)

tensor([[-0.8275, -1.9518],
        [-1.5592, -0.0333],
        [ 1.1647,  0.3005]])
tensor([-1.9518, -1.5592,  0.3005]) torch.Size([3])
tensor([1, 0, 1])
tensor([[-1.9518],
        [-1.5592],
        [ 0.3005]]) torch.Size([3, 1])


## torch.mean()

In [8]:
import torch
data = torch.arange(10).reshape(5, 2).float()
print(data)
data_mean = torch.mean(data, dim=1)
print(data_mean)
data_mean2 = torch.mean(data_mean)
print(data_mean2)

data_meann = torch.mean(data)
print(data_meann)

tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.],
        [8., 9.]])
tensor([0.5000, 2.5000, 4.5000, 6.5000, 8.5000])
tensor(4.5000)
tensor(4.5000)


## torch.argmin

In [4]:
import torch

distances = torch.randn(5, 2)
print(distances)

dis_idx = torch.argmin(distances, axis=1)
print(dis_idx, dis_idx.shape)

tensor([[ 2.7553,  1.3871],
        [ 1.0058, -1.3197],
        [ 0.9727,  0.0923],
        [ 2.9525, -1.1732],
        [-0.7277,  1.3667]])
tensor([1, 1, 1, 1, 0]) torch.Size([5])


## torch.linspace

In [5]:
import torch

X = torch.linspace(1, 10, 10)
print(X, X.shape)

X_splited = X.split(4)
print(len(X_splited), type(X_splited))
print(X_splited)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]) torch.Size([10])
3 <class 'tuple'>
(tensor([1., 2., 3., 4.]), tensor([5., 6., 7., 8.]), tensor([ 9., 10.]))


## torch.cdist
https://pytorch.org/docs/stable/generated/torch.cdist.html

In [4]:
import torch

x1 = torch.arange(4).reshape(2, 2).to(torch.float)
x2 = torch.arange(6).reshape(3, 2).to(torch.float)
print(x1, x1.shape)
print(x2, x2.dtype, x2.shape)

dist = torch.cdist(x1, x2)
dist = torch.cdist(x2, x1)
print(dist, dist.shape)

tensor([[0., 1.],
        [2., 3.]]) torch.Size([2, 2])
tensor([[0., 1.],
        [2., 3.],
        [4., 5.]]) torch.float32 torch.Size([3, 2])
tensor([[0.0000, 2.8284],
        [2.8284, 0.0000],
        [5.6569, 2.8284]]) torch.Size([3, 2])


## torch.topk

In [10]:
import torch

pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(3, dim=-1, largest=False, sorted=False)
print(values)
print(indices)

pc = torch.randn(100, 3)
pc = pc[indices]
print(pc.shape)

tensor([[ 0.3988, -0.4285,  0.9232, -0.0042, -1.7357],
        [-0.5347,  1.1426, -1.2428,  0.7063,  1.0780],
        [ 0.9742, -0.0708,  0.3661,  0.8740,  0.9038],
        [ 0.5233,  0.0808,  0.3399, -0.7865,  0.0210]])
tensor([[-0.4285, -1.7357, -0.0042],
        [-0.5347, -1.2428,  0.7063],
        [ 0.3661, -0.0708,  0.8740],
        [-0.7865,  0.0210,  0.0808]])
tensor([[1, 4, 3],
        [0, 2, 3],
        [2, 1, 3],
        [3, 4, 1]])


In [2]:
import torch

q_loc = torch.normal(mean=0.0, std=1,
                     size=(5, 3))
print(q_loc)

tensor([[ 0.4793, -0.6942, -0.0746],
        [-0.5317,  1.2179,  1.0645],
        [-0.7935,  2.3039, -0.7851],
        [ 0.3673,  1.7328, -0.6026],
        [-0.1441,  0.9631, -0.7099]])


## torch.repeat_interleave
https://blog.csdn.net/weixin_45261707/article/details/119187799

In [7]:
import torch

B = 6
N = 1

batch = torch.arange(B).reshape(3, 2)
print(batch, batch.shape)

batch = torch.repeat_interleave(batch, N)
print(batch, batch.shape)
batch = batch.reshape(3, 2)
print(batch)

tensor([[0, 1],
        [2, 3],
        [4, 5]]) torch.Size([3, 2])
tensor([0, 1, 2, 3, 4, 5]) torch.Size([6])
tensor([[0, 1],
        [2, 3],
        [4, 5]])


## torch.randperm

In [3]:
import torch

index = torch.randperm(10)
print(index, index.shape, index.dtype)

tensor([1, 4, 3, 0, 6, 7, 9, 8, 2, 5]) torch.Size([10]) torch.int64


## torch.cumsum
https://pytorch.org/docs/stable/generated/torch.cumsum.html

In [7]:
import torch

a = torch.arange(10).reshape(2, 5)
print(a, a.shape)
out = torch.cumsum(a, dim=1)
print(out, out.shape)

tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]) torch.Size([2, 5])
tensor([[ 0,  1,  3,  6, 10],
        [ 5, 11, 18, 26, 35]]) torch.Size([2, 5])


## Transpose

In [5]:
import torch

input = torch.randn(5, 1024, 3)

input = input.transpose(2, 1)
# input = input.transpose(1, 2)
print(input.shape)

torch.Size([5, 3, 1024])
