Skip to content

Commit

Permalink
feat(download): add an option to archive and download (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Yang committed Mar 11, 2021
1 parent 25fdd92 commit 3f87b20
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 60 deletions.
22 changes: 10 additions & 12 deletions app/codegen.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
"""Code Generator base module.
"""
import shutil
from pathlib import Path

from jinja2 import Environment, FileSystemLoader


class CodeGenerator:
def __init__(self, templates_dir=None):
templates_dir = templates_dir or "./templates"
self.template_list = [p.stem for p in Path(templates_dir).iterdir() if p.is_dir()]
self.env = Environment(
loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True
)
self.env = Environment(loader=FileSystemLoader(templates_dir), trim_blocks=True, lstrip_blocks=True)

def render_templates(self, template_name: str, config: dict):
"""Renders all the templates from template folder for the given config.
"""
file_template_list = (
template
for template in self.env.list_templates(".jinja")
if template.startswith(template_name)
template for template in self.env.list_templates(".jinja") if template.startswith(template_name)
)
for fname in file_template_list:
# Get template
template = self.env.get_template(fname)
# Render template
code = template.render(**config)
# Write python file
fname = fname.strip(f"{template_name}/").strip(".jinja")
fname = fname.replace(f"{template_name}/", "").replace(".jinja", "")
self.generate(template_name, fname, code)
yield fname, code

def generate(self, template_name: str, fname: str, code: str) -> None:
"""Generates `fname` with content `code` in `path`.
"""
path = Path(f"dist/{template_name}")
path.mkdir(parents=True, exist_ok=True)
(path / fname).write_text(code)
self.path = Path(f"./dist/{template_name}")
self.path.mkdir(parents=True, exist_ok=True)
(self.path / fname).write_text(code)

def make_archive(self):
raise NotImplementedError
def make_archive(self, format_):
return shutil.make_archive(base_name=str(self.path), format=format_, base_dir=self.path)
24 changes: 23 additions & 1 deletion app/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import streamlit as st
import shutil
from pathlib import Path

import streamlit as st
from codegen import CodeGenerator
from utils import import_from_file

Expand Down Expand Up @@ -60,9 +62,29 @@ def add_content(self):
for fname, code in content:
self.render_code(fname, code, fold)

def add_download(self):
st.markdown("")
format_ = st.radio(
"Archive format",
[name for name, _ in sorted(shutil.get_archive_formats(), key=lambda x: x[0], reverse=True)],
)
# temporary hack until streamlit has official download option
# https://github.com/streamlit/streamlit/issues/400
# https://github.com/streamlit/streamlit/issues/400#issuecomment-648580840
if st.button("Generate an archive"):
archive_fname = self.codegen.make_archive(format_)
# this is where streamlit serves static files
# ~/site-packages/streamlit/static/static/
dist_path = Path(st.__path__[0]) / "static/static/dist"
if not dist_path.is_dir():
dist_path.mkdir()
shutil.copy(archive_fname, dist_path)
st.success(f"Download link : [{archive_fname}](./static/{archive_fname})")

def run(self):
self.add_sidebar()
self.add_content()
self.add_download()


def main():
Expand Down
62 changes: 16 additions & 46 deletions templates/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,44 @@
def get_configs() -> dict:
config = {}
with st.beta_expander("Training Configurations"):
st.info(
"Common base training configurations. Those in the parenthesis are used in the code."
)
st.info("Common base training configurations. Those in the parenthesis are used in the code.")

# group by streamlit function type
config["amp_mode"] = st.selectbox(
"AMP mode (amp_mode)", ("None", "amp", "apex")
)
config["device"] = st.selectbox(
"Device to use (device)", ("cpu", "cuda", "xla")
)
config["amp_mode"] = st.selectbox("AMP mode (amp_mode)", ("None", "amp", "apex"))
config["device"] = st.selectbox("Device to use (device)", ("cpu", "cuda", "xla"))

config["data_path"] = st.text_input("Dataset path (data_path)", value="./")
config["filepath"] = st.text_input(
"Logging file path (filepath)", value="./logs"
)
config["filepath"] = st.text_input("Logging file path (filepath)", value="./logs")

config["train_batch_size"] = st.number_input(
"Train batch size (train_batch_size)", min_value=1, value=1
)
config["eval_batch_size"] = st.number_input(
"Eval batch size (eval_batch_size)", min_value=1, value=1
)
config["num_workers"] = st.number_input(
"Number of workers (num_workers)", min_value=0, value=2
)
config["max_epochs"] = st.number_input(
"Maximum epochs to train (max_epochs)", min_value=1, value=2
)
config["train_batch_size"] = st.number_input("Train batch size (train_batch_size)", min_value=1, value=1)
config["eval_batch_size"] = st.number_input("Eval batch size (eval_batch_size)", min_value=1, value=1)
config["num_workers"] = st.number_input("Number of workers (num_workers)", min_value=0, value=2)
config["max_epochs"] = st.number_input("Maximum epochs to train (max_epochs)", min_value=1, value=2)
config["lr"] = st.number_input(
"Learning rate used by torch.optim.* (lr)",
min_value=0.0,
value=1e-3,
format="%e",
"Learning rate used by torch.optim.* (lr)", min_value=0.0, value=1e-3, format="%e",
)
config["log_train"] = st.number_input(
"Logging interval of training iterations (log_train)", min_value=0, value=50
)
config["log_eval"] = st.number_input(
"Logging interval of evaluation epoch (log_eval)", min_value=0, value=1
)
config["seed"] = st.number_input(
"Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666
)
config["log_eval"] = st.number_input("Logging interval of evaluation epoch (log_eval)", min_value=0, value=1)
config["seed"] = st.number_input("Seed used in ignite.utils.manual_seed() (seed)", min_value=0, value=666)
if st.checkbox("Use distributed training"):
config["nproc_per_node"] = st.number_input(
"Number of processes to launch on each node (nproc_per_node)",
min_value=1,
)
config["nnodes"] = st.number_input(
"Number of nodes to use for distributed training (nnodes)", min_value=1,
"Number of processes to launch on each node (nproc_per_node)", min_value=1,
)
config["nnodes"] = st.number_input("Number of nodes to use for distributed training (nnodes)", min_value=1,)
if config["nnodes"] > 1:
st.info(
"The following options are only supported by torch.distributed, namely 'gloo' and 'nccl' backends."
" For other backends, please specify spawn_kwargs in main.py"
)
config["node_rank"] = st.number_input(
"Rank of the node for multi-node distributed training (node_rank)",
min_value=0,
"Rank of the node for multi-node distributed training (node_rank)", min_value=0,
)
if config["node_rank"] > (config["nnodes"] - 1):
st.error(
f"node_rank should be between 0 and {config['nnodes'] - 1}"
)
st.error(f"node_rank should be between 0 and {config['nnodes'] - 1}")
config["master_addr"] = st.text_input(
"Master node TCP/IP address for torch native backends (master_addr)",
"'127.0.0.1'",
"Master node TCP/IP address for torch native backends (master_addr)", "'127.0.0.1'",
)
st.warning("Please include single quote in master_addr.")
config["master_port"] = st.text_input(
Expand Down
2 changes: 1 addition & 1 deletion templates/base/utils.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ DEFAULTS = {
"help": "rank of the node for multi-node distributed training ({{ node_rank }})",
},
"master_addr": {
"default": {{ master_addr|safe }},
"default": {{ master_addr }},
"type": str,
"help": "master node TCP/IP address for torch native backends ({{ master_addr }})",
},
Expand Down

0 comments on commit 3f87b20

Please sign in to comment.