Skip to content

Commit

Permalink
split pytorch dependence
Browse files Browse the repository at this point in the history
  • Loading branch information
TsumiNa committed Jan 27, 2018
1 parent c452fa1 commit 4d4cc09
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 27 deletions.
14 changes: 8 additions & 6 deletions xenonpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# license that can be found in the LICENSE file.

__version__ = '0.1.0'
__release__ = 'b6'
__release__ = 'b7'
__short_description__ = "material descriptor library"
__license__ = "BSD (3-clause)"
__author__ = "TsumiNa"
Expand All @@ -12,7 +12,9 @@
__maintainer_email__ = "liu.chang.1865@gmail.com"
__github_username__ = "yoshida-lab"

__pkg_cfg_root__ = '.' + __name__
cfg_root = '.' + __name__

dataset_ver = 'v0.1.0b6'


def get_conf(key: str = None):
Expand All @@ -33,7 +35,7 @@ def get_conf(key: str = None):
import yaml
from pathlib import Path
home = Path.home()
dir_ = home / __pkg_cfg_root__
dir_ = home / cfg_root
cfg_file = dir_ / 'conf.yml'
with open(cfg_file) as f:
conf = yaml.load(f)
Expand Down Expand Up @@ -61,7 +63,7 @@ def _get_dataset_url(fname: str):
return 'https://github.com/' + \
__github_username__ + '/' + \
__name__ + '/releases/download/' + \
__version__ + __release__ + '/' + \
dataset_ver + '/' + \
fname + '.pkl'


Expand All @@ -80,7 +82,7 @@ def _init_cfg_file(force=False):
from shutil import rmtree, copyfile
from pathlib import Path
home = Path.home()
dir_ = home / __pkg_cfg_root__
dir_ = home / cfg_root
cfg_file = dir_ / 'conf.yml'

dataset_dir = dir_ / 'dataset'
Expand All @@ -107,8 +109,8 @@ def _init_cfg_file(force=False):
_init_cfg_file()

from . import descriptor
from . import model
# from .pipeline import *
# from .preprocess import *
from . import utils
from . import visualization
from . import model
5 changes: 3 additions & 2 deletions xenonpy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

from xenonpy.model.nn.checkpoint import CheckPoint
from .nn import *
from . import nn

from . import extern
13 changes: 11 additions & 2 deletions xenonpy/model/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,14 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

from .base_model import *
from .layer import *
import warnings as warnings

with warnings.catch_warnings():
# warnings.simplefilter('default')
try:
import torch
except ImportError:
warnings.warn("Can't fing pytorch, will not load Neorul Network modules.", RuntimeWarning)
else:
from .base_model import *
from .layer import *
18 changes: 9 additions & 9 deletions xenonpy/utils/datatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from sklearn.externals import joblib as jl

from .. import __pkg_cfg_root__
from .. import cfg_root
from .. import _get_dataset_url


Expand Down Expand Up @@ -155,19 +155,19 @@ def __init__(self,
'electron_density', 'sample_A', 'mp_structure'
]
self.chunk_size = chunk_size
self.dataset_dir = Path().home() / __pkg_cfg_root__ / 'dataset'
self.user_data_dir = Path().home() / __pkg_cfg_root__ / 'userdata'
self.cached_dir = Path().home() / __pkg_cfg_root__ / 'cached'
self.dataset_dir = Path().home() / cfg_root / 'dataset'
self.user_data_dir = Path().home() / cfg_root / 'userdata'
self.cached_dir = Path().home() / cfg_root / 'cached'

def _fetch_data(self, url):
def _fetch_data(self, url, save_to=None):
schemes = {'http', 'https'}
scheme = urlparse(url).scheme
if 'http' in scheme:
return self._http_data(url)
return self._http_data(url, save_to)
else:
raise ValueError("Only can access [{}] data but you send {}. :(".format(schemes, scheme))

def _http_data(self, url):
def _http_data(self, url, save_to=None):
r = requests.get(url, stream=True)
r.raise_for_status()

Expand All @@ -176,7 +176,7 @@ def _http_data(self, url):
else:
filename = url.split('/')[-1]

save_to = str(self.cached_dir / filename)
save_to = str(self.cached_dir / filename) if not save_to else str(save_to)
with open(save_to, 'wb') as f:
for chunk in r.iter_content(chunk_size=self.chunk_size):
if chunk: # filter out keep-alive new chunks
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(self, dataset=None):
The dir to save and load data
"""
self.dataset = dataset
self._path = Path.home() / __pkg_cfg_root__ / 'userdata' / dataset
self._path = Path.home() / cfg_root / 'userdata' / dataset
if not self._path.exists():
self._path.mkdir(parents=True)
self._files = None
Expand Down
20 changes: 12 additions & 8 deletions xenonpy/visualization/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def fit(self, desc):
self.desc = pd.DataFrame(minmax_scale(desc_), index=desc.index, columns=desc.columns)
return self

def draw(self, y):
def draw(self, y=None):
ax = sb.clustermap(
self.desc,
cmap="RdBu",
Expand All @@ -64,15 +64,19 @@ def draw(self, y):
col_cluster=self.col_cluster,
**self.kwargs)
ax.cax.set_visible(False)
ax.ax_heatmap.set_position((0.1, 0.2, 0.84, 0.6))
ax.ax_heatmap.yaxis.set_ticks_position('left')
ax.ax_heatmap.yaxis.set_label_position('left')
ax.ax_col_dendrogram.set_position((0.1, 0.8, 0.83, 0.1))

ax = plt.axes([0.95, 0.2, 0.05, 0.6])
ax.plot(y.values, lw=4)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('{:s}\n(property)'.format(y.name), fontsize=20)
if y is None:
ax.ax_col_dendrogram.set_position((0.1, 0.8, 0.9, 0.1))
ax.ax_heatmap.set_position((0.1, 0.2, 0.9, 0.6))
else:
ax.ax_col_dendrogram.set_position((0.1, 0.8, 0.83, 0.1))
ax.ax_heatmap.set_position((0.1, 0.2, 0.84, 0.6))
ax = plt.axes([0.95, 0.2, 0.05, 0.6])
ax.plot(y.values, lw=4)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('{:s}\n(property)'.format(y.name), fontsize=20)
if self.save:
plt.savefig(**self.save)

0 comments on commit 4d4cc09

Please sign in to comment.