diff --git a/form/form.py b/form/form.py index 640df85..477b3db 100644 --- a/form/form.py +++ b/form/form.py @@ -263,10 +263,30 @@ def read(self, *names): is done in the preprocessor of FORM (i.e., at compile-time), so one may need to write ".sort" to get the correct result. The return value is a string, or a list of strings when multiple names are passed. + + If non-string objects are passed, they are considered as sequences, and + the return value becomes the list corresponding to the arguments. If + a sequence is passed as the argument to this method, it guarantees that + the return value is always a list: + fl.read(['F1']) --> ['a1'] + fl.read(['F1', 'F2']) --> ['a1', 'a2'] + fl.read(['F1', 'F2', 'F3']) --> ['a1', 'a2', 'a3'] + A more complicated example is + fl.read('F1', ['F2', 'F3']) --> ['a1', ['a2', 'a3']] """ if self._closed: raise IOError('tried to read from closed connection') + if len(names) == 1 and not is_string(names[0]): + names = tuple(names[0]) + if len(names) == 1: + return [self.read(*names)] # Guarantee to return a list + else: + return self.read(*names) + + if any(not is_string(x) for x in names): + return [self.read(x) for x in names] + END_MARK = '__END__' END_MARK_LEN = len(END_MARK) diff --git a/form/tests/test_form.py b/form/tests/test_form.py index 0436196..686a8be 100644 --- a/form/tests/test_form.py +++ b/form/tests/test_form.py @@ -82,7 +82,7 @@ def test_flush(self): f.flush() f.flush() f.flush() - self.assertEqual(f.read("F"), str(N**M)) + self.assertEqual(f.read('F'), str(N**M)) def test_errors(self): with form.open() as f: @@ -90,7 +90,7 @@ def test_errors(self): L F = (1+x)^2; .sort ''') - self.assertRaises(RuntimeError, f.read, "F") + self.assertRaises(RuntimeError, f.read, 'F') with form.open() as f: f.write(''' @@ -98,7 +98,7 @@ def test_errors(self): L F = (1+x)^2; .sort ''') - self.assertRaises(RuntimeError, f.read, "G") + self.assertRaises(RuntimeError, f.read, 'G') with form.open() as f: f.close() @@ -148,5 +148,39 @@ def test_keep_log(self): msg = str(e) self.assertTrue(msg is not None and msg.find('L F = (1+x)^2;') >= 0) + def test_seq_args(self): + with form.open() as f: + f.write(''' + #do i=1,9 + L F`i' = `i'; + #enddo + .sort + ''') + # normal arguments + self.assertEqual(f.read('F1'), '1') + self.assertEqual(f.read('F1', 'F2'), ['1', '2']) + self.assertEqual(f.read('F1', 'F2', 'F3'), ['1', '2', '3']) + # a non-string argument + self.assertEqual(f.read(('F1',)), ['1']) + self.assertEqual(f.read(('F1', 'F2')), ['1', '2']) + self.assertEqual(f.read(('F1', 'F2', 'F3')), ['1', '2', '3']) + # a generator + self.assertEqual(f.read('F{0}'.format(i) for i in range(1, 2)), + ['1']) + self.assertEqual(f.read('F{0}'.format(i) for i in range(1, 3)), + ['1', '2']) + self.assertEqual(f.read('F{0}'.format(i) for i in range(1, 4)), + ['1', '2', '3']) + # more complicated arguments + self.assertEqual(f.read(['F1'], 'F2'), [['1'], '2']) + self.assertEqual(f.read(['F1'], ['F2']), [['1'], ['2']]) + self.assertEqual(f.read('F1', ['F2', 'F3']), ['1', ['2', '3']]) + self.assertEqual(f.read('F1', ['F2'], ['F3']), ['1', ['2'], ['3']]) + self.assertEqual(f.read(['F1'], ['F2', ['F3', 'F4']]), + [['1'], ['2', ['3', '4']]]) + self.assertEqual(f.read('F1', + (('F{0}').format(i) for i in range(2, 5))), + ['1', ['2', '3', '4']]) + if __name__ == '__main__': unittest.main()