diff --git a/lib/ch7/simple_conv_net.ex b/lib/ch7/simple_conv_net.ex new file mode 100644 index 0000000..4bdf755 --- /dev/null +++ b/lib/ch7/simple_conv_net.ex @@ -0,0 +1,55 @@ +defmodule SimpleConvNet do + require Axon + + def train_inputs do + x_train = + Dataset.train_image + |> Nx.tensor + |> Nx.reshape({60000, 1, 28, 28}) + |> (& Nx.divide(&1, Nx.reduce_max(&1))).() + |> Nx.to_batched_list(100) + + t_train = + Dataset.train_label + |> Dataset.to_one_hot + |> Nx.tensor + |> Nx.to_batched_list(100) + {x_train, t_train} + end + + def test_inputs do + x_test = + Dataset.test_image + |> Nx.tensor + |> Nx.reshape({10000,1,28,28}) + |> (& Nx.divide(&1, Nx.reduce_max(&1))).() + + t_test = Dataset.test_label |> Dataset.to_one_hot |> Nx.tensor + {x_test, t_test} + end + + def model do + Axon.input({nil,1,28,28}) + |> Axon.conv(30, kernel_size: {5, 5}, activation: :relu) + |> Axon.max_pool(kernel_size: {2, 2}) + |> Axon.flatten() + |> Axon.dense(100, activation: :relu) + |> Axon.dense(10, activation: :softmax) + end + + def train do + {x_train, t_train} = train_inputs() + {trained_params, _optmizer} = + model() + |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Training.train(x_train, t_train, epochs: 10, compiler: EXLA) + + trained_params + end + + def test(params) do + {x_test, t_test} = test_inputs() + Axon.predict(model(), params, x_test, compiler: EXLA) + |> Axon.Metrics.accuracy(t_test) + end +end diff --git a/lib/jose/mnist.ex b/lib/jose/mnist.ex deleted file mode 100644 index e69de29..0000000 diff --git a/mix.exs b/mix.exs index d0aba18..c7e1999 100644 --- a/mix.exs +++ b/mix.exs @@ -21,7 +21,8 @@ defmodule NxDl.MixProject do # Run "mix help deps" to learn about dependencies. def deps do [ - {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, + {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla", override: true}, {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}, {:expyplot, "~> 1.1.2"}, {:erlport, "~> 0.9.8" }, diff --git a/mix.lock b/mix.lock index b694e69..b7ef83c 100644 --- a/mix.lock +++ b/mix.lock @@ -1,18 +1,20 @@ %{ + "axon": {:git, "https://github.com/elixir-nx/axon.git", "e450f32416179baf818c2948e082932b475d8ed9", [branch: "main"]}, "benchee": {:hex, :benchee, "1.0.1", "66b211f9bfd84bd97e6d1beaddf8fc2312aaabe192f776e8931cb0c16f53a521", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}], "hexpm", "3ad58ae787e9c7c94dd7ceda3b587ec2c64604563e049b2a0e8baafae832addb"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark": {:hex, :earmark, "1.3.6", "ce1d0675e10a5bb46b007549362bd3f5f08908843957687d8484fe7f37466b19", [:mix], [], "hexpm", "1476378df80982302d5a7857b6a11dd0230865057dec6d16544afecc6bc6b4c2"}, "elixir_make": {:hex, :elixir_make, "0.6.2", "7dffacd77dec4c37b39af867cedaabb0b59f6a871f89722c25b28fcd4bd70530", [:mix], [], "hexpm", "03e49eadda22526a7e5279d53321d1cced6552f344ba4e03e619063de75348d9"}, "erlport": {:hex, :erlport, "0.9.8", "b7dc57eb87f215a671926bfbcd23e6e9c76f8653b0d072627b41431ef51c4d20", [:rebar3], [], "hexpm", "df57d99455d4bf2bab83e12f242d4e5513ad094b6c73179a85d084c929ce697c"}, "ex_doc": {:hex, :ex_doc, "0.20.2", "1bd0dfb0304bade58beb77f20f21ee3558cc3c753743ae0ddbb0fd7ba2912331", [:mix], [{:earmark, "~> 1.3", [hex: :earmark, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.10", [hex: :makeup_elixir, repo: "hexpm", optional: false]}], "hexpm", "8e24fc8ff9a50b9f557ff020d6c91a03cded7e59ac3e0eec8a27e771430c7d27"}, - "exla": {:git, "https://github.com/elixir-nx/nx.git", "6c836bfa617178efc0e703f7794e7e58440581b3", [sparse: "exla"]}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "bb6f2a811882921de684833e64154bc8084bd792", [sparse: "exla"]}, "expyplot": {:hex, :expyplot, "1.1.3", "3b47e107e979f3d7701460c4cfe4d2535bb991ba354663f08ac20030f14bed8a", [:mix], [{:earmark, "~> 1.3.2", [hex: :earmark, repo: "hexpm", optional: false]}, {:erlport, "~> 0.9", [hex: :erlport, repo: "hexpm", optional: false]}, {:ex_doc, "~> 0.20.2", [hex: :ex_doc, repo: "hexpm", optional: false]}, {:statistics, "~> 0.4.1", [hex: :statistics, repo: "hexpm", optional: false]}], "hexpm", "b4f9d32b56f2782b8cc2bde6017984197eb569d088b656b88efd5eb2d73522a8"}, "flow": {:hex, :flow, "1.1.0", "b569c1042cb2da97103f6d70a0267a5657dce1402f41b4020bef98bbef9c7c1e", [:mix], [{:gen_stage, "~> 1.0", [hex: :gen_stage, repo: "hexpm", optional: false]}], "hexpm", "066f42f7a1ea6a86cb4ef763310338981a5cfb93bcebce10863a23a4859fd785"}, "gen_stage": {:hex, :gen_stage, "1.1.0", "dd0c0f8d2f3b993fdbd3d58e94abbe65380f4e78bdee3fa93d5618d7d14abe60", [:mix], [], "hexpm", "7f2b36a6d02f7ef2ba410733b540ec423af65ec9c99f3d1083da508aca3b9305"}, "makeup": {:hex, :makeup, "1.0.5", "d5a830bc42c9800ce07dd97fa94669dfb93d3bf5fcf6ea7a0c67b2e0e4a7f26c", [:mix], [{:nimble_parsec, "~> 0.5 or ~> 1.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cfa158c02d3f5c0c665d0af11512fed3fba0144cf1aadee0f2ce17747fba2ca9"}, "makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"}, "nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "6c836bfa617178efc0e703f7794e7e58440581b3", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "bb6f2a811882921de684833e64154bc8084bd792", [sparse: "nx"]}, "pelemay_fp": {:hex, :pelemay_fp, "0.1.2", "6ddb0fb8e91c75f490b12ddb69fd3be77d54ce32e8ed31964a0b8e4e4cf05382", [:mix], [], "hexpm", "2d3f1ed87c0a275cd34abd7eadc5fd6bbaecfa75963308fc1d12f8bf40b4fcbe"}, "statistics": {:hex, :statistics, "0.4.1", "e9bfe6649f70842d9ebce69aa31ac8b6928e096bde43b683d7d673057237d028", [:mix], [], "hexpm", "726d8791e9bafb08b3ceeb5b08df6664f29a73a0e6ac0db835500b686a153bd5"}, + "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, }