import { Callout } from "nextra-theme-docs" import { Tab, Tabs } from "nextra-theme-docs"
Base class for all multi-head self attentions.
Attention(dim=768, num_heads=8, head_dim=64, plugins=[])
dim
: The dimension size.num_heads
: The number of attention heads.head_dim
: The dimension size for each attention head.plugins
: A list ofAttentionPlugin
s to use.
(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']