Skip to content

Commit

Permalink
Merge branch 'master' into low-level-guide
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 1, 2021
2 parents e2ba648 + a8b917b commit d1f9048
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
8 changes: 4 additions & 4 deletions elegy/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

PRETRAINED_URLS = {
"ResNet18": {
"url": "https://github.com/poets-ai/elegy-assets/releases/download/resnet18_rev0/ResNet18_ImageNet.pkl",
"sha256": "4397cd02b56a29825243341204710daa1de9f3d6ad776558e61b34690896aaaa",
"url": "https://github.com/poets-ai/elegy-assets/releases/download/resnet18_rev1/ResNet18_ImageNet_rev1.pkl",
"sha256": "02824ae2f29563add46feff14f40c362ae5f9af3f01ea2edc0812e5ca06ca9ae",
},
"ResNet50": {
"url": "https://github.com/poets-ai/elegy-assets/releases/download/resnet50_rev0/ResNet50_ImageNet.pkl",
"sha256": "aadeb068ee6b5e114bc1902159e592c5170a27a661fb3a3d7c463607b25f1381",
"url": "https://github.com/poets-ai/elegy-assets/releases/download/resnet50_rev1/ResNet50_ImageNet_rev1.pkl",
"sha256": "c69086813ccff6b67b2452daabdf64772f8a7f5c04591e1962185129e18989fc",
},
}

Expand Down
23 changes: 22 additions & 1 deletion elegy/nets/resnet_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from elegy import utils

import jax.numpy as jnp
import jax, jax.numpy as jnp
import numpy as np
from unittest import TestCase
import tempfile, os, pickle
import PIL, urllib

import elegy

Expand Down Expand Up @@ -36,3 +37,23 @@ def test_basic_predict(self):
y2 = elegy.Model(new_r18, run_eagerly=True).predict(x)

assert np.allclose(y, y2, rtol=0.001)

def test_autodownload_pretrained_r18(self):
fname, _ = urllib.request.urlretrieve(
"https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg"
)
im = np.array(PIL.Image.open(fname).resize([224, 224])) / np.float32(255)

r18 = elegy.nets.resnet.ResNet18(weights="imagenet")
with jax.disable_jit():
assert elegy.Model(r18).predict(im[np.newaxis]).argmax() == 245

def test_autodownload_pretrained_r50(self):
fname, _ = urllib.request.urlretrieve(
"https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg"
)
im = np.array(PIL.Image.open(fname).resize([224, 224])) / np.float32(255)

r50 = elegy.nets.resnet.ResNet50(weights="imagenet")
with jax.disable_jit():
assert elegy.Model(r50).predict(im[np.newaxis]).argmax() == 245

0 comments on commit d1f9048

Please sign in to comment.