Skip to content

Latest commit

History

History
24 lines (16 loc) 路 559 Bytes

Attention.mdx

File metadata and controls

24 lines (16 loc) 路 559 Bytes

import { Callout } from "nextra-theme-docs" import { Tab, Tabs } from "nextra-theme-docs"

Attention

Base class for all multi-head self attentions.

Attention(dim=768, num_heads=8, head_dim=64, plugins=[])

Parameters

  • dim: The dimension size.
  • num_heads: The number of attention heads.
  • head_dim: The dimension size for each attention head.
  • plugins: A list of AttentionPlugins to use.

Forward

(x: jaxtyping.Float[Tensor, '... n d']) -> jaxtyping.Float[Tensor, '... n d']