Skip to content

Commit

Permalink
Merge branch 'main' into polish_filetable
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber committed May 23, 2023
2 parents c51bae6 + 0fe1191 commit fee694c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .ci_support/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ dependencies:
- pyfileindex =0.0.11
- pysqa =0.0.24
- pytables =3.8.0
- sqlalchemy =2.0.13
- sqlalchemy =2.0.15
- tqdm =4.65.0
- traitlets =5.9.0
116 changes: 72 additions & 44 deletions pyiron_base/storage/filedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,62 @@
)


def _load_txt(file):
if isinstance(file, str):
with open(file, encoding="utf8") as f:
return f.readlines()
else:
return file.readlines()


def _load_json(file):
if isinstance(file, str):
with open(file) as f:
return json.load(f)
else:
return json.load(file)


class FileLoader:
_file_types = {
".json": _load_json,
".txt": _load_txt,
".csv": pandas.read_csv,
}
default_assumed_file_type = ".txt"

@classmethod
def register(cls, file_type, load_callable):
"""Register a load function for a specific file type.
Args:
file_type(str): File extension to be registered, e.g. '.txt', '.csv'
load_callable(callable): function excepting a file or file-handle, returning an appropriate object for
this file type.
"""
cls._file_types[file_type] = load_callable

def load(self, file_type, file, *args, **kwargs):
if file_type in self._file_types:
return self._file_types[file_type](file, *args, **kwargs)
else:
return self._load_default(file, *args, **kwargs)

def _load_default(self, file, *args, **kwargs):
try:
return self._file_types[self.default_assumed_file_type](
file, *args, **kwargs
)
except Exception as e:
raise IOError("File could not be loaded.") from e


if _has_imported["PIL"]:
for ext in Image.registered_extensions():
FileLoader.register(ext, Image.open)


if _has_imported["nbformat"]:

class OwnNotebookNode(nbformat.NotebookNode):
Expand All @@ -65,6 +121,13 @@ def _repr_html_(self):
(html_output, _) = html_exporter.from_notebook_node(self)
return html_output

def _load_ipynb(file):
return OwnNotebookNode(nbformat.read(file, as_version=4))

FileLoader.register(".ipynb", _load_ipynb)

_file_loader = FileLoader()


@import_alarm
def load_file(fp, filetype=None, project=None):
Expand All @@ -85,42 +148,13 @@ def load_file(fp, filetype=None, project=None):
Image extensions supported by PIL
Returns:
:class:`FileHDFio`: pointing to the file of filetype = '.h5'
dict: containing data from file of filetype = '.json'
:class:`FileHDFio`/:class:`ProjectHDFio`: pointing to the file of filetype = '.h5'
dict/list: containing data from file of filetype = '.json'
list: of all lines from file for filetype = '.txt'
:class:`pandas.DataFrame`: containing data from file of filetype = '.csv'
"""

def _load_txt(file):
if isinstance(file, str):
with open(file, encoding="utf8") as f:
return f.readlines()
else:
return file.readlines()

def _load_ipynb(file):
return OwnNotebookNode(nbformat.read(file, as_version=4))

def _load_json(file):
if isinstance(file, str):
with open(file) as f:
return json.load(f)
else:
return json.load(file)

def _load_csv(file):
return pandas.read_csv(file)

def _load_img(file):
return Image.open(file)

def _load_default(file):
try:
return _load_txt(file)
except Exception as e:
raise IOError("File could not be loaded.") from e

def _resolve_filetype(file, _filetype):
if _filetype is None and isinstance(file, str):
_, _filetype = os.path.splitext(file)
Expand All @@ -139,18 +173,8 @@ def _resolve_filetype(file, _filetype):
return FileHDFio(file_name=fp)
else:
return ProjectHDFio(file_name=fp, project=project)
elif filetype in [".json"]:
return _load_json(fp)
elif filetype in [".txt"]:
return _load_txt(fp)
elif filetype in [".csv"]:
return _load_csv(fp)
elif _has_imported["nbformat"] and filetype in [".ipynb"]:
return _load_ipynb(fp)
elif _has_imported["PIL"] and filetype in Image.registered_extensions():
return _load_img(fp)
else:
return _load_default(fp)
return _file_loader.load(filetype, fp)


class FileDataTemplate(ABC):
Expand All @@ -164,7 +188,9 @@ def data(self):
class FileData(FileDataTemplate):
"""FileData stores an instance of a data file, e.g. a single Image from a measurement."""

def __init__(self, file, data=None, metadata=None, filetype=None):
def __init__(
self, file, data=None, metadata=None, filetype=None, pyiron_project=None
):
"""FileData class to store data and associated metadata.
Args:
Expand All @@ -173,7 +199,9 @@ def __init__(self, file, data=None, metadata=None, filetype=None):
metadata (dict/DataContainer): Dictionary of metadata associated with the data
filetype (str): File extension associated with the type data,
If provided this overwrites the assumption based on the extension of the filename.
pyiron_project(Project): Project this file belongs to, if any, used to load files with project awareness.
"""
self._project = pyiron_project
if data is None:
self.filename = os.path.split(file)[1]
self.source = file
Expand Down Expand Up @@ -203,4 +231,4 @@ def data(self):
if self._hasdata:
return self._data
else:
return load_file(self.source, filetype=self.filetype)
return load_file(self.source, filetype=self.filetype, project=self._project)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'psutil==5.9.5',
'pyfileindex==0.0.11',
'pysqa==0.0.24',
'sqlalchemy==2.0.13',
'sqlalchemy==2.0.15',
'tables==3.8.0',
'tqdm==4.65.0',
'traitlets==5.9.0',
Expand Down

0 comments on commit fee694c

Please sign in to comment.