Skip to content

Commit

Permalink
add overwrite argument for odr.ODR() to fix gh-1367 and its test.
Browse files Browse the repository at this point in the history
  • Loading branch information
AtsushiSakai committed May 24, 2020
1 parent 40bf81b commit 47c8f37
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
24 changes: 19 additions & 5 deletions scipy/odr/odrpack.py
Expand Up @@ -35,6 +35,7 @@
robert.kern@gmail.com
"""
import os

import numpy
from warnings import warn
Expand Down Expand Up @@ -656,11 +657,13 @@ class ODR(object):
ODRPACK User's Guide if you absolutely must set the value here. Use the
method set_iprint post-initialization for a more readable interface.
errfile : str, optional
string with the filename to print ODRPACK errors to. *Do Not Open
This File Yourself!*
string with the filename to print ODRPACK errors to. If the file already
exists, an error will be thrown. The `overwrite` argument can be used to
prevent this. *Do Not Open This File Yourself!*
rptfile : str, optional
string with the filename to print ODRPACK summaries to. *Do Not
Open This File Yourself!*
string with the filename to print ODRPACK summaries to. If the file
already exists, an error will be thrown. The `overwrite` argument can be
used to prevent this. *Do Not Open This File Yourself!*
ndigit : int, optional
integer specifying the number of reliable digits in the computation
of the function.
Expand Down Expand Up @@ -709,6 +712,9 @@ class ODR(object):
iwork : ndarray, optional
array to hold the integer-valued working data for ODRPACK. When
restarting, takes the value of self.output.iwork.
overwrite : boolean, optional
If it is True, output files defined by `errfile` and `rptfile` are
overwritten. The default is False.
Attributes
----------
Expand All @@ -725,7 +731,8 @@ class ODR(object):
def __init__(self, data, model, beta0=None, delta0=None, ifixb=None,
ifixx=None, job=None, iprint=None, errfile=None, rptfile=None,
ndigit=None, taufac=None, sstol=None, partol=None, maxit=None,
stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None):
stpb=None, stpd=None, sclb=None, scld=None, work=None, iwork=None,
overwrite=False):

self.data = data
self.model = model
Expand All @@ -743,6 +750,13 @@ def __init__(self, data, model, beta0=None, delta0=None, ifixb=None,
if ifixx is None and data.fix is not None:
ifixx = data.fix

if overwrite:
# remove output files for overwriting.
if rptfile is not None and os.path.exists(rptfile):
os.remove(rptfile)
if errfile is not None and os.path.exists(errfile):
os.remove(errfile)

self.delta0 = _conv(delta0)
# These really are 32-bit integers in FORTRAN (gfortran), even on 64-bit
# platforms.
Expand Down
19 changes: 19 additions & 0 deletions scipy/odr/tests/test_odr.py
@@ -1,4 +1,6 @@
# SciPy imports.
import os

import numpy as np
from numpy import pi
from numpy.testing import (assert_array_almost_equal,
Expand Down Expand Up @@ -472,3 +474,20 @@ def test_quadratic_model(self):
output = odr_obj.run()
assert_array_almost_equal(output.beta, [1.0, 2.0, 3.0])

def test_output_file_overwrite(self):
"""
Verify fix for gh-1892
"""
def func(b, x):
return b[0] + b[1] * x

p = Model(func)
data = Data(np.arange(10), 12 * np.arange(10))
ODR(data, p, beta0=[0.1, 13], errfile="error.dat", rptfile='report.dat',
).run()
ODR(data, p, beta0=[0.1, 13], errfile="error.dat", rptfile='report.dat',
overwrite=True).run()
# remove output files for clean up
os.remove("error.dat")
os.remove("report.dat")

0 comments on commit 47c8f37

Please sign in to comment.