Skip to content

Commit

Permalink
Merge pull request #12204 from AtsushiSakai/issue_1367
Browse files Browse the repository at this point in the history
ENH: Add overwrite argument for odr.ODR() and its test.
  • Loading branch information
rgommers committed Jul 18, 2020
2 parents 0b77a6d + 57e00e2 commit 99fb387
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
24 changes: 19 additions & 5 deletions scipy/odr/odrpack.py
Expand Up @@ -35,6 +35,7 @@
robert.kern@gmail.com
"""
import os

import numpy
from warnings import warn
Expand Down Expand Up @@ -656,11 +657,13 @@ class ODR(object):
ODRPACK User's Guide if you absolutely must set the value here. Use the
method set_iprint post-initialization for a more readable interface.
errfile : str, optional
string with the filename to print ODRPACK errors to. *Do Not Open
This File Yourself!*
string with the filename to print ODRPACK errors to. If the file already
exists, an error will be thrown. The `overwrite` argument can be used to
prevent this. *Do Not Open This File Yourself!*
rptfile : str, optional
string with the filename to print ODRPACK summaries to. *Do Not
Open This File Yourself!*
string with the filename to print ODRPACK summaries to. If the file
already exists, an error will be thrown. The `overwrite` argument can be
used to prevent this. *Do Not Open This File Yourself!*
ndigit : int, optional
integer specifying the number of reliable digits in the computation
of the function.
Expand Down Expand Up @@ -709,6 +712,9 @@ class ODR(object):
iwork : ndarray, optional
array to hold the integer-valued working data for ODRPACK. When
restarting, takes the value of self.output.iwork.
overwrite : bool, optional
If it is True, output files defined by `errfile` and `rptfile` are
overwritten. The default is False.
Attributes
----------
Expand All @@ -725,7 +731,8 @@ class ODR(object):
def __init__(self, data, model, beta0=None, delta0=None, ifixb=None,
ifixx=None, job=None, iprint=None, errfile=None, rptfile=None,
ndigit=None, taufac=None, sstol=None, partol=None, maxit=None,
stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None):
stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None,
overwrite=False):

self.data = data
self.model = model
Expand All @@ -743,6 +750,13 @@ def __init__(self, data, model, beta0=None, delta0=None, ifixb=None,
if ifixx is None and data.fix is not None:
ifixx = data.fix

if overwrite:
# remove output files for overwriting.
if rptfile is not None and os.path.exists(rptfile):
os.remove(rptfile)
if errfile is not None and os.path.exists(errfile):
os.remove(errfile)

self.delta0 = _conv(delta0)
# These really are 32-bit integers in FORTRAN (gfortran), even on 64-bit
# platforms.
Expand Down
25 changes: 25 additions & 0 deletions scipy/odr/tests/test_odr.py
@@ -1,4 +1,7 @@
# SciPy imports.
import tempfile
import shutil
import os
import numpy as np
from numpy import pi
from numpy.testing import (assert_array_almost_equal,
Expand Down Expand Up @@ -495,3 +498,25 @@ def func(par, x):
sd_ind = out.work_ind['sd']
assert_array_almost_equal(out.sd_beta,
out.work[sd_ind:sd_ind + len(out.sd_beta)])

def test_output_file_overwrite(self):
"""
Verify fix for gh-1892
"""
def func(b, x):
return b[0] + b[1] * x

p = Model(func)
data = Data(np.arange(10), 12 * np.arange(10))
tmp_dir = tempfile.mkdtemp()
error_file_path = os.path.join(tmp_dir, "error.dat")
report_file_path = os.path.join(tmp_dir, "report.dat")
try:
ODR(data, p, beta0=[0.1, 13], errfile=error_file_path,
rptfile=report_file_path).run()
ODR(data, p, beta0=[0.1, 13], errfile=error_file_path,
rptfile=report_file_path, overwrite=True).run()
finally:
# remove output files for clean up
shutil.rmtree(tmp_dir)

0 comments on commit 99fb387

Please sign in to comment.