In [1]:
from typing import Optional, Tuple
import importlib
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
is_xformers_available = importlib.util.find_spec("xformers") is not None
if is_xformers_available:
    import xformers
    import xformers.ops
else:
    xformers = None

from torch.utils.checkpoint import checkpoint_sequential

In [17]:
class MemoryEfficientAttention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            attn_drop: float = 0.0,
            proj_drop: float = 0.0,
    ):
        super().__init__()
        assert (dim % num_heads == 0), 'dim should be divisible by num_heads'
        assert num_heads > 0

        self.dim = dim
        self.num_heads = num_heads
        self.dim_heads = dim // num_heads
        self.attn_drop = attn_drop

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop, inplace=False)

        self._use_memory_efficient_attention_xformers = False
        self.set_use_memory_efficient_attention_xformers(True)

    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
        if use_memory_efficient_attention_xformers:
            if not is_xformers_available:
                raise ModuleNotFoundError(
                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                    " xformers",
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
                )
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e
        self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers

    def forward(self, x, att_mask=None):
        # B : batch size
        # S : sequence length
        # D : embedding size
        # H : number of heads
        # K : embeddings size per head

        B, N, C = x.size()
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, self.dim_heads)
            .permute(2, 0, 3, 1, 4)
            .flatten(1, 2)
        )

        q, k, v = qkv.unbind()

        if self._use_memory_efficient_attention_xformers:
            x = xformers.ops.memory_efficient_attention(
                q, k, v, p=self.attn_drop, attn_bias=att_mask
            )
        else:
            # todo implement using torch.baddbmm
            q = q / math.sqrt(k.size(-1))
            attn = q @ k.transpose(-2, -1)
            if att_mask is not None:
                attn = attn + att_mask
            attn = attn.softmax(-1)
            attn = F.dropout(attn, self.attn_drop)
            x =  attn @ v

        x = (
            x
            .view(B, self.num_heads, N, self.dim_heads)
            .transpose(1, 2)
            .reshape(B, N, C)
        )

        x = self.proj_drop(self.proj(x))
        return x

In [17]:
class MemoryEfficientAttentionNew(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            attn_drop: float = 0.0,
            proj_drop: float = 0.0,
    ):
        super().__init__()
        assert (dim % num_heads == 0), 'dim should be divisible by num_heads'
        assert num_heads > 0

        self.dim = dim
        self.num_heads = num_heads
        self.dim_heads = dim // num_heads
        self.attn_drop = attn_drop

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop, inplace=False)

        self._use_memory_efficient_attention_xformers = False

    def forward(self, x, att_mask=None):
        # B : batch size
        # S : sequence length
        # D : embedding size
        # H : number of heads
        # K : embeddings size per head

        B, S, D = x.size()                                     # B x S x D
        qkv = (
            self.qkv(x)                                        # B x S x D*3
            .reshape(B, S, 3, self.num_heads, self.dim_heads)  # B x S x 3 x H x (D//H)
            .permute(2, 0, 3, 1, 4)                            # 3 x B x H x S x (D//H)
            .flatten(1, 2)                                     # 3 x (B*H) x S x (D//H)
        )

        q, k, v = qkv.unbind() # (B*H) x S x (D//H)

        if self._use_memory_efficient_attention_xformers:
            x = xformers.ops.memory_efficient_attention(
                q, k, v, p=self.attn_drop, attn_bias=att_mask
            )
        else:
            # todo implement using torch.baddbmm
            q = q / math.sqrt(k.size(-1))
            attn = q @ k.transpose(-2, -1)
            if att_mask is not None:
                attn = attn + att_mask
            attn = attn.softmax(-1)
            attn = F.dropout(attn, self.attn_drop)
            x =  attn @ v


        x = (
            x                                            # (B*H) x S x (D//H)
            .view(B, self.num_heads, S, self.dim_heads)  # B x H x S x (D//H)
            .transpose(1, 2)                             # B x S x H x (D//H)
            .reshape(B, S, D)                            # B x S x D
        )

        x = self.proj_drop(self.proj(x))
        return x

In [18]:
attn = MemoryEfficientAttentionNew(dim=128, num_heads=8).to('cuda').train()

x = torch.randn((2, 512, 128), device="cuda")

out_eff = attn(x)


# Test

In [52]:
class MemoryEfficientAttentionOne(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            attn_drop=0.0,
            proj_drop=0.0,
    ):
        super().__init__()
        assert (dim % num_heads == 0), 'dim should be divisible by num_heads'
        assert num_heads > 0

        self.dim = dim
        self.num_heads = num_heads
        self.dim_heads = dim // num_heads
        self.attn_drop = attn_drop

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop, inplace=False)

    def forward(self, x, att_mask=None):
        B, N, C = x.size()
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, self.dim_heads)
            .permute(2, 0, 3, 1, 4)
            .flatten(1, 2)
        )

        q, k, v = qkv.unbind()

        x = xformers.ops.memory_efficient_attention(
            q, k, v, p=self.attn_drop, attn_bias=att_mask
        )

        x = (
            x
            .view(B, self.num_heads, N, self.dim_heads)
            .transpose(1, 2)
            .reshape(B, N, C)
        )

        x = self.proj_drop(self.proj(x))
        return x

In [47]:
class MemoryEfficientAttentionTwo(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias: bool = False,
            attn_drop=0.0,
            proj_drop=0.0,
    ):
        super().__init__()
        assert (dim % num_heads == 0), 'dim should be divisible by num_heads'
        assert num_heads > 0

        self.dim = dim
        self.num_heads = num_heads
        self.dim_heads = dim // num_heads
        self.attn_drop = attn_drop

        self.proj_query = nn.Linear(dim, dim, bias=qkv_bias)
        self.proj_key = nn.Linear(dim, dim, bias=qkv_bias)
        self.proj_value = nn.Linear(dim, dim, bias=qkv_bias)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop, inplace=False)

    def forward(self,
                query: torch.Tensor,
                key: Optional[torch.Tensor] = None,
                value: Optional[torch.Tensor] = None,
                att_mask: Optional[torch.Tensor] = None, ):
        """
        Expected input dimensions are [batch size, sequence length, embed dim]
        Output dimensions are [batch size, sequence length, embed dim]
        """
        if key is None:
            key = query
        if value is None:
            value = query

        B, S_Q, _ = query.size()  # Batch x Sequence x Embedding (latent)
        _, S_K, _ = key.size()  # K, Q's sequence length could differ

        q = self.proj_query(query).view(B, S_Q, self.num_heads, self.dim_heads).transpose(1, 2)
        k = self.proj_key(key).view(B, S_K, self.num_heads, self.dim_heads).transpose(1, 2)
        v = self.proj_value(value).view(B, S_K, self.num_heads, self.dim_heads).transpose(1, 2)

        # Input tensors must be in format [B, M, H, K], where B is the batch size, M
        # the sequence length, H the number of heads, and K the embedding size per head
        x = xformers.ops.memory_efficient_attention(
            q, k, v, p=self.attn_drop, attn_bias=att_mask
        )

        x = (
            x
            .view(B, self.num_heads, S_Q, self.dim_heads)
            .transpose(1, 2)
            .flatten(start_dim=2, end_dim=3)
        )

        x = self.proj_drop(self.proj(x))
        return x

In [63]:
def benchmark_model(model, input_size, device, dtype, warmup=2, runs=10, kwargs={}):
    model = model.to(device=device)
    model.train()
    mem, t = [], []
    with torch.autocast(device_type='cuda', dtype=dtype):
        x = torch.randn(input_size, device=device, requires_grad=True)

        for i in range(runs + warmup):
            start = time.time()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()

            y = model(x)

            torch.cuda.synchronize()
            stop = time.time()

            # now report
            max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
            diff = round((stop - start) * 1e6) / 1e3

            if i >= warmup:
                mem.append(max_memory)
                t.append(diff)
    print("Max Memory:", np.mean(max_memory), "Time:", np.mean(t))
    return None

DEVICE = torch.device("cuda")
DTYPE = torch.float16
BATCH = 64
HEADS = 12
SEQ = 512
EMB = 768

benchmark_model(MemoryEfficientAttentionTwo(dim=EMB, num_heads=HEADS),
                input_size=(BATCH, SEQ, EMB),
                dtype=DTYPE, device=DEVICE, runs=100, warmup=10)
benchmark_model(MemoryEfficientAttentionOne(dim=EMB, num_heads=HEADS),
                input_size=(BATCH, SEQ, EMB),
                dtype=DTYPE, device=DEVICE, runs=100, warmup=10)


Max Memory: 1406.0 Time: 8.993
Max Memory: 1357.0 Time: 5.708809999999999


In [36]:

class MLP(nn.Module):
    def __init__(self,
                 dim: int,
                 dropout: float = 0.0,
                 activation: nn.Module = nn.GELU,
                 hidden_layer_multiplier: int = 4,
                 bias: bool = True):
        super().__init__()
        dim_mlp = hidden_layer_multiplier * dim

        self.mlp = nn.Sequential(
            nn.Linear(in_features=dim, out_features=dim_mlp, bias=bias),
            activation(),
            nn.Dropout(dropout),
            nn.Linear(in_features=dim_mlp, out_features=dim, bias=bias),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class AttentionBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads

        self.attention = MemoryEfficientAttentionTwo(dim=dim, num_heads=num_heads)

        self.ln1 = nn.LayerNorm(dim)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, dim: int, num_heads: int, num_layers: int,
                 gradient_checkpointing: bool = False):
        super().__init__()
        self.num_layers = num_layers
        self.gradient_checkpointing = gradient_checkpointing
        print("Using gradient checkpointing: ", gradient_checkpointing)

        self.blocks = nn.Sequential(
            *[AttentionBlock(dim, num_heads) for _ in range(num_layers)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.gradient_checkpointing and self.training:
            return checkpoint_sequential(functions=self.blocks, input=x, segments=self.num_layers)
        return self.blocks(x)

In [8]:
DEVICE = 'cuda'
BATCH = 8
HEADS = 8
LAYERS = 8
SEQ = 512
EMB = 1024


attn = Transformer(dim=EMB, num_heads=HEADS, num_layers=LAYERS).to(DEVICE)
x = torch.randn(BATCH, SEQ, EMB, device=DEVICE, requires_grad=True)
x = attn(x)
x.shape

ImportError: attempted relative import with no known parent package