forked from hasktorch/hasktorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial Hasktorch 0.2 interface (hasktorch#42)
* Starting point for Hasktorch 0.2 interface * Flesh out basic functionality in the new interface * Working XOR MLP * Add the independent function * Aten -> ATen * Add hasktorch and examples to CI * Add side-effect for ones * Make autograd safer
- Loading branch information
1 parent
cf969c3
commit 0af2a97
Showing
21 changed files
with
737 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
packages: | ||
codegen/*.cabal | ||
ffi/*.cabal | ||
hasktorch/*.cabal | ||
examples/*.cabal | ||
inline-c/inline-c/*.cabal | ||
inline-c/inline-c-cpp/*.cabal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
name: examples | ||
version: 0.2.0.0 | ||
synopsis: examples for the new version of hasktorch | ||
-- description: | ||
homepage: https://github.com/githubuser/ffi-experimental#readme | ||
license: BSD3 | ||
author: Austin Huang | ||
maintainer: hasktorch@gmail.com | ||
copyright: 2019 Austin Huang | ||
category: Codegen | ||
build-type: Simple | ||
cabal-version: >=1.10 | ||
|
||
executable xor_mlp | ||
hs-source-dirs: xor_mlp | ||
main-is: Main.hs | ||
default-language: Haskell2010 | ||
build-depends: base >= 4.7 && < 5 | ||
, hasktorch | ||
, mtl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{-# LANGUAGE RecordWildCards #-} | ||
{-# LANGUAGE FunctionalDependencies #-} | ||
|
||
module Main where | ||
|
||
import Torch.Tensor | ||
import Torch.DType | ||
import Torch.TensorFactories | ||
import Torch.Functions | ||
import Torch.TensorOptions | ||
import Torch.Autograd | ||
|
||
import Control.Monad.State.Strict | ||
import Data.List (foldl', scanl', intersperse) | ||
|
||
type Parameter = IndependentTensor | ||
type ParamStream a = State [Parameter] a | ||
|
||
nextParameter :: ParamStream Parameter | ||
nextParameter = do | ||
params <- get | ||
case params of | ||
[] -> error "Not enough parameters supplied to replaceParameters" | ||
(p : t) -> do put t; return p | ||
|
||
class Parametrized f where | ||
flattenParameters :: f -> [Parameter] | ||
replaceOwnParameters :: f -> ParamStream f | ||
|
||
replaceParameters :: Parametrized f => f -> [Parameter] -> f | ||
replaceParameters f params = | ||
let (f', remaining) = runState (replaceOwnParameters f) params in | ||
if null remaining | ||
then f' | ||
else error "Some parameters in a call to replaceParameters haven't been consumed!" | ||
|
||
class Randomizable spec f | spec -> f where | ||
sample :: spec -> IO f | ||
|
||
class (Randomizable spec f, Parametrized f) => Module spec f | ||
|
||
-------------------------------------------------------------------------------- | ||
-- Linear function | ||
-------------------------------------------------------------------------------- | ||
|
||
data LinearSpec = LinearSpec { in_features :: Int, out_features :: Int } | ||
deriving (Show, Eq) | ||
|
||
data Linear = Linear { weight :: Parameter, bias :: Parameter } | ||
deriving (Show) | ||
|
||
instance Randomizable LinearSpec Linear where | ||
sample LinearSpec{..} = do | ||
w <- makeIndependent =<< randn' [in_features, out_features] | ||
b <- makeIndependent =<< randn' [out_features] | ||
return $ Linear w b | ||
|
||
instance Parametrized Linear where | ||
flattenParameters Linear{..} = [weight, bias] | ||
replaceOwnParameters _ = do | ||
weight <- nextParameter | ||
bias <- nextParameter | ||
return $ Linear{..} | ||
|
||
|
||
linear :: Linear -> Tensor -> Tensor | ||
linear Linear{..} input = (matmul input (toDependent weight)) + (toDependent bias) | ||
|
||
-------------------------------------------------------------------------------- | ||
-- MLP | ||
-------------------------------------------------------------------------------- | ||
|
||
data MLPSpec = MLPSpec { feature_counts :: [Int], nonlinearitySpec :: Tensor -> Tensor } | ||
|
||
data MLP = MLP { layers :: [Linear], nonlinearity :: Tensor -> Tensor } | ||
|
||
instance Randomizable MLPSpec MLP where | ||
sample MLPSpec{..} = do | ||
let layer_sizes = mkLayerSizes feature_counts | ||
linears <- mapM sample $ map (uncurry LinearSpec) layer_sizes | ||
return $ MLP { layers = linears, nonlinearity = nonlinearitySpec } | ||
where | ||
mkLayerSizes (a : (b : t)) = | ||
scanl shift (a, b) t | ||
where | ||
shift (a, b) c = (b, c) | ||
|
||
instance Parametrized MLP where | ||
flattenParameters MLP{..} = concat $ map flattenParameters layers | ||
replaceOwnParameters mlp = do | ||
new_layers <- mapM replaceOwnParameters (layers mlp) | ||
return $ mlp { layers = new_layers } | ||
|
||
mlp :: MLP -> Tensor -> Tensor | ||
mlp MLP{..} input = foldl' revApply input $ intersperse nonlinearity $ map linear layers | ||
where revApply x f = f x | ||
|
||
-------------------------------------------------------------------------------- | ||
-- Training code | ||
-------------------------------------------------------------------------------- | ||
|
||
batch_size = 32 | ||
num_iters = 10000 | ||
|
||
model :: MLP -> Tensor -> Tensor | ||
model params t = sigmoid (mlp params t) | ||
|
||
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor] | ||
sgd lr parameters gradients = zipWith (\p dp -> p - (lr * dp)) (map toDependent parameters) gradients | ||
|
||
main :: IO () | ||
main = do | ||
init <- sample $ MLPSpec { feature_counts = [2, 20, 20, 1], nonlinearitySpec = Torch.Functions.tanh } | ||
trained <- foldLoop init num_iters $ \state i -> do | ||
input <- rand' [batch_size, 2] >>= return . (toDType Float) . (gt 0.5) | ||
let expected_output = tensorXOR input | ||
|
||
let output = squeezeAll $ model state input | ||
let loss = mse_loss output expected_output | ||
|
||
let flat_parameters = flattenParameters state | ||
let gradients = grad loss flat_parameters | ||
|
||
if i `mod` 100 == 0 | ||
then do putStrLn $ show loss | ||
else return () | ||
|
||
new_flat_parameters <- mapM makeIndependent $ sgd 5e-4 flat_parameters gradients | ||
return $ replaceParameters state $ new_flat_parameters | ||
return () | ||
where | ||
foldLoop x count block = foldM block x [1..count] | ||
|
||
tensorXOR :: Tensor -> Tensor | ||
tensorXOR t = (1 - (1 - a) * (1 - b)) * (1 - (a * b)) | ||
where | ||
a = select t 1 0 | ||
b = select t 1 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{-# LANGUAGE FlexibleInstances #-} | ||
{-# LANGUAGE MultiParamTypeClasses #-} | ||
|
||
module ATen.Managed.Cast where | ||
|
||
import Foreign.ForeignPtr | ||
import Control.Monad | ||
|
||
import ATen.Class | ||
import ATen.Cast | ||
import ATen.Type | ||
import ATen.Managed.Type.IntArray | ||
import ATen.Managed.Type.TensorList | ||
|
||
instance Castable [Int] (ForeignPtr IntArray) where | ||
cast xs f = do | ||
arr <- newIntArray | ||
forM_ xs $ (intArray_push_back_l arr) . fromIntegral | ||
f arr | ||
uncast xs f = do | ||
len <- intArray_size xs | ||
f =<< mapM (\i -> intArray_at_s xs i >>= return . fromIntegral) [0..(len - 1)] | ||
|
||
instance Castable [ForeignPtr Tensor] (ForeignPtr TensorList) where | ||
cast xs f = do | ||
l <- newTensorList | ||
forM_ xs $ (tensorList_push_back_t l) | ||
f l | ||
uncast xs f = do | ||
len <- tensorList_size xs | ||
f =<< mapM (tensorList_at_s xs) [0..(len - 1)] | ||
|
||
|
||
instance Castable (ForeignPtr Scalar) (ForeignPtr Scalar) where | ||
cast x f = f x | ||
uncast x f = f x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
module Torch.Managed.Autograd where | ||
|
||
import Foreign.ForeignPtr | ||
|
||
import qualified Torch.Unmanaged.Autograd as Unmanaged | ||
import qualified ATen.Unmanaged.Type.Tensor | ||
import qualified ATen.Unmanaged.Type.TensorList | ||
import ATen.Type | ||
import ATen.Class | ||
import ATen.Cast | ||
|
||
|
||
grad :: ForeignPtr Tensor -> ForeignPtr TensorList -> IO (ForeignPtr TensorList) | ||
grad = cast2 Unmanaged.grad | ||
|
||
|
||
makeIndependent :: ForeignPtr Tensor -> IO (ForeignPtr Tensor) | ||
makeIndependent = cast1 Unmanaged.makeIndependent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE PolyKinds #-} | ||
{-# LANGUAGE TemplateHaskell #-} | ||
{-# LANGUAGE QuasiQuotes #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
{-# LANGUAGE OverloadedStrings #-} | ||
|
||
module Torch.Unmanaged.Autograd where | ||
|
||
import Foreign.Ptr | ||
import qualified Language.C.Inline.Cpp as C | ||
import qualified Language.C.Inline.Cpp.Exceptions as C | ||
import qualified Language.C.Inline.Context as C | ||
import qualified Language.C.Types as C | ||
|
||
import ATen.Type | ||
|
||
C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } | ||
|
||
C.include "<vector>" | ||
C.include "<torch/torch.h>" | ||
C.include "<torch/csrc/autograd/variable.h>" | ||
C.include "<torch/csrc/autograd/engine.h>" | ||
C.include "<ATen/core/functional.h>" | ||
|
||
grad :: Ptr Tensor -> Ptr TensorList -> IO (Ptr TensorList) | ||
grad y inputs = [C.throwBlock| std::vector<at::Tensor>* { | ||
torch::autograd::Variable y = *$(at::Tensor* y); | ||
const auto & inputs = *$(std::vector<at::Tensor>* inputs); | ||
|
||
torch::autograd::edge_list roots { y.gradient_edge() }; | ||
if (!roots[0].function) { | ||
throw std::runtime_error("Differentiated tensor not require grad"); | ||
} | ||
|
||
if (y.numel() != 1) { | ||
throw std::runtime_error("Differentiated tensor has more than a single element"); | ||
} | ||
torch::autograd::variable_list grads { torch::ones_like(y) }; | ||
|
||
torch::autograd::edge_list output_edges; | ||
output_edges.reserve(inputs.size()); | ||
for (torch::autograd::Variable input : inputs) { | ||
const auto output_nr = input.output_nr(); | ||
auto grad_fn = input.grad_fn(); | ||
if (!grad_fn) { | ||
grad_fn = input.try_get_grad_accumulator(); | ||
} | ||
if (!input.requires_grad()) { | ||
throw std::runtime_error("One of the differentiated Tensors does not require grad"); | ||
} | ||
if (!grad_fn) { | ||
output_edges.emplace_back(); | ||
} else { | ||
output_edges.emplace_back(grad_fn, output_nr); | ||
} | ||
} | ||
|
||
auto & engine = torch::autograd::Engine::get_default_engine(); | ||
auto outputs = engine.execute(roots, grads, | ||
/*keep_graph=*/true, | ||
/*create_graph=*/false, | ||
output_edges); | ||
|
||
return new std::vector<at::Tensor>(at::fmap<at::Tensor>(outputs)); | ||
}|] | ||
|
||
makeIndependent :: Ptr Tensor -> IO (Ptr Tensor) | ||
makeIndependent t = [C.throwBlock| at::Tensor* { | ||
return new at::Tensor($(at::Tensor* t)->detach().set_requires_grad(true)); | ||
}|] |
Oops, something went wrong.