Skip to content

Commit

Permalink
ch7 cnn
Browse files Browse the repository at this point in the history
  • Loading branch information
thehaigo committed Apr 11, 2021
1 parent 3eef8a8 commit 65b9207
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
55 changes: 55 additions & 0 deletions lib/ch7/simple_conv_net.ex
@@ -0,0 +1,55 @@
defmodule SimpleConvNet do
require Axon

def train_inputs do
x_train =
Dataset.train_image
|> Nx.tensor
|> Nx.reshape({60000, 1, 28, 28})
|> (& Nx.divide(&1, Nx.reduce_max(&1))).()
|> Nx.to_batched_list(100)

t_train =
Dataset.train_label
|> Dataset.to_one_hot
|> Nx.tensor
|> Nx.to_batched_list(100)
{x_train, t_train}
end

def test_inputs do
x_test =
Dataset.test_image
|> Nx.tensor
|> Nx.reshape({10000,1,28,28})
|> (& Nx.divide(&1, Nx.reduce_max(&1))).()

t_test = Dataset.test_label |> Dataset.to_one_hot |> Nx.tensor
{x_test, t_test}
end

def model do
Axon.input({nil,1,28,28})
|> Axon.conv(30, kernel_size: {5, 5}, activation: :relu)
|> Axon.max_pool(kernel_size: {2, 2})
|> Axon.flatten()
|> Axon.dense(100, activation: :relu)
|> Axon.dense(10, activation: :softmax)
end

def train do
{x_train, t_train} = train_inputs()
{trained_params, _optmizer} =
model()
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
|> Axon.Training.train(x_train, t_train, epochs: 10, compiler: EXLA)

trained_params
end

def test(params) do
{x_test, t_test} = test_inputs()
Axon.predict(model(), params, x_test, compiler: EXLA)
|> Axon.Metrics.accuracy(t_test)
end
end
Empty file removed lib/jose/mnist.ex
Empty file.
3 changes: 2 additions & 1 deletion mix.exs
Expand Up @@ -21,7 +21,8 @@ defmodule NxDl.MixProject do
# Run "mix help deps" to learn about dependencies.
def deps do
[
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla", override: true},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:expyplot, "~> 1.1.2"},
{:erlport, "~> 0.9.8" },
Expand Down
6 changes: 4 additions & 2 deletions mix.lock
@@ -1,18 +1,20 @@
%{
"axon": {:git, "https://github.com/elixir-nx/axon.git", "e450f32416179baf818c2948e082932b475d8ed9", [branch: "main"]},
"benchee": {:hex, :benchee, "1.0.1", "66b211f9bfd84bd97e6d1beaddf8fc2312aaabe192f776e8931cb0c16f53a521", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}], "hexpm", "3ad58ae787e9c7c94dd7ceda3b587ec2c64604563e049b2a0e8baafae832addb"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark": {:hex, :earmark, "1.3.6", "ce1d0675e10a5bb46b007549362bd3f5f08908843957687d8484fe7f37466b19", [:mix], [], "hexpm", "1476378df80982302d5a7857b6a11dd0230865057dec6d16544afecc6bc6b4c2"},
"elixir_make": {:hex, :elixir_make, "0.6.2", "7dffacd77dec4c37b39af867cedaabb0b59f6a871f89722c25b28fcd4bd70530", [:mix], [], "hexpm", "03e49eadda22526a7e5279d53321d1cced6552f344ba4e03e619063de75348d9"},
"erlport": {:hex, :erlport, "0.9.8", "b7dc57eb87f215a671926bfbcd23e6e9c76f8653b0d072627b41431ef51c4d20", [:rebar3], [], "hexpm", "df57d99455d4bf2bab83e12f242d4e5513ad094b6c73179a85d084c929ce697c"},
"ex_doc": {:hex, :ex_doc, "0.20.2", "1bd0dfb0304bade58beb77f20f21ee3558cc3c753743ae0ddbb0fd7ba2912331", [:mix], [{:earmark, "~> 1.3", [hex: :earmark, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.10", [hex: :makeup_elixir, repo: "hexpm", optional: false]}], "hexpm", "8e24fc8ff9a50b9f557ff020d6c91a03cded7e59ac3e0eec8a27e771430c7d27"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "6c836bfa617178efc0e703f7794e7e58440581b3", [sparse: "exla"]},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "bb6f2a811882921de684833e64154bc8084bd792", [sparse: "exla"]},
"expyplot": {:hex, :expyplot, "1.1.3", "3b47e107e979f3d7701460c4cfe4d2535bb991ba354663f08ac20030f14bed8a", [:mix], [{:earmark, "~> 1.3.2", [hex: :earmark, repo: "hexpm", optional: false]}, {:erlport, "~> 0.9", [hex: :erlport, repo: "hexpm", optional: false]}, {:ex_doc, "~> 0.20.2", [hex: :ex_doc, repo: "hexpm", optional: false]}, {:statistics, "~> 0.4.1", [hex: :statistics, repo: "hexpm", optional: false]}], "hexpm", "b4f9d32b56f2782b8cc2bde6017984197eb569d088b656b88efd5eb2d73522a8"},
"flow": {:hex, :flow, "1.1.0", "b569c1042cb2da97103f6d70a0267a5657dce1402f41b4020bef98bbef9c7c1e", [:mix], [{:gen_stage, "~> 1.0", [hex: :gen_stage, repo: "hexpm", optional: false]}], "hexpm", "066f42f7a1ea6a86cb4ef763310338981a5cfb93bcebce10863a23a4859fd785"},
"gen_stage": {:hex, :gen_stage, "1.1.0", "dd0c0f8d2f3b993fdbd3d58e94abbe65380f4e78bdee3fa93d5618d7d14abe60", [:mix], [], "hexpm", "7f2b36a6d02f7ef2ba410733b540ec423af65ec9c99f3d1083da508aca3b9305"},
"makeup": {:hex, :makeup, "1.0.5", "d5a830bc42c9800ce07dd97fa94669dfb93d3bf5fcf6ea7a0c67b2e0e4a7f26c", [:mix], [{:nimble_parsec, "~> 0.5 or ~> 1.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cfa158c02d3f5c0c665d0af11512fed3fba0144cf1aadee0f2ce17747fba2ca9"},
"makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"},
"nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "6c836bfa617178efc0e703f7794e7e58440581b3", [sparse: "nx"]},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "bb6f2a811882921de684833e64154bc8084bd792", [sparse: "nx"]},
"pelemay_fp": {:hex, :pelemay_fp, "0.1.2", "6ddb0fb8e91c75f490b12ddb69fd3be77d54ce32e8ed31964a0b8e4e4cf05382", [:mix], [], "hexpm", "2d3f1ed87c0a275cd34abd7eadc5fd6bbaecfa75963308fc1d12f8bf40b4fcbe"},
"statistics": {:hex, :statistics, "0.4.1", "e9bfe6649f70842d9ebce69aa31ac8b6928e096bde43b683d7d673057237d028", [:mix], [], "hexpm", "726d8791e9bafb08b3ceeb5b08df6664f29a73a0e6ac0db835500b686a153bd5"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
}

0 comments on commit 65b9207

Please sign in to comment.