diff --git a/scipy/odr/odrpack.py b/scipy/odr/odrpack.py index 878327d0f90b..45fb08c91dcf 100644 --- a/scipy/odr/odrpack.py +++ b/scipy/odr/odrpack.py @@ -35,6 +35,7 @@ robert.kern@gmail.com """ +import os import numpy from warnings import warn @@ -656,11 +657,13 @@ class ODR(object): ODRPACK User's Guide if you absolutely must set the value here. Use the method set_iprint post-initialization for a more readable interface. errfile : str, optional - string with the filename to print ODRPACK errors to. *Do Not Open - This File Yourself!* + string with the filename to print ODRPACK errors to. If the file already + exists, an error will be thrown. The `overwrite` argument can be used to + prevent this. *Do Not Open This File Yourself!* rptfile : str, optional - string with the filename to print ODRPACK summaries to. *Do Not - Open This File Yourself!* + string with the filename to print ODRPACK summaries to. If the file + already exists, an error will be thrown. The `overwrite` argument can be + used to prevent this. *Do Not Open This File Yourself!* ndigit : int, optional integer specifying the number of reliable digits in the computation of the function. @@ -709,6 +712,9 @@ class ODR(object): iwork : ndarray, optional array to hold the integer-valued working data for ODRPACK. When restarting, takes the value of self.output.iwork. + overwrite : bool, optional + If it is True, output files defined by `errfile` and `rptfile` are + overwritten. The default is False. Attributes ---------- @@ -725,7 +731,8 @@ class ODR(object): def __init__(self, data, model, beta0=None, delta0=None, ifixb=None, ifixx=None, job=None, iprint=None, errfile=None, rptfile=None, ndigit=None, taufac=None, sstol=None, partol=None, maxit=None, - stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None): + stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None, + overwrite=False): self.data = data self.model = model @@ -743,6 +750,13 @@ def __init__(self, data, model, beta0=None, delta0=None, ifixb=None, if ifixx is None and data.fix is not None: ifixx = data.fix + if overwrite: + # remove output files for overwriting. + if rptfile is not None and os.path.exists(rptfile): + os.remove(rptfile) + if errfile is not None and os.path.exists(errfile): + os.remove(errfile) + self.delta0 = _conv(delta0) # These really are 32-bit integers in FORTRAN (gfortran), even on 64-bit # platforms. diff --git a/scipy/odr/tests/test_odr.py b/scipy/odr/tests/test_odr.py index f58f0190e7b9..75a4b942701a 100644 --- a/scipy/odr/tests/test_odr.py +++ b/scipy/odr/tests/test_odr.py @@ -1,4 +1,7 @@ # SciPy imports. +import tempfile +import shutil +import os import numpy as np from numpy import pi from numpy.testing import (assert_array_almost_equal, @@ -495,3 +498,25 @@ def func(par, x): sd_ind = out.work_ind['sd'] assert_array_almost_equal(out.sd_beta, out.work[sd_ind:sd_ind + len(out.sd_beta)]) + + def test_output_file_overwrite(self): + """ + Verify fix for gh-1892 + """ + def func(b, x): + return b[0] + b[1] * x + + p = Model(func) + data = Data(np.arange(10), 12 * np.arange(10)) + tmp_dir = tempfile.mkdtemp() + error_file_path = os.path.join(tmp_dir, "error.dat") + report_file_path = os.path.join(tmp_dir, "report.dat") + try: + ODR(data, p, beta0=[0.1, 13], errfile=error_file_path, + rptfile=report_file_path).run() + ODR(data, p, beta0=[0.1, 13], errfile=error_file_path, + rptfile=report_file_path, overwrite=True).run() + finally: + # remove output files for clean up + shutil.rmtree(tmp_dir) +