# Diagonal extraction

In [None]:
# Note: torch.randn(n, device=torch.device("cuda"))  performs random sampling on the CPU.
import torch

DEVICE = torch.device("cuda")
DTYPE = torch.float32
N = 1000
ZERO = torch.tensor(0.0, dtype=DTYPE, device=DEVICE)
EYE = torch.eye(N, dtype=bool, device=DEVICE)

In [None]:
%%timeit
x = torch.cuda.FloatTensor(N, N).normal_()
torch.diag(torch.diag(x))

In [None]:
%%timeit
# creating I, 0 dynamically
x = torch.cuda.FloatTensor(N, N).normal_()
eye = torch.eye(x.shape[-1], dtype=torch.bool, device=x.device)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
torch.where(eye, x, zero)

In [None]:
%%timeit
# using I, 0 statically
x = torch.cuda.FloatTensor(N, N).normal_()
torch.where(EYE, x, ZERO)

# Tranpose

In [None]:
%%timeit
x = torch.cuda.FloatTensor(N, N).normal_()
x.swapaxes(-1, -2)

In [None]:
%%timeit
x = torch.cuda.FloatTensor(N, N).normal_()
x.T

In [None]:
from typing import Optional

from torch import Tensor, jit


@jit.script
def norm(x: Tensor, p: Optional[float] = None) -> Tensor:
    if p is None:
        return torch.linalg.matrix_norm(x, ord="fro")
    return torch.linalg.matrix_norm(x, ord=p)

In [None]:
x = torch.randn(7, 5, 5)

torch.linalg.norm(x, ord=1, dim=(-1, -2))

In [None]:
torch.linalg.matrix_norm(x, ord="fro")

In [None]:
torch.sum(torch.diagonal(x, dim1=-1, dim2=-2), dim=-1)

In [None]:
[torch.trace(z) for z in x]