In [70]:
{-# LANGUAGE RecordWildCards #-}
import GHC.Generics
import Torch
import Torch.NN as NN
import Torch.Functional as F
import Torch.Functional.Internal as FI

import Torch.TensorFactories
import Torch.TensorOptions

import Control.Monad (when)

## casualSelfAttentionInit

In [71]:
data Config = Config
  { configNEmbd :: Int
  , configNHead :: Int
  , configBlockSize :: Int
  } deriving (Show, Eq)

In [72]:
data CasualSelfAttention = CasualSelfAttention
  { cAttn :: Linear
  , cProj :: Linear
  , nHead :: Int
  , nEmbd :: Int
  , attentionBias :: Tensor
  } deriving (Generic, Show)

In [73]:
createCausalMask :: Int -> Tensor
createCausalMask seqLen =
  FI.reshape triangle [1, 1, seqLen, seqLen]
  where
    triangle = FI.tril onesMatrix 0
    onesMatrix = ones' [seqLen, seqLen]

In [74]:
casualSelfAttentionInit :: Config -> IO CasualSelfAttention
casualSelfAttentionInit Config{..} = do
  
  when (configNEmbd `mod` configNHead /= 0) $
    error "configNEmbd must be divisible by configNHead"
  
  cAttn <- sample (LinearSpec configNEmbd (3 * configNEmbd))
  cProj <- sample (LinearSpec configNEmbd configNEmbd)

  
  let attentionBias = createCausalMask configBlockSize
                  
  return CasualSelfAttention
    { cAttn = cAttn
    , cProj = cProj
    , nHead = configNHead
    , nEmbd = configNEmbd
    , attentionBias = attentionBias
    }

## scaledDotProductAttention

In [75]:
scaledDotProductAttention :: Tensor -> Tensor -> Tensor -> Maybe Tensor -> Maybe Float -> Tensor
scaledDotProductAttention q k v mask dropout = 
  let
    scaleFactor = FI.sqrt (fromIntegral $ last $ shape k)
    scores = FI.div (FI.matmul q (FI.transpose k (-2) (-1))) scaleFactor
  
    
    maskedScores = case mask of
      Just m -> scores * m + (1.0 - m) * (-1e10)
      Nothing -> scores
    
    tyype = dtype q
    weights = FI.softmax maskedScores (-1) tyype
    
    droppedWeights = case dropout of
       Just rate -> weights
       Nothing -> weights
    
    output = FI.matmul weights v
  in
    output

In [76]:
q <- randIO' [4,8,16,64]-- (B, nh, T, hs)
k <- randIO' [4,8,16,64]-- (B, nh, T, hs)
v <- randIO' [4,8,16,64]-- (B, nh, T, hs)
mask <- randIO' [1, 1, 16, 16] -- (1, 1, T, T)

In [77]:
output = scaledDotProductAttention q k v (Just mask) Nothing
shape output

[4,8,16,64]

## casualSelfAttentionForward

In [78]:
casualSelfAttentionForward :: CasualSelfAttention -> Tensor -> Tensor
casualSelfAttentionForward CasualSelfAttention{..} x = let
  shapes = shape x
  batchSize = head shapes 
  seqLen = shapes !! 1
  embedSize = shapes !! 2
  headSize = Prelude.div embedSize nHead  
  
  projected = NN.linear cAttn x
  

  qkvList = FI.chunk projected 3 2 
  q = head qkvList 
  k = qkvList !! 1
  v = qkvList !! 2
  
  q' = FI.transpose (F.view [batchSize, seqLen, nHead, headSize] q ) 1 2
  k' = FI.transpose (F.view [batchSize, seqLen, nHead, headSize] k) 1 2
  v' = FI.transpose (F.view [batchSize, seqLen, nHead, headSize] v ) 1 2

  
  currentMask = createCausalMask seqLen
  att = scaledDotProductAttention q' k' v' (Just currentMask) Nothing
  
  y1 = FI.transpose att 1 2
  y2 = contiguous y1
  y = F.view [batchSize, seqLen, embedSize] y2 
  
  output = NN.linear cProj y
  in 
   output

In [79]:
nHead = 8
nEmbed = 128
blockSize = 32

batchSize = 4
seqLen = 16


In [None]:
attentionModule <- casualSelfAttentionInit (Config nEmbed nHead blockSize)
x_example <- randIO' [batchSize, seqLen, nEmbed] -- BATCH_SIZE, SEQ_LEN, N_EMBD
putStrLn $ "Input shape: " ++ show (shape x_example) 

Input shape: [4,16,128]

: 