Skip to content

Commit

Permalink
Add copy() method to datatype (#6269)
Browse files Browse the repository at this point in the history
### Problem

We have a lot of calls to the `CCompiler`, `CppCompiler`, and `Linker` constructors in `native_toolchain.py` to add extra cli arguments or environment variables, and we currently have to specify the value of every field of these constructors, even if we are only changing one or two of them.

### Solution

- add a `copy(self, **kwargs)` method to `datatype` which creates a new object with any entries from `kwargs` applied. *this particular interface was inspired by scala case classes*
- add testing for the new `copy()` method
- consume `copy()` in `native_toolchain.py`

### Result

When creating datatype objects derived from existing objects of the same type, only the changed fields need to be specified, improving the readability of the code in `native_toolchain.py`.
  • Loading branch information
cosmicexplorer authored and Stu Hood committed Aug 1, 2018
1 parent 492acb2 commit 89c6b40
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 42 deletions.
47 changes: 17 additions & 30 deletions src/python/pants/backend/native/subsystems/native_toolchain.py
Expand Up @@ -126,32 +126,28 @@ def select_llvm_c_toolchain(platform, native_toolchain):

if platform.normalized_os_name == 'darwin':
xcode_clang = yield Get(CCompiler, XCodeCLITools, native_toolchain._xcode_cli_tools)
working_c_compiler = CCompiler(
working_c_compiler = provided_clang.copy(
path_entries=(provided_clang.path_entries + xcode_clang.path_entries),
exe_filename=provided_clang.exe_filename,
library_dirs=(provided_clang.library_dirs + xcode_clang.library_dirs),
include_dirs=(provided_clang.include_dirs + xcode_clang.include_dirs),
extra_args=(llvm_c_compiler_args + xcode_clang.extra_args))
extra_args=(provided_clang.extra_args + llvm_c_compiler_args + xcode_clang.extra_args))
else:
gcc_install = yield Get(GCCInstallLocationForLLVM, GCC, native_toolchain._gcc)
provided_gcc = yield Get(CCompiler, GCC, native_toolchain._gcc)
working_c_compiler = CCompiler(
path_entries=provided_clang.path_entries,
exe_filename=provided_clang.exe_filename,
working_c_compiler = provided_clang.copy(
# We need g++'s version of the GLIBCXX library to be able to run, unfortunately.
library_dirs=(provided_gcc.library_dirs + provided_clang.library_dirs),
include_dirs=provided_gcc.include_dirs,
extra_args=(llvm_c_compiler_args + gcc_install.as_clang_argv))
extra_args=(llvm_c_compiler_args + provided_clang.extra_args + gcc_install.as_clang_argv))

base_linker_wrapper = yield Get(BaseLinker, NativeToolchain, native_toolchain)
base_linker = base_linker_wrapper.linker
libc_dev = yield Get(LibcDev, NativeToolchain, native_toolchain)
working_linker = Linker(
working_linker = base_linker.copy(
path_entries=(base_linker.path_entries + working_c_compiler.path_entries),
exe_filename=working_c_compiler.exe_filename,
library_dirs=(base_linker.library_dirs + working_c_compiler.library_dirs),
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)),
extra_args=base_linker.extra_args)
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)))

yield LLVMCToolchain(CToolchain(working_c_compiler, working_linker))

Expand All @@ -172,35 +168,32 @@ def select_llvm_cpp_toolchain(platform, native_toolchain):

if platform.normalized_os_name == 'darwin':
xcode_clang = yield Get(CppCompiler, XCodeCLITools, native_toolchain._xcode_cli_tools)
working_cpp_compiler = CppCompiler(
working_cpp_compiler = provided_clangpp.copy(
path_entries=(provided_clangpp.path_entries + xcode_clang.path_entries),
exe_filename=provided_clangpp.exe_filename,
library_dirs=(provided_clangpp.library_dirs + xcode_clang.library_dirs),
include_dirs=(provided_clangpp.include_dirs + xcode_clang.include_dirs),
# On OSX, this uses the libc++ (LLVM) C++ standard library implementation. This is
# feature-complete for OSX, but not for Linux (see https://libcxx.llvm.org/ for more info).
extra_args=(llvm_cpp_compiler_args + xcode_clang.extra_args))
extra_args=(llvm_cpp_compiler_args + provided_clangpp.extra_args + xcode_clang.extra_args))
linking_library_dirs = []
linker_extra_args = []
else:
gcc_install = yield Get(GCCInstallLocationForLLVM, GCC, native_toolchain._gcc)
provided_gpp = yield Get(CppCompiler, GCC, native_toolchain._gcc)
working_cpp_compiler = CppCompiler(
path_entries=provided_clangpp.path_entries,
exe_filename=provided_clangpp.exe_filename,
working_cpp_compiler = provided_clangpp.copy(
# We need g++'s version of the GLIBCXX library to be able to run, unfortunately.
library_dirs=(provided_gpp.library_dirs + provided_clangpp.library_dirs),
# NB: we use g++'s headers on Linux, and therefore their C++ standard library.
include_dirs=provided_gpp.include_dirs,
extra_args=(llvm_cpp_compiler_args + gcc_install.as_clang_argv))
extra_args=(llvm_cpp_compiler_args + provided_clangpp.extra_args + gcc_install.as_clang_argv))
linking_library_dirs = provided_gpp.library_dirs + provided_clangpp.library_dirs
# Ensure we use libstdc++, provided by g++, during the linking stage.
linker_extra_args=['-stdlib=libstdc++']

libc_dev = yield Get(LibcDev, NativeToolchain, native_toolchain)
base_linker_wrapper = yield Get(BaseLinker, NativeToolchain, native_toolchain)
base_linker = base_linker_wrapper.linker
working_linker = Linker(
working_linker = base_linker.copy(
path_entries=(base_linker.path_entries + working_cpp_compiler.path_entries),
exe_filename=working_cpp_compiler.exe_filename,
library_dirs=(base_linker.library_dirs + working_cpp_compiler.library_dirs),
Expand Down Expand Up @@ -230,22 +223,19 @@ def select_gcc_c_toolchain(platform, native_toolchain):
else:
new_include_dirs = provided_gcc.include_dirs

working_c_compiler = CCompiler(
working_c_compiler = provided_gcc.copy(
path_entries=(provided_gcc.path_entries + assembler.path_entries),
exe_filename=provided_gcc.exe_filename,
library_dirs=provided_gcc.library_dirs,
include_dirs=new_include_dirs,
extra_args=['-x', 'c', '-std=c11'])

base_linker_wrapper = yield Get(BaseLinker, NativeToolchain, native_toolchain)
base_linker = base_linker_wrapper.linker
libc_dev = yield Get(LibcDev, NativeToolchain, native_toolchain)
working_linker = Linker(
working_linker = base_linker.copy(
path_entries=(working_c_compiler.path_entries + base_linker.path_entries),
exe_filename=working_c_compiler.exe_filename,
library_dirs=(base_linker.library_dirs + working_c_compiler.library_dirs),
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)),
extra_args=base_linker.extra_args)
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)))

yield GCCCToolchain(CToolchain(working_c_compiler, working_linker))

Expand All @@ -268,10 +258,8 @@ def select_gcc_cpp_toolchain(platform, native_toolchain):
else:
new_include_dirs = provided_gpp.include_dirs

working_cpp_compiler = CppCompiler(
working_cpp_compiler = provided_gpp.copy(
path_entries=(provided_gpp.path_entries + assembler.path_entries),
exe_filename=provided_gpp.exe_filename,
library_dirs=provided_gpp.library_dirs,
include_dirs=new_include_dirs,
extra_args=([
'-x', 'c++', '-std=c++11',
Expand All @@ -281,12 +269,11 @@ def select_gcc_cpp_toolchain(platform, native_toolchain):
base_linker_wrapper = yield Get(BaseLinker, NativeToolchain, native_toolchain)
base_linker = base_linker_wrapper.linker
libc_dev = yield Get(LibcDev, NativeToolchain, native_toolchain)
working_linker = Linker(
working_linker = base_linker.copy(
path_entries=(working_cpp_compiler.path_entries + base_linker.path_entries),
exe_filename=working_cpp_compiler.exe_filename,
library_dirs=(base_linker.library_dirs + working_cpp_compiler.library_dirs),
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)),
extra_args=base_linker.extra_args)
linking_library_dirs=(base_linker.linking_library_dirs + libc_dev.get_libc_dirs(platform)))

yield GCCCppToolchain(CppToolchain(working_cpp_compiler, working_linker))

Expand Down
20 changes: 8 additions & 12 deletions src/python/pants/util/objects.py
Expand Up @@ -6,7 +6,7 @@

import sys
from abc import abstractmethod
from builtins import map, object, zip
from builtins import object, zip
from collections import OrderedDict, namedtuple

from future.utils import PY2
Expand Down Expand Up @@ -119,17 +119,13 @@ def _asdict(self):

def _replace(_self, **kwds):
'''Return a new datatype object replacing specified fields with new values'''
result = _self._make(map(kwds.pop, _self._fields, _self._super_iter()))
if kwds:
raise ValueError('Got unexpected field names: %r' % kwds.keys())
return result

# TODO: would we want to expose a self.as_tuple() method (which just calls __getnewargs__) so we
# can tuple assign? E.g.:
# class A(datatype(['field'])): pass
# x = A(field='asdf')
# field_value, = x.as_tuple()
# print(field_value) # => 'asdf'
field_dict = _self._asdict()
field_dict.update(**kwds)
return type(_self)(**field_dict)

copy = _replace

# NB: it is *not* recommended to rely on the ordering of the tuple returned by this method.
def __getnewargs__(self):
'''Return self as a plain tuple. Used by copy and pickle.'''
return tuple(self._super_iter())
Expand Down
25 changes: 25 additions & 0 deletions tests/python/pants_test/util/test_objects.py
Expand Up @@ -564,3 +564,28 @@ def compare_str(unicode_type_name, include_unicode=False):
"""error: in constructor of type WithSubclassTypeConstraint: type check error:
field 'some_value' was invalid: value 3 (with type 'int') must satisfy this type constraint: SubclassesOf(SomeBaseClass).""")
self.assertEqual(str(cm.exception), expected_msg)

def test_copy(self):
obj = AnotherTypedDatatype(string='some_string', elements=[1, 2, 3])
new_obj = obj.copy(string='another_string')

self.assertEqual(type(obj), type(new_obj))
self.assertEqual(new_obj.string, 'another_string')
self.assertEqual(new_obj.elements, obj.elements)

def test_copy_failure(self):
obj = AnotherTypedDatatype(string='some string', elements=[1,2,3])

with self.assertRaises(TypeCheckError) as cm:
obj.copy(nonexistent_field=3)
expected_msg = (
"""error: in constructor of type AnotherTypedDatatype: type check error:
__new__() got an unexpected keyword argument 'nonexistent_field'""")
self.assertEqual(str(cm.exception), expected_msg)

with self.assertRaises(TypeCheckError) as cm:
obj.copy(elements=3)
expected_msg = (
"""error: in constructor of type AnotherTypedDatatype: type check error:
field 'elements' was invalid: value 3 (with type 'int') must satisfy this type constraint: Exactly(list).""")
self.assertEqual(str(cm.exception), expected_msg)

0 comments on commit 89c6b40

Please sign in to comment.