Skip to content

Commit

Permalink
lreg
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 15, 2021
1 parent 4da6c0f commit 951415b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions examples/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def init(self, x):
# friendly RNG interface: rng.next() == jax.random.split(...)
rng = elegy.RNGSeq(42)

# pred
# params
w = jax.random.uniform(rng.next(), shape=[d, 10], minval=-1, maxval=1)
b = jax.random.uniform(rng.next(), shape=[1], minval=-1, maxval=1)
net_params = (w, b)

# test
# metrics
total_samples = jnp.array(0, dtype=jnp.float32)
total_tp = jnp.array(0, dtype=jnp.float32)
total_loss = jnp.array(0, dtype=jnp.float32)

# train
# optimizer
optimizer_states = self.optimizer.init(net_params)

return elegy.States(
Expand All @@ -44,14 +44,15 @@ def init(self, x):
)

def train_step(self, x, y_true, net_params, metrics_states, optimizer_states):
# flatten + scale
x = jnp.reshape(x, (x.shape[0], -1)) / 255

def loss_fn(net_params, x: jnp.ndarray, y_true):
# model
w, b = net_params
logits = jnp.dot(x, w) + b

# binary crossentropy loss
# crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
sample_loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.mean(sample_loss)
Expand All @@ -68,7 +69,7 @@ def loss_fn(net_params, x: jnp.ndarray, y_true):
)
net_params = optax.apply_updates(net_params, grads)

# metrics
# cumulative metrics
sample_acc = (jnp.argmax(logits, axis=-1) == y_true).astype(jnp.int32)

total_samples, total_tp, total_loss = metrics_states
Expand Down
4 changes: 2 additions & 2 deletions examples/logistic_regression_minimalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def init(self, x):
# friendly RNG interface: rng.next() == jax.random.split(...)
rng = elegy.RNGSeq(42)

# pred
# params
w = jax.random.uniform(rng.next(), shape=[d, 10], minval=-1, maxval=1)
b = jax.random.uniform(rng.next(), shape=[1], minval=-1, maxval=1)

Expand All @@ -37,7 +37,7 @@ def loss_fn(net_params, x: jnp.ndarray, y_true):
w, b = net_params
logits = jnp.dot(x, w) + b

# binary crossentropy loss
# crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))

Expand Down

0 comments on commit 951415b

Please sign in to comment.