In [1]:
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

import Control.Exception.Safe
  ( SomeException (..),
    try,
  )
import Control.Monad ( forM_, when, (<=<) )
import Control.Monad.Cont ( ContT (..) )
import GHC.Generics
import Pipes hiding ( (~>) )
import qualified Pipes.Prelude as P
import Torch
import Torch.Serialize
import Torch.Typed.Vision ( initMnist )
import qualified Torch.Vision as V
import Prelude hiding ( exp )

data MLP = MLP
  { fc1 :: Linear,
    fc2 :: Linear,
    fc3 :: Linear
  }
  deriving (Generic, Show, Parameterized)

data MLPSpec = MLPSpec
  { i :: Int,
    h1 :: Int,
    h2 :: Int,
    o :: Int
  }
  deriving (Show, Eq)


(~>) :: (a -> b) -> (b -> c) -> a -> c
f ~> g = g. f

mlp :: MLP -> Tensor -> Tensor
mlp MLP {..} = 
  -- Layer 1
  linear fc1
  ~> relu
  
  -- Layer 2
  ~> linear fc2
  ~> relu
  
  -- Layer 3
  ~> linear fc3
  ~> logSoftmax (Dim 1)

In [2]:
instance Randomizable MLPSpec MLP where
  sample MLPSpec {..} =
    MLP
      <$> sample (LinearSpec i h1)
      <*> sample (LinearSpec h1 h2)
      <*> sample (LinearSpec h2 o)

In [3]:
trainLoop :: Optimizer o => MLP -> o -> ListT IO (Tensor, Tensor) -> IO MLP
trainLoop model optimizer = P.foldM step begin done. enumerateData
  where
    step :: MLP -> ((Tensor, Tensor), Int) -> IO MLP
    step model ((input, label), iter) = do
      let loss = nllLoss' label $ mlp model input
      -- Print loss every 50 batches
      when (iter `mod` 50 == 0) $ do
        putStrLn $ "Iteration: " ++ show iter ++ " | Loss: " ++ show loss
      (newParam, _) <- runStep model optimizer loss 1e-3
      return newParam
    done = pure
    begin = pure model

displayImages :: MLP -> (Tensor, Tensor) -> IO ()
displayImages model (testImg, testLabel) = do
  V.dispImage testImg
  putStrLn $ "Model        : " ++ (show. argmax (Dim 1) RemoveDim. exp $ mlp model testImg)
  putStrLn $ "Ground Truth : " ++ show testLabel

main :: IO ()
main = do
  (trainData, testData) <- initMnist "data"
  let trainMnist = V.MNIST {batchSize = 256, mnistData = trainData}
      testMnist = V.MNIST {batchSize = 1, mnistData = testData}
      spec = MLPSpec 784 64 32 10
      optimizer = GD
  net <- sample spec
  
  -- Train for 5 epochs
  net' <- foldLoop net 5 $ \model _ ->
    runContT (streamFromMap (datasetOpts 2) trainMnist) $ trainLoop model optimizer. fst

  -- Show test images + labels
  forM_ [0 .. 10] $ displayImages net' <=< getItem testMnist

  putStrLn "Done"

In [4]:
main

Iteration: 0 | Loss: Tensor Float []  12.3775   
Iteration: 50 | Loss: Tensor Float []  1.0952   
Iteration: 100 | Loss: Tensor Float []  0.5626   
Iteration: 150 | Loss: Tensor Float []  0.6660   
Iteration: 200 | Loss: Tensor Float []  0.4771   
Iteration: 0 | Loss: Tensor Float []  0.5012   
Iteration: 50 | Loss: Tensor Float []  0.4058   
Iteration: 100 | Loss: Tensor Float []  0.3095   
Iteration: 150 | Loss: Tensor Float []  0.4237   
Iteration: 200 | Loss: Tensor Float []  0.3433   
Iteration: 0 | Loss: Tensor Float []  0.3671   
Iteration: 50 | Loss: Tensor Float []  0.3206   
Iteration: 100 | Loss: Tensor Float []  0.2467   
Iteration: 150 | Loss: Tensor Float []  0.3420   
Iteration: 200 | Loss: Tensor Float []  0.2737   
Iteration: 0 | Loss: Tensor Float []  0.3054   
Iteration: 50 | Loss: Tensor Float []  0.2779   
Iteration: 100 | Loss: Tensor Float []  0.2161   
Iteration: 150 | Loss: Tensor Float []  0.2933   
Iteration: 200 | Loss: Tensor Float []  0.2289   
Iteration: 