Skip to content

Commit

Permalink
Merge pull request #136 from chummels/windows
Browse files Browse the repository at this point in the history
Fixing bug preventing Trident from running on Windows
  • Loading branch information
chummels committed Jun 24, 2020
2 parents 25e6dee + bbad8c1 commit 3a49b1d
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 8 deletions.
34 changes: 34 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Tests for Config Code
"""

#-----------------------------------------------------------------------------
# Copyright (c) 2016, Trident Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#-----------------------------------------------------------------------------

from trident.config import \
trident, \
trident_path, \
create_config, \
parse_config

def test_banner():
"""
Tests running the banner display
"""
trident()

def test_path():
"""
Tests that the trident path is working ok.
"""
trident_path()

# Need to make a test_config but this necessitates changing the config
# code around a bunch.

98 changes: 98 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Tests for Utilities Code
"""

#-----------------------------------------------------------------------------
# Copyright (c) 2016, Trident Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#-----------------------------------------------------------------------------

import os
import tempfile
import filecmp
from trident.utilities import \
ensure_directory, \
gunzip_file, \
gzip_file, \
make_onezone_dataset, \
make_onezone_ray
from trident.spectrum_generator import \
SpectrumGenerator
from trident.ray_generator import \
make_simple_ray

def test_ensure_directory():
"""
Tests ensure_directory code, which ensures a path exists and if not,
it creates such a path
"""
tempdir = tempfile.mkdtemp()
ensure_directory(tempdir)
assert os.path.exists(tempdir)

# is new random path created?
subdir = os.path.join(tempdir, 'random')
assert not os.path.exists(subdir)
ensure_directory(subdir)
assert os.path.exists(subdir)

def test_gzip_unzip_file():
"""
Adds a test to ensure gzip and gunzip code is functionity properly by
creating a random file, zipping it, unzipping it, and checking if the
output is the same as the original file.
"""
# Create a temporary directory
tmpdir = tempfile.mkdtemp()

# Create a temporary file
fp = tempfile.NamedTemporaryFile()
fp.write(b'Hello world!')
filepath = fp.name
directory, filename = os.path.split(filepath)
gzip_filename = os.path.join(tmpdir, filename+'.gz')
gunzip_filename = os.path.join(tmpdir, filename)

# Gzip it
gzip_file(filepath, out_filename=gzip_filename, cleanup=False)

# Gunzip it
gunzip_file(gzip_filename, out_filename=gunzip_filename, cleanup=False)

# Ensure contents of gunzipped file are the same as original
filecmp.cmp(filepath, gunzip_filename)

def test_make_onezone_ray():
"""
Tests the make_onezone_ray infrastructure by creating a ray and making a
spectrum from it.
"""
dirpath = tempfile.mkdtemp()
ray_filename = os.path.join(dirpath, 'ray.h5')
image_filename = os.path.join(dirpath, 'spec.png')
ray = make_onezone_ray(column_densities={'H_p0_number_density':1e21},
filename=ray_filename)
sg_final = SpectrumGenerator(lambda_min=1200, lambda_max=1300, dlambda=0.5)
sg_final.make_spectrum(ray, lines=['Ly a'])
sg_final.plot_spectrum(image_filename)

def test_make_onezone_dataset():
"""
Tests the make_onezone_dataset infrastructure by generating a one_zone
dataset and then creating a ray and spectrum from it.
"""
dirpath = tempfile.mkdtemp()
ray_filename = os.path.join(dirpath, 'ray.h5')
image_filename = os.path.join(dirpath, 'spec.png')
ds = make_onezone_dataset()
ray = make_simple_ray(ds, start_position=ds.domain_left_edge,
end_position=ds.domain_right_edge,
fields=['density', 'temperature', 'metallicity'],
data_filename=ray_filename)
sg = SpectrumGenerator('COS')
sg.make_spectrum(ray)
sg.plot_spectrum(image_filename)
9 changes: 5 additions & 4 deletions trident/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import os
from configparser import \
ConfigParser
ConfigParser, \
NoSectionError
import shutil
import tempfile
import sys
Expand Down Expand Up @@ -64,7 +65,7 @@ def create_config():
datafile from the web. It does this using user interaction from the
python prompt.
"""
default_dir = os.path.expanduser('~/.trident')
default_dir = os.path.expanduser(os.path.join('~', '.trident'))
trident()
print("It appears that this is your first time using Trident. To finalize your")
print("Trident installation, you must:")
Expand Down Expand Up @@ -110,7 +111,7 @@ def create_config():
config.add_section('Trident')
config.set('Trident', 'ion_table_dir', datadir)
config.set('Trident', 'ion_table_file', datafile)
config_filename = os.path.expanduser('~/.trident/config.tri')
config_filename = os.path.expanduser(os.path.join('~', '.trident', 'config.tri'))
with open(config_filename, 'w') as configfile:
config.write(configfile)

Expand Down Expand Up @@ -154,7 +155,7 @@ def parse_config(variable=None):
parser.read(config_filename)
ion_table_dir = parser.get('Trident', 'ion_table_dir')
ion_table_file = parser.get('Trident', 'ion_table_file')
except BaseException:
except NoSectionError:
config_filename = create_config()
parser = ConfigParser()
parser.read(config_filename)
Expand Down
6 changes: 2 additions & 4 deletions trident/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

import gzip
import os
from os.path import \
expanduser
import requests
import tempfile
import shutil
Expand Down Expand Up @@ -88,7 +86,7 @@ def download_file(url, progress_bar=True, local_directory=None,

# Set defaults
if local_filename is None:
local_filename = url.split(os.sep)[-1]
local_filename = os.path.basename(url)
if local_directory is None:
local_directory = '.'
ensure_directory(local_directory)
Expand Down Expand Up @@ -210,7 +208,7 @@ def get_datafiles(datadir=None, url=None):
>>> get_datafiles()
"""
if datadir is None:
datadir = expanduser('~/.trident')
datadir = os.path.expanduser(os.path.join('~','.trident'))
ensure_directory(datadir)

# ion table datafiles are stored here
Expand Down

0 comments on commit 3a49b1d

Please sign in to comment.