Skip to content

Commit

Permalink
fix export of export_generator
Browse files Browse the repository at this point in the history
  • Loading branch information
tmabraham committed Sep 13, 2020
1 parent 45e6dd4 commit 020f8e2
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 23 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# UPIT-specific
*.tif
*.pth

ignore/
test/

*.bak
.gitattributes
Expand Down
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,3 @@ cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,opt_func=partial(Adam,mom=0.5,sqr_mom=0.999))
learn.fit_flat_lin(100,100,2e-4)
```

## Demo web app

In the examples/web_app folder, code for a Heroku demo web app (which can be deployed for free) is provided.
A Heroku web app runs at:
https://upit-cyclegan.herokuapp.com
33 changes: 20 additions & 13 deletions docs/inference.cyclegan.html
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,28 @@ <h1 id="Exporting-the-Generator">Exporting the Generator<a class="anchor-link" h
{% raw %}

<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="k">def</span> <span class="nf">export_generator</span><span class="p">(</span><span class="n">learn</span><span class="p">,</span> <span class="n">generator_name</span><span class="o">=</span><span class="s1">&#39;generator&#39;</span><span class="p">,</span><span class="n">path</span><span class="o">=</span><span class="n">Path</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">),</span><span class="n">convert_to</span><span class="o">=</span><span class="s1">&#39;B&#39;</span><span class="p">):</span>
<span class="k">if</span> <span class="n">convert_to</span><span class="o">==</span><span class="s1">&#39;B&#39;</span><span class="p">:</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">learn</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">G_B</span>
<span class="k">elif</span> <span class="n">convert_to</span><span class="o">==</span><span class="s1">&#39;A&#39;</span><span class="p">:</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">learn</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">G_A</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;convert_to must be &#39;A&#39; or &#39;B&#39; (generator that converts either from A to B or B to A)&quot;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span><span class="n">path</span><span class="o">/</span><span class="p">(</span><span class="n">generator_name</span><span class="o">+</span><span class="s1">&#39;.pth&#39;</span><span class="p">))</span>
</pre></div>
</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="export_generator" class="doc_header"><code>export_generator</code><a href="https://github.com/tmabraham/UPIT/tree/master/upit/inference/cyclegan.py#L85" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>export_generator</code>(<strong><code>learn</code></strong>, <strong><code>generator_name</code></strong>=<em><code>'generator'</code></em>, <strong><code>path</code></strong>=<em><code>Path('.')</code></em>, <strong><code>convert_to</code></strong>=<em><code>'B'</code></em>)</p>
</blockquote>

</div>

</div>

</div>
</div>
</div>

Expand Down
1 change: 1 addition & 0 deletions nbs/04_inference.cyclegan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def export_generator(learn, generator_name='generator',path=Path('.'),convert_to='B'):\n",
" if convert_to=='B':\n",
" model = learn.model.G_B\n",
Expand Down
3 changes: 2 additions & 1 deletion upit/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"cycle_learner": "03_train.cyclegan.ipynb",
"FolderDataset": "04_inference.cyclegan.ipynb",
"load_dataset": "04_inference.cyclegan.ipynb",
"get_preds_cyclegan": "04_inference.cyclegan.ipynb"}
"get_preds_cyclegan": "04_inference.cyclegan.ipynb",
"export_generator": "04_inference.cyclegan.ipynb"}

modules = ["models/cyclegan.py",
"models/junyanz.py",
Expand Down
14 changes: 12 additions & 2 deletions upit/inference/cyclegan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04_inference.cyclegan.ipynb (unless otherwise specified).

__all__ = ['FolderDataset', 'load_dataset', 'get_preds_cyclegan']
__all__ = ['FolderDataset', 'load_dataset', 'get_preds_cyclegan', 'export_generator']

# Cell
from ..models.cyclegan import *
Expand Down Expand Up @@ -79,4 +79,14 @@ def get_preds_cyclegan(learn,test_path,pred_path,bs=4,num_workers=4,suffix='tif'
preds = (model(im.cuda())/2 + 0.5)
for i in range(len(fn)):
new_fn = os.path.join(pred_path,'.'.join([os.path.basename(fn[i]).split('.')[0]+'_fakeB',suffix]))
torchvision.utils.save_image(preds[i],new_fn)
torchvision.utils.save_image(preds[i],new_fn)

# Cell
def export_generator(learn, generator_name='generator',path=Path('.'),convert_to='B'):
if convert_to=='B':
model = learn.model.G_B
elif convert_to=='A':
model = learn.model.G_A
else:
raise ValueError("convert_to must be 'A' or 'B' (generator that converts either from A to B or B to A)")
torch.save(model.state_dict(),path/(generator_name+'.pth'))

0 comments on commit 020f8e2

Please sign in to comment.