From 2d6ffb4a7b2e53dd54a37fc636c377d54ccb42be Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 22 Apr 2024 17:11:45 -0700 Subject: [PATCH 1/4] chore: Upgrade TensorRT version to TRT 10 EA (#2699) Co-authored-by: Evan Li --- .github/scripts/install-torch-tensorrt.sh | 9 +- .github/workflows/build-test.yml | 36 ++++-- README.md | 2 +- core/conversion/converters/converter_util.cpp | 9 +- core/conversion/converters/converter_util.h | 3 + core/conversion/converters/impl/chunk.cpp | 4 - .../converters/impl/constant_pad.cpp | 11 +- .../converters/impl/conv_deconv.cpp | 14 +-- core/conversion/converters/impl/cumsum.cpp | 2 +- core/conversion/converters/impl/expand.cpp | 15 +-- .../converters/impl/interpolate.cpp | 65 +++++----- core/conversion/converters/impl/linear.cpp | 33 ++++-- core/conversion/converters/impl/select.cpp | 13 +- core/conversion/evaluators/eval_util.cpp | 17 +-- core/ir/ir.cpp | 1 - .../unpack_scaled_dot_product_attention.cpp | 8 +- core/plugins/impl/interpolate_plugin.h | 5 +- core/plugins/impl/normalize_plugin.h | 5 +- core/runtime/TRTEngine.cpp | 49 ++++++-- core/runtime/execute_engine.cpp | 1 - cpp/include/torch_tensorrt/ptq.h | 5 - dev_dep_versions.yml | 2 +- packaging/pre_build_script.sh | 7 +- packaging/smoke_test_script.sh | 6 + py/requirements.txt | 2 +- py/torch_tensorrt/__init__.py | 1 - py/torch_tensorrt/_enums.py | 5 +- py/torch_tensorrt/csrc/torch_tensorrt_py.cpp | 5 + py/torch_tensorrt/dynamo/_compiler.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 9 +- .../dynamo/conversion/_conversion.py | 6 +- .../dynamo/conversion/impl/conv.py | 2 +- .../dynamo/conversion/impl/deconv.py | 2 +- .../conversion/impl/elementwise/base.py | 66 ++++++++--- .../dynamo/conversion/impl/elementwise/ops.py | 4 +- .../conversion/impl/normalization/ops.py | 14 ++- .../dynamo/conversion/impl/pad.py | 8 +- .../dynamo/conversion/impl/permutation.py | 4 +- .../dynamo/conversion/impl/select.py | 10 +- .../dynamo/conversion/impl/shape.py | 59 ++++++++- .../dynamo/conversion/impl/upsample.py | 7 +- .../runtime/_PythonTorchTensorRTModule.py | 112 +++++------------- .../fx/converters/acc_ops_converters.py | 27 ++--- .../fx/converters/converter_utils.py | 10 +- py/torch_tensorrt/fx/utils.py | 25 +++- pyproject.toml | 4 +- .../converters/test_conv_deconv.cpp | 4 +- .../test_scaled_dot_product_attention.cpp | 9 +- .../core/partitioning/test_loading_model.cpp | 2 +- tests/cpp/test_compiled_modules.cpp | 2 +- tests/cpp/test_modules_as_engines.cpp | 2 +- .../py/dynamo/conversion/test_arange_aten.py | 8 +- tests/py/dynamo/conversion/test_erf_aten.py | 6 +- .../dynamo/conversion/test_layer_norm_aten.py | 49 -------- tests/py/dynamo/conversion/test_neg_aten.py | 6 +- tests/py/dynamo/runtime/gen_hw_compat.py | 33 ++++++ .../test_convert_method_to_trt_engine.py | 8 +- tests/py/dynamo/runtime/test_hw_compat.py | 6 +- .../test_trt_intercompatibility.py | 20 ++-- tests/py/ts/models/hw_compat.ts | Bin 386982 -> 110482 bytes third_party/tensorrt/archive/BUILD | 2 - third_party/tensorrt/local/BUILD | 18 +-- .../WORKSPACE.x86_64.release.rhel.tmpl | 20 +--- 63 files changed, 481 insertions(+), 420 deletions(-) create mode 100644 packaging/smoke_test_script.sh create mode 100644 tests/py/dynamo/runtime/gen_hw_compat.py diff --git a/.github/scripts/install-torch-tensorrt.sh b/.github/scripts/install-torch-tensorrt.sh index 48c6c90cbf..b2b19b139d 100644 --- a/.github/scripts/install-torch-tensorrt.sh +++ b/.github/scripts/install-torch-tensorrt.sh @@ -5,6 +5,13 @@ source ${BUILD_ENV_FILE} ${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision ${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0 export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()") -${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com + +# Install TensorRT manually +wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz +tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/ +python -m pip install /opt/torch-tensorrt-builds/TensorRT-10.0.0.6/python/tensorrt-10.0.0b6-cp${PYTHON_VERSION//./}-none-linux_x86_64.whl + +# Install Torch-TensorRT +${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl echo -e "Running test script"; diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f5ddce8924..c19b4b4b2e 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -15,7 +15,7 @@ on: jobs: generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@release/2.3 with: package-type: wheel os: linux @@ -36,11 +36,11 @@ jobs: - repository: pytorch/tensorrt pre-script: packaging/pre_build_script.sh env-var-script: packaging/env_vars.txt - post-script: "" - smoke-test-script: "" + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh package-name: torch_tensorrt name: Build torch-tensorrt whl package - uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@release/2.3 with: repository: ${{ matrix.repository }} ref: "" @@ -64,7 +64,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-torchscript-fe repository: "pytorch/tensorrt" @@ -76,9 +77,11 @@ jobs: script: | export USE_HOST_DEPS=1 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/modules - ${CONDA_RUN} python -m pip install --pre -r requirements.txt --use-deprecated=legacy-resolver + # Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now. + ${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2 ${CONDA_RUN} python hub.py popd pushd . @@ -99,7 +102,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-converters repository: "pytorch/tensorrt" @@ -110,6 +114,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -126,7 +131,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-fe repository: "pytorch/tensorrt" @@ -137,6 +143,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -154,7 +161,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-serde repository: "pytorch/tensorrt" @@ -165,6 +173,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -181,7 +190,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-torch-compile-be repository: "pytorch/tensorrt" @@ -192,6 +202,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -209,7 +220,8 @@ jobs: - repository: pytorch/tensorrt package-name: torch_tensorrt pre-script: packaging/pre_build_script.sh - uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + post-script: packaging/post_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3 with: job-name: tests-py-dynamo-core repository: "pytorch/tensorrt" @@ -220,6 +232,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver @@ -249,6 +262,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/core ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver diff --git a/README.md b/README.md index 561ffdc0ad..038956da32 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR - Libtorch 2.4.0.dev (latest nightly) (built with CUDA 12.1) - CUDA 12.1 - cuDNN 8.9.5 -- TensorRT 8.6.1 +- TensorRT 10.0.0.6 ## Prebuilt Binaries and Wheel files diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 3dcd2e9d80..39afe9945f 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -39,6 +39,12 @@ nvinfer1::ITensor* addPadding( } } +nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name) { + nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + input_shape = castITensor(ctx, input_shape, nvinfer1::DataType::kINT32, name); + return input_shape; +} + nvinfer1::ITensor* addUnpadding( ConversionCtx* ctx, const torch::jit::Node* n, @@ -134,7 +140,7 @@ nvinfer1::ILayer* add_elementwise( } auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask); auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask); - auto selfShape = ctx->net->addShape(*self)->getOutput(0); + nvinfer1::ITensor* selfShape = getShapeOutput(ctx, self, std::string(name + "_shape_cast").c_str()); // size of dynamic dimension of other need to the same as that of // corresponding dimension of self auto otherDynamicShape = @@ -348,7 +354,6 @@ nvinfer1::ITensor* normalize_indices( auto neg_itensor = tensor_to_const(ctx, neg); // find the indices that = -1 auto signs = clamp(ctx, indices, neg_itensor, zero_itensor, "clamp layer for " + name); - // get the inputDim value where indices == -1, else 0 auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, signs, input_dim, "prod layer for " + name); TORCHTRT_CHECK(mul, "Unable to create mul layer in normalize_indices"); diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index 3342302431..ad57c476e1 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -62,6 +62,9 @@ nvinfer1::ITensor* castITensor( nvinfer1::DataType dtype, const std::string& layer_name_prefix = ""); +// Get the shape of the input tensor and cast it to INT32 type +nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name = ""); + // Freeze an at::Tensor in a IConstant layer nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string()); diff --git a/core/conversion/converters/impl/chunk.cpp b/core/conversion/converters/impl/chunk.cpp index a7191133fb..b3d2441706 100644 --- a/core/conversion/converters/impl/chunk.cpp +++ b/core/conversion/converters/impl/chunk.cpp @@ -17,7 +17,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() auto chunks = args[1].unwrapToInt(); auto dim = args[2].unwrapToInt(); bool dynamic_shape = ctx->input_is_dynamic; - int size = in->getDimensions().nbDims; int maxDim = static_cast(in->getDimensions().d[dim]); c10::ListTypePtr lt = n->output()->type()->expect(); @@ -41,9 +40,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() size_.nbDims = nbdims; stride_.nbDims = nbdims; - int startIdx = 0; - int endIdx = maxDim; - for (int i = 0; i < nbdims; i++) { start_.d[i] = 0; size_.d[i] = 0; diff --git a/core/conversion/converters/impl/constant_pad.cpp b/core/conversion/converters/impl/constant_pad.cpp index 4191cb1bab..42947e1c03 100644 --- a/core/conversion/converters/impl/constant_pad.cpp +++ b/core/conversion/converters/impl/constant_pad.cpp @@ -55,18 +55,15 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns util::toDims(c10::IntArrayRef(stride))); TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n); slice_layer->setName((util::node_info(n) + "_slice").c_str()); - slice_layer->setMode(nvinfer1::SliceMode::kFILL); + slice_layer->setMode(nvinfer1::SampleMode::kFILL); slice_layer->setInput(4, *value_itensor); if (ctx->input_is_dynamic) { // build the size using inetwork layers - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - shape_layer->setName((util::node_info(n) + "_shape").c_str()); auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32)); - - auto add_layer = ctx->net->addElementWise( - *shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM); + nvinfer1::ITensor* shapeOutput = getShapeOutput(ctx, in, (util::node_info(n) + "_shape").c_str()); + auto add_layer = + ctx->net->addElementWise(*shapeOutput, *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM); TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n); add_layer->setName((util::node_info(n) + "_add").c_str()); slice_layer->setInput(2, *add_layer->getOutput(0)); diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 66620197a9..c71007ac03 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -33,7 +33,7 @@ nvinfer1::ILayer* add_bias_layer( nvinfer1::Dims& input_dims, nvinfer1::Dims& output_padding, Weights& bias) { - nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + nvinfer1::ITensor* input_shape = getShapeOutput(ctx, input_tensor, std::string("bias_shape_cast").c_str()); // Add padding layer nvinfer1::ITensor* start; nvinfer1::ITensor* totalPadding; @@ -61,7 +61,7 @@ nvinfer1::ILayer* add_bias_layer( auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride); sliceLayer->setInput(1, *start); sliceLayer->setInput(2, *size); - sliceLayer->setMode(nvinfer1::SliceMode::kFILL); + sliceLayer->setMode(nvinfer1::SampleMode::kFILL); nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0); nvinfer1::Dims constantDims; @@ -146,9 +146,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // TensorRT expects nbSpatialDims = 2 or 3 filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false); // Reshape input dimensions - in = addPadding(ctx, n, in, 4); + in = addPadding(ctx, n, in, 4, true, true, std::string(util::node_info(n) + "_input_shuffle")); LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions()); - kernel = addPadding(ctx, n, kernel, 4); + kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle")); LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions()); if (transposed) { num_output_maps = kernel_dims.d[1]; @@ -194,7 +194,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) nvinfer1::IConvolutionLayer* convLayer = ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data); convLayer->setStrideNd(stride); - convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + convLayer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); convLayer->setPaddingNd(padding); convLayer->setPostPadding(out_padding); convLayer->setDilationNd(dilation); @@ -291,11 +291,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // shape of convolution's weight: [out, in/groups, ...] auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data); TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n); - conv->setStrideNd(stride); - conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN); + conv->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN); conv->setPaddingNd(padding); - conv->setPostPadding(out_padding); conv->setDilationNd(dilation); conv->setNbGroups(groups); new_layer = conv; diff --git a/core/conversion/converters/impl/cumsum.cpp b/core/conversion/converters/impl/cumsum.cpp index 5c518fd635..f856ca5d4e 100644 --- a/core/conversion/converters/impl/cumsum.cpp +++ b/core/conversion/converters/impl/cumsum.cpp @@ -36,7 +36,7 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32); tripLimit = tensor_to_const(ctx, axis); } else { - nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0); + nvinfer1::ITensor* inpShape = getShapeOutput(ctx, in); torch::Tensor dimValue = torch::tensor(dim, torch::kInt32); nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue); tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0); diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 6b22fea8d4..49f18159dd 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -19,11 +19,11 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe if (max_rank - old_rank > 0) { torch::Tensor thOne = torch::tensor(std::vector(max_rank - old_rank, 1), torch::kInt32); auto one_tensor = tensor_to_const(ctx, thOne); - auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0); + auto in_shape_tensor = getShapeOutput(ctx, tensor); nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor}; return ctx->net->addConcatenation(args, 2)->getOutput(0); } else { // max_rank - old_rank == 0 - return ctx->net->addShape(*tensor)->getOutput(0); + return getShapeOutput(ctx, tensor); } } @@ -44,8 +44,7 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor if (size != targetSize) { if (size != 1) { TORCHTRT_THROW_ERROR( - "The expanded size of tensor (" << targetSize << ")" - << " must match the existing size (" << size << ")" + "The expanded size of tensor (" << targetSize << ")" << " must match the existing size (" << size << ")" << " at dimension " << i); } } @@ -132,8 +131,7 @@ bool add_expand_dynamic( // if size == -1, we can't validate the expansion before setBindingDimensions. if (!(size == -1 || size == 1)) { TORCHTRT_THROW_ERROR( - "The expanded size of tensor (" << targetSize << ")" - << " must match the existing size (" << size << ")" + "The expanded size of tensor (" << targetSize << ")" << " must match the existing size (" << size << ")" << " at dimension " << i); } } @@ -221,8 +219,7 @@ auto expand_registrations TORCHTRT_UNUSED = auto targetDims = targetTensor->getDimensions(); LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims); if (ctx->input_is_dynamic) { - return add_expand_dynamic( - ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false); + return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false); } else { return add_expand(ctx, n, in, targetDims); } @@ -357,7 +354,7 @@ auto expand_registrations TORCHTRT_UNUSED = if (ctx->input_is_dynamic) { auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32)); - auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0); + auto expand_output_shape = getShapeOutput(ctx, expand->getOutput(0)); std::vector repeat_const_vec(repeat_shape_dims.nbDims, 1); repeat_const_vec[dim + 1] = repeats; auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32)); diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index b9a5f631b0..96c257c4b7 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -72,12 +72,11 @@ void resize_layer_size( nvinfer1::ITensor* in, std::vector out_shape, std::vector scales, - nvinfer1::ResizeMode mode, + nvinfer1::InterpolationMode mode, bool align_corners = false) { TORCHTRT_CHECK((out_shape.size() > 0) ^ (scales.size() > 0), "only one of out_shape or scales should be defined"); auto resize_layer = ctx->net->addResize(*in); TORCHTRT_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n); - if (out_shape.size() > 0) { auto th_dynamic_shape_mask = torch::zeros(out_shape.size(), torch::kInt32); auto th_static_shape_mask = torch::zeros(out_shape.size(), torch::kInt32); @@ -91,7 +90,7 @@ void resize_layer_size( auto dynamic_shape_mask = tensor_to_const(ctx, th_dynamic_shape_mask); auto static_shape_mask = tensor_to_const(ctx, th_static_shape_mask); - auto input_shape = ctx->net->addShape(*in)->getOutput(0); + nvinfer1::ITensor* input_shape = getShapeOutput(ctx, in); auto dynamic_shape = ctx->net->addElementWise(*input_shape, *dynamic_shape_mask, nvinfer1::ElementWiseOperation::kPROD) ->getOutput(0); @@ -108,13 +107,17 @@ void resize_layer_size( resize_layer->setResizeMode(mode); resize_layer->setName(util::node_info(n).c_str()); -#if NV_TENSORRT_MAJOR < 8 - resize_layer->setAlignCorners(align_corners); -#else + if (align_corners) { resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS); + } else { + if (mode == nvinfer1::InterpolationMode::kLINEAR) { + resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kHALF_PIXEL); + } else { + // kASYMMETRIC is the default transformation in TensorRT + resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC); + } } -#endif auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); @@ -141,7 +144,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = args[2].IValue()->toDouble(); std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -150,7 +153,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -172,7 +175,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = scale_factors[0]; std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -181,7 +184,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -203,7 +206,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -212,7 +215,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -236,7 +239,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -245,7 +248,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -270,7 +273,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -279,7 +282,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -306,7 +309,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -315,7 +318,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST); } return true; @@ -336,7 +339,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = args[3].IValue()->toDouble(); std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -345,7 +348,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -368,7 +371,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = float scale = scale_factors[0]; std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 1] = scale; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -377,7 +380,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -400,7 +403,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -410,7 +413,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -435,7 +438,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = std::vector padded_scales(in_shape.size(), 1); padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -445,7 +448,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -470,7 +473,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -480,7 +483,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; @@ -507,7 +510,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = padded_scales[padded_scales.size() - 3] = scale_d; padded_scales[padded_scales.size() - 2] = scale_h; padded_scales[padded_scales.size() - 1] = scale_w; - resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners); } else { // Case 2: user uses output size auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList())); @@ -517,7 +520,7 @@ auto interpolate_registrations TORCHTRT_UNUSED = auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); + resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners); } return true; diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index 6289334736..0e4452dec0 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -40,22 +40,29 @@ auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat in = in_shuffle->getOutput(0); } - auto w_tensor = args[1].IValue()->toTensor(); - Weights w = Weights(ctx, w_tensor); + // Convert w_tensor to ITensor and broadcast 2d to 4d if needed + auto weight = args[1].IValue()->toTensor(); + auto weight_tensor = tensor_to_const(ctx, weight, util::node_info(n) + "_weight"); + auto weight_shape = util::toVec(weight_tensor->getDimensions()); + weight_tensor = addPadding(ctx, n, weight_tensor, in->getDimensions().nbDims, false, false); - nvinfer1::ILayer* new_layer; - if (!args[2].IValue()->isNone()) { - Weights b(ctx, args[2].IValue()->toTensor()); - new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, b.data); - } else { - LOG_DEBUG("There is no bias for the linear layer"); - new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, Weights().data); - } + auto mm_layer = ctx->net->addMatrixMultiply( + *in, nvinfer1::MatrixOperation::kNONE, *weight_tensor, nvinfer1::MatrixOperation::kTRANSPOSE); + + TORCHTRT_CHECK(mm_layer, "Unable to create linear layer from node: " << *n); + mm_layer->setName(util::node_info(n).c_str()); - TORCHTRT_CHECK(new_layer, "Unable to create linear layer from node: " << *n); + auto mm_output = mm_layer->getOutput(0); - new_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + if (!args[2].IValue()->isNone()) { + // Convert bias to ITensor + auto bias = args[2].IValue()->toTensor(); + auto bias_tensor = tensor_to_const(ctx, bias, util::node_info(n) + "_bias"); + auto bias_add_layer = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kSUM, mm_output, bias_tensor, util::node_info(n) + "_bias_add"); + mm_output = bias_add_layer->getOutput(0); + } + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_output); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 8334205879..d6b49aa609 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -368,8 +368,7 @@ auto select_registrations TORCHTRT_UNUSED = int rank = inDims.nbDims; LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results."); int adv_idx_count = adv_idx_indices.size(); - auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0); - + nvinfer1::ITensor* in_shape_itensor = getShapeOutput(ctx, in); std::vector dim_tensor_list; for (int i = 0; i < rank; i++) { auto dim_tensor = @@ -401,7 +400,7 @@ auto select_registrations TORCHTRT_UNUSED = // t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n] nvinfer1::ITensor* flatten_tensor = NULL; { - auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0); + nvinfer1::ITensor* shuffle_shape_tensor = getShapeOutput(ctx, shuffle_out); auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32)); for (int i = 0; i < adv_idx_count; i++) { auto dim_tensor = @@ -479,7 +478,7 @@ auto select_registrations TORCHTRT_UNUSED = nvinfer1::ITensor* reshape_output = NULL; { - auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0); + nvinfer1::ITensor* cum_adv_index_shape_tensor = getShapeOutput(ctx, cum_adv_index); // check if all advanced indices are consecutive. if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) { // unfold regular index axes @@ -559,8 +558,7 @@ auto select_registrations TORCHTRT_UNUSED = bool dynamic_shape = ctx->input_is_dynamic; auto input_dim = in->getDimensions(); // add Shape Tensor - auto ishape_layer = ctx->net->addShape(*in); - auto ishape_tensor = ishape_layer->getOutput(0); // input shape + nvinfer1::ITensor* ishape_tensor = getShapeOutput(ctx, in); std::string node_name = n->outputs()[0]->debugName().c_str(); int startIdx = 0; @@ -605,6 +603,7 @@ auto select_registrations TORCHTRT_UNUSED = stride_.d[i] = 1; } } + if (!dynamic_shape) { auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); LOG_DEBUG("start_:" << start_); @@ -617,7 +616,6 @@ auto select_registrations TORCHTRT_UNUSED = LOG_DEBUG("Using dynamic version of slice"); // start tensor at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32); - ; start_tensor[axis] = startIdx; auto start_itensor = tensor_to_const(ctx, start_tensor); @@ -647,7 +645,6 @@ auto select_registrations TORCHTRT_UNUSED = // calculate size auto size_itensor = get_slice_size(ctx, out_start, out_end, stride_itensor, nbdims, node_name); - // update slice layer auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); slice_layer->setInput(1, *out_start); // start diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 71b6de9eb2..2c6567aa95 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -34,11 +34,8 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); auto input_dims = in->getDimensions(); LOG_DEBUG("Input dimensions: " << input_dims); - - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - auto shape_1d_tensor = shape_layer->getOutput(0); - + nvinfer1::ITensor* shape_1d_tensor = torch_tensorrt::core::conversion::converters::getShapeOutput( + ctx, in, std::string(util::node_info(n) + "_dynamic_shape_layer_cast").c_str()); if (n->inputs().size() != 1) { auto maxDim = static_cast(in->getDimensions().nbDims); auto dim = args.at(n->input(1)).unwrapToInt(); @@ -168,8 +165,7 @@ c10::optional toIValue(const torch::jit::Value* v) { void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { if (!elem_type->isSubtypeOf(c10::NumberType::get()) && elem_type != c10::BoolType::get()) { std::stringstream error; - error << "Input must be of ints, floats, or bools, " - << "got " << elem_type->repr_str(); + error << "Input must be of ints, floats, or bools, " << "got " << elem_type->repr_str(); // special case empty list torch.tensor([]) if (elem_type->isSubtypeOf(c10::TensorType::get())) { if (empty_list) { @@ -423,13 +419,12 @@ c10::optional newTensorLikeImplementation( // broadcast constant to output shape std::vector start_vec(self->getDimensions().nbDims, 0); auto start_offset = util::toDims(c10::IntArrayRef(start_vec)); - auto shape_layer = ctx->net->addShape(*self); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - shape_layer->setName((util::node_info(n) + "_shape").c_str()); + nvinfer1::ITensor* shape_output = torch_tensorrt::core::conversion::converters::getShapeOutput( + ctx, self, std::string(util::node_info(n) + "_shape").c_str()); // slice implements expand auto slice_layer = ctx->net->addSlice(*constant_itensor, start_offset, self->getDimensions(), start_offset); TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n); - slice_layer->setInput(2, *shape_layer->getOutput(0)); + slice_layer->setInput(2, *shape_output); slice_layer->setName((util::node_info(n) + "_slice").c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index c98d17c5ef..b67e228f1f 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -151,7 +151,6 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* // If node outputs a Tensor it might be a result of tensor calcuation so check to see // if any inputs to the calculation can give us hints - c10::optional const_tensor_n = {}; // Backtrace to constants which will immediately give us the Tensor type if possible for (auto in : ins) { diff --git a/core/lowering/passes/unpack_scaled_dot_product_attention.cpp b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp index bfe0004bd6..3c347f65ca 100644 --- a/core/lowering/passes/unpack_scaled_dot_product_attention.cpp +++ b/core/lowering/passes/unpack_scaled_dot_product_attention.cpp @@ -12,12 +12,12 @@ namespace passes { // https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html void UnpackScaledDotProductAttention(std::shared_ptr& graph) { std::string sdpa_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): - %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal) + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): + %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale) return (%out))IR"; std::string unpacked_sdpa_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): %none : NoneType = prim::Constant() %1 : int = prim::Constant[value=-1]() %2 : int = prim::Constant[value=-2]() @@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr& graph) return(%out))IR"; std::string unpacked_sdpa_attn_biased_pattern = R"IR( - graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale): %none : NoneType = prim::Constant() %0 : int = prim::Constant[value=1]() %1 : int = prim::Constant[value=-1]() diff --git a/core/plugins/impl/interpolate_plugin.h b/core/plugins/impl/interpolate_plugin.h index ced4cbee20..661cee270f 100644 --- a/core/plugins/impl/interpolate_plugin.h +++ b/core/plugins/impl/interpolate_plugin.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -58,7 +57,7 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override{}; + void setPluginNamespace(const char* pluginNamespace) noexcept override {}; nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; @@ -118,7 +117,7 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* libNamespace) noexcept override{}; + void setPluginNamespace(const char* libNamespace) noexcept override {}; const char* getPluginName() const noexcept override; diff --git a/core/plugins/impl/normalize_plugin.h b/core/plugins/impl/normalize_plugin.h index 28c3a5c5da..7e564b505b 100644 --- a/core/plugins/impl/normalize_plugin.h +++ b/core/plugins/impl/normalize_plugin.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -42,7 +41,7 @@ class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override{}; + void setPluginNamespace(const char* pluginNamespace) noexcept override {}; nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; @@ -102,7 +101,7 @@ class NormalizePluginCreator : public nvinfer1::IPluginCreator { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* libNamespace) noexcept override{}; + void setPluginNamespace(const char* libNamespace) noexcept override {}; const char* getPluginName() const noexcept override; diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 92e5d7a8ff..31fbf60204 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -120,16 +120,26 @@ TRTEngine::TRTEngine( } else { uint64_t inputs_size = _in_binding_names.size(); in_binding_names.resize(inputs_size); - for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) { + for (uint64_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) { auto binding_name = _in_binding_names[pyt_idx]; - auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - std::string engine_binded_name = cuda_engine->getIOTensorName(trt_idx); - TORCHTRT_CHECK( - (binding_name == engine_binded_name), - "Could not find a TensorRT engine binding for input named " << binding_name); + // Check if the binding name provided is in the list of engine's bindings + // by iterating through nbIOTensors and verify it is an input binding + bool is_binding = false, is_input = false; + int32_t trt_idx; + for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) { + std::string curr_bind_name = cuda_engine->getIOTensorName(idx); + if (curr_bind_name == binding_name) { + is_binding = true; + trt_idx = idx; + if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) { + is_input = true; + break; + } + } + } + TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for input named " << binding_name); TORCHTRT_CHECK( - (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), - "Binding " << binding_name << " specified as input but found as output in TensorRT engine"); + is_input, "Binding " << binding_name << " specified as input but found as output in TensorRT engine"); LOG_DEBUG( "Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx << ", Torch binding index: " << pyt_idx); @@ -141,11 +151,26 @@ TRTEngine::TRTEngine( out_binding_names.resize(outputs); for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) { auto binding_name = _out_binding_names[pyt_idx]; - auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str()); - TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name); + // Check if the binding name provided is in the list of engine's bindings + // by iterating through nbIOTensors and verify it is an output binding + bool is_binding = false, is_output = false; + int32_t trt_idx; + for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) { + std::string curr_bind_name = cuda_engine->getIOTensorName(idx); + if (curr_bind_name == binding_name) { + is_binding = true; + trt_idx = idx; + if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kOUTPUT) { + is_output = true; + break; + } + } + } + + TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for output named " << binding_name); TORCHTRT_CHECK( - !(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT), - "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); + is_output, "Binding " << binding_name << " specified as output but found as input in TensorRT engine"); + LOG_DEBUG( "Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx << ", Torch binding index: " << inputs_size + pyt_idx); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 5ff163fbfb..a1ee30e994 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -179,7 +179,6 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index()); - // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. std::unique_lock lock(compiled_engine->mu); compiled_engine->exec_ctx->enqueueV3(stream); diff --git a/cpp/include/torch_tensorrt/ptq.h b/cpp/include/torch_tensorrt/ptq.h index d8570f0e6e..6650f45fe9 100644 --- a/cpp/include/torch_tensorrt/ptq.h +++ b/cpp/include/torch_tensorrt/ptq.h @@ -21,11 +21,6 @@ #include "torch_tensorrt/macros.h" #ifndef DOXYGEN_SHOULD_SKIP_THIS -namespace nvinfer1 { -class IInt8Calibrator; -class IInt8EntropyCalibrator2; -} // namespace nvinfer1 - namespace torch_tensorrt { namespace ptq { TORCHTRT_API bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data); diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 442485474c..02faa26e05 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,4 +1,4 @@ __version__: "2.3.0.dev0" __cuda_version__: "12.1" __cudnn_version__: "8.9" -__tensorrt_version__: "8.6" +__tensorrt_version__: "10.0.0.6" diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 18cd5d9fe2..8f5d1d8acc 100755 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -2,10 +2,11 @@ # Install dependencies python3 -m pip install pyyaml +yum install -y ninja-build gettext TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") -yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo -yum check-update -yum install -y ninja-build gettext tensorrt-${TRT_VERSION}.* +wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz +tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/ +export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ && mv bazelisk-linux-amd64 /usr/bin/bazel \ && chmod +x /usr/bin/bazel diff --git a/packaging/smoke_test_script.sh b/packaging/smoke_test_script.sh new file mode 100644 index 0000000000..d3bed3249e --- /dev/null +++ b/packaging/smoke_test_script.sh @@ -0,0 +1,6 @@ +# Smoke test is intentionally disabled. +# The issue was smoke test installs the built torch_tensorrt wheel file and checks `import torch_tensorrt; print(torch_tensorrt.__version__)` +# Since tensorrt cannot be pip installable in CI, the smoke test will fail. +# One way we tried to handle it is manually install tensorrt wheel while by extracting from the tarball. +# However, the TensorRT-10.0.0.6/lib path doesn't seem to show up in LD_LIBRARY_PATH even if we explicitly set it. +# TODO: Implement a custom smoke_test script to verify torch_tensorrt installation. \ No newline at end of file diff --git a/py/requirements.txt b/py/requirements.txt index 3571b9ac34..621297d46c 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -5,5 +5,5 @@ pybind11==2.6.2 torch>=2.4.0.dev,<2.5.0 torchvision>=0.19.0.dev,<0.20.0 --extra-index-url https://pypi.ngc.nvidia.com -tensorrt==8.6.1 pyyaml +tensorrt \ No newline at end of file diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index f95f33bc74..b2bc0660e6 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -60,7 +60,6 @@ def _find_lib(name: str, paths: List[str]) -> str: elif sys.platform.startswith("linux"): LINUX_PATHS = ["/usr/local/cuda-12.1/lib64", "/usr/lib", "/usr/lib64"] - if "LD_LIBRARY_PATH" in os.environ: LINUX_PATHS += os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 350d8a299e..062abb9a87 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -5,11 +5,10 @@ from typing import Any, Optional, Type, Union import numpy as np +import tensorrt as trt import torch from torch_tensorrt._features import ENABLED_FEATURES -import tensorrt as trt - class dtype(Enum): """Enum to set supported dtypes in the compiler""" @@ -108,7 +107,7 @@ def _from( return dtype.f16 elif t == trt.float32: return dtype.f32 - elif trt.__version__ >= "7.0" and t == trt.bool: + elif t == trt.bool: return dtype.b else: raise TypeError( diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index e4d88088e4..81814486f6 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -2,6 +2,7 @@ #include "pybind11/stl.h" #include "ATen/core/jit_type.h" +#include "NvInferRuntimeBase.h" #include "Python.h" #include "core/compiler.h" #include "core/conversion/conversion.h" @@ -77,6 +78,10 @@ class pyIInt8Calibrator : public pyCalibratorTrampoline; using Derived::Derived; + nvinfer1::InterfaceInfo getInterfaceInfo() const noexcept override { + return nvinfer1::InterfaceInfo{"PYTHON CALIBRATOR", 1, 0}; + } + nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override { try { PYBIND11_OVERLOAD_PURE_NAME( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ed9a0bb7ae..9eb6008952 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -578,7 +578,7 @@ def convert_module_to_trt_engine( import io with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) + engine_bytes.write(interpreter_result.engine) engine_bytearray = engine_bytes.getvalue() return engine_bytearray diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 6362e4253b..9a75add755 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -25,7 +26,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def _populate_trt_builder_config( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( - trt.ProfilingVerbosity.VERBOSE + trt.ProfilingVerbosity.DETAILED if self.compilation_settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) @@ -193,6 +193,7 @@ def _populate_trt_builder_config( if self.compilation_settings.version_compatible: _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + builder_config.set_flag(trt.BuilderFlag.EXCLUDE_LEAN_RUNTIME) if self.compilation_settings.hardware_compatible: _LOGGER.info("Using hardware compatible") builder_config.hardware_compatibility_level = ( @@ -312,7 +313,7 @@ def run( ) timing_cache = self._create_timing_cache(builder_config, existing_cache) - engine = self.builder.build_engine(self.ctx.net, builder_config) + engine = self.builder.build_serialized_network(self.ctx.net, builder_config) assert engine serialized_cache = ( @@ -323,7 +324,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") + _LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 6a2530a956..373c128920 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,6 +4,7 @@ import logging from typing import List, Sequence +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -17,8 +18,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -114,8 +113,9 @@ def convert_module( from torch_tensorrt.dynamo.runtime import TorchTensorRTModule with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) + engine_bytes.write(interpreter_result.engine) engine_str = engine_bytes.getvalue() + return TorchTensorRTModule( serialized_engine=engine_str, name=name, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 26e0d59b8f..6c15b4b5fe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -63,7 +63,7 @@ def convNd( ) # Process weight terms - if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index f66bff7c82..03a209e2a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -63,7 +63,7 @@ def deconvNd( ) # Process weight terms - if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor): + if isinstance(weight, TRTTensor): weight = get_trt_tensor(ctx, weight, f"{name}_weight") # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ce4d70cef5..ffac049140 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -3,20 +3,24 @@ from typing import Any, Callable, Optional, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_trt_tensor, ) -from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name +from torch_tensorrt.fx.converters.converter_utils import ( + broadcast, + has_dynamic_shape, + set_layer_name, +) from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor -import tensorrt as trt - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -139,27 +143,51 @@ def convert_binary_elementwise( if trt_promoted_type != lhs_val.dtype: lhs_val = cast_trt_tensor( - ctx, lhs_val, trt_promoted_type, name, target, source_ir + ctx, lhs_val, trt_promoted_type, f"{name}_cast_lhs_val", target, source_ir ) if trt_promoted_type != rhs_val.dtype: rhs_val = cast_trt_tensor( - ctx, rhs_val, trt_promoted_type, name, target, source_ir + ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir ) - # Check the limitation in the doc string. - if ctx.net.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) + if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): + lhs_val, rhs_val = broadcast( + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + else: + lhs_val_shape = lhs_val.shape + rhs_val_shape = rhs_val.shape + rank_diff = len(lhs_val_shape) - len(rhs_val_shape) + if rank_diff > 0: + rhs_val = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape + ) + elif rank_diff < 0: + lhs_val = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape + ) + else: + if tuple(lhs_val_shape) != tuple(rhs_val_shape): + sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape) + if sum_diff > 0: + rhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_rhs_val", + rhs_val, + lhs_val_shape, + ) + elif sum_diff < 0: + lhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_lhs_val", + lhs_val, + rhs_val_shape, + ) + layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name, source_ir) output = layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 30a5203eed..593094d331 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,6 +1,7 @@ from typing import Optional, Union import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target @@ -21,8 +22,6 @@ from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor -import tensorrt as trt - def trunc_div( ctx: ConversionContext, @@ -422,6 +421,7 @@ def add( lhs_val: Union[TRTTensor, int, float], rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: + return convert_binary_elementwise( ctx, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f45d067349..bbe566d0b7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -452,6 +452,7 @@ def pdist( p: float = 2, ) -> Union[TRTTensor, Sequence[TRTTensor]]: shape = input.shape + # Extend input from shape [N, D] to [N, 1, D] extend_input = impl.shuffle.reshape( ctx, target, @@ -460,7 +461,18 @@ def pdist( input, shape=shape[0:1] + (1,) + shape[1:], ) - x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", extend_input, input) + # Expand the input from [N, 1, D] to [N, N, D] + x = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_sub", + extend_input, + (shape[0], shape[0]) + shape[1:], + ) + # Subtract the expanded input from original input. Result shape = [N, N, D] + # This matrix has the distance of each sample to every other sample and hence the shape is [N, N, D] + x = impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", x, input) if p == 0: # norm = torch.sum(x!=0, dim=2) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 3764667ffb..9031426c5c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -53,7 +53,7 @@ def constant_padNd( ) value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype) layer.set_input(4, value_const) - layer.mode = trt.SliceMode.FILL + layer.mode = trt.SampleMode.FILL set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -91,7 +91,7 @@ def reflection_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.REFLECT + layer.mode = trt.SampleMode.REFLECT set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -129,7 +129,7 @@ def replication_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.CLAMP + layer.mode = trt.SampleMode.CLAMP set_layer_name(layer, target, name, source_ir) return layer.get_output(0) @@ -167,7 +167,7 @@ def circular_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 48a91faa40..4fabebd176 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -66,7 +66,7 @@ def roll( shape=shape, stride=stride, ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) return layer.get_output(0) @@ -83,7 +83,7 @@ def roll( shape=flatten_shape, stride=stride, ) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) output = layer.get_output(0) output = impl.shuffle.reshape( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index a4507ece3e..6f827de2eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -2,12 +2,14 @@ from typing import Optional, Sequence, Union, cast import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( broadcastable, + cast_trt_tensor, get_positive_dim, get_trt_tensor, to_numpy, @@ -20,8 +22,6 @@ ) from torch_tensorrt.fx.types import Shape, TRTTensor -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -258,6 +258,12 @@ def index( cum_adv_index_shape_layer, target, name + "_cum_adv_index_shape", source_ir ) cum_adv_index_shape_tensor = cum_adv_index_shape_layer.get_output(0) + cum_adv_index_shape_tensor = cast_trt_tensor( + ctx, + cum_adv_index_shape_tensor, + trt.int32, + name + "_cum_adv_index_shape_casted", + ) cum_adv_index_shape = cum_adv_index.shape _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index_shape}") # check if all advanced indices are consecutive diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index ef30b186c1..b620e13637 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -8,14 +8,55 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_positive_dim, + get_trt_tensor, +) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.converters.converter_utils import ( + Frameworks, + set_layer_name, + unified_dtype_converter, +) from torch_tensorrt.fx.types import TRTTensor +def shape( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: int, +) -> TRTTensor: + """ + This is the general shape layer implementation in TensorRT. + sym_size.int ops map to addShape layer in TensorRT and returns + the dynamic shape of the tensor optionally taking in a dim argument. + """ + shape_layer = ctx.net.add_shape(input_val) + input_shape = shape_layer.get_output(0) + input_shape = cast_trt_tensor( + ctx, + input_shape, + trt.int32, + name + "_shape_casted", + ) + set_layer_name(shape_layer, target, name + "_shape", source_ir) + + n_dims = len(input_val.shape) + dim = get_positive_dim(dim, n_dims) + dim_tensor = get_trt_tensor(ctx, dim, name + "_dim") + gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0) + set_layer_name(gather_layer, target, name + "_gather", source_ir) + input_shape = gather_layer.get_output(0) + + return input_shape + + def get_shape_with_dynamic_shape( ctx: ConversionContext, target: Target, @@ -48,17 +89,25 @@ def get_shape_with_dynamic_shape( """ # Ger real shape info for input_val input_shape = ctx.net.add_shape(input_val).get_output(0) - + input_shape = cast_trt_tensor( + ctx, + input_shape, + trt.int32, + name + "_int32_casted", + ) + # input_shape.dtype is int64 in TRT 10.0 + input_np_dtype = unified_dtype_converter(input_shape.dtype, Frameworks.NUMPY) scale_layer = ctx.net.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + input_shape.shape, np.ascontiguousarray(shape, dtype=input_np_dtype) ) set_layer_name(scale_layer, target, f"{name}_scale") scale_res = scale_layer.get_output(0) length = input_shape.shape[0] zero_layer = ctx.net.add_constant( - input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + input_shape.shape, np.zeros((length), dtype=np.int32) ) + set_layer_name(zero_layer, target, f"{name}_zeros") condition_val = convert_binary_elementwise( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index e2e8481e24..c61aad4290 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -1,13 +1,12 @@ from typing import Optional, Sequence +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor -import tensorrt as trt - def upsample( ctx: ConversionContext, @@ -35,9 +34,9 @@ def upsample( # interpolate mode if resize_mode == "nearest" or None: - resize_layer.resize_mode = trt.ResizeMode.NEAREST + resize_layer.resize_mode = trt.InterpolationMode.NEAREST elif resize_mode == "bilinear": - resize_layer.resize_mode = trt.ResizeMode.LINEAR + resize_layer.resize_mode = trt.InterpolationMode.LINEAR if align_corners is None or not align_corners: raise RuntimeError( f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT." diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 20762731b0..0c152e15f1 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,8 +2,9 @@ import logging from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -14,8 +15,7 @@ _select_rt_device, multi_gpu_device_check, ) - -import tensorrt as trt +from torch_tensorrt.logging import TRT_LOGGER logger = logging.getLogger(__name__) @@ -56,65 +56,28 @@ def __init__( def _initialize(self) -> None: self.initialized = True + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.engine) self.context = self.engine.create_execution_context() - # Indices of inputs/outputs in the trt engine bindings, in the order - # as they are in the original PyTorch model. - self.input_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.input_names - ] - self.output_binding_indices_in_order: Sequence[int] = [ - self.engine.get_binding_index(name) for name in self.output_names - ] - primary_input_outputs = set() - primary_input_outputs.update(self.input_binding_indices_in_order) - primary_input_outputs.update(self.output_binding_indices_in_order) - self.hidden_output_binding_indices_in_order: Sequence[int] = [] - self.hidden_output_names: Sequence[str] = [] - for i in range( - self.engine.num_bindings // self.engine.num_optimization_profiles - ): - if i not in primary_input_outputs: - self.hidden_output_binding_indices_in_order.append(i) - self.hidden_output_names.append(self.engine.get_binding_name(i)) - - assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) + assert ( + self.engine.num_io_tensors // self.engine.num_optimization_profiles + ) == (len(self.input_names) + len(self.output_names)) self.input_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.input_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(input_name)) + for input_name in self.input_names ] - self.input_shapes: Sequence[Sequence[int]] = [ - tuple(self.engine.get_binding_shape(idx)) - for idx in self.input_binding_indices_in_order + self.input_shapes = [ + self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] self.output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.output_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(output_name)) + for output_name in self.output_names ] self.output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.output_binding_indices_in_order - ] - self.hidden_output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.hidden_output_binding_indices_in_order - ] - self.hidden_output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.hidden_output_binding_indices_in_order + self.engine.get_tensor_shape(output_name) + for output_name in self.output_names ] def _check_initialized(self) -> None: @@ -142,8 +105,7 @@ def _load_from_state_dict( # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() - logger = trt.Logger() - runtime = trt.Runtime(logger) + runtime = trt.Runtime(TRT_LOGGER) self.engine = runtime.deserialize_cuda_engine(engine_bytes) self.input_names = state_dict[prefix + "input_names"] @@ -212,12 +174,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] - bindings: List[Any] = [None] * ( - len(self.input_names) - + len(self.output_names) - + len(self.hidden_output_names) - ) - + bindings = [] for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: logger.warning( @@ -236,11 +193,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs[i].dtype == self.input_dtypes[i] ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - idx = self.input_binding_indices_in_order[i] - bindings[idx] = contiguous_inputs[i].data_ptr() - - self.context.set_binding_shape( - idx, tuple(contiguous_inputs[i].shape) + bindings.append(contiguous_inputs[i].data_ptr()) + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) ) with ( @@ -253,26 +208,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . # create output tensors outputs: List[torch.Tensor] = [] - for i, idx in enumerate(self.output_binding_indices_in_order): - shape = tuple(self.context.get_binding_shape(idx)) + for i, output_name in enumerate(self.output_names): + shape = tuple(self.context.get_tensor_shape(output_name)) output = torch.empty( size=shape, dtype=self.output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) + bindings.append(output.data_ptr()) outputs.append(output) - bindings[idx] = output.data_ptr() - for i, idx in enumerate(self.hidden_output_binding_indices_in_order): - shape = tuple(self.context.get_binding_shape(idx)) - - output = torch.empty( - size=shape, - dtype=self.hidden_output_dtypes[i].to(torch.dtype), - device=torch.cuda.current_device(), - ) - bindings[idx] = output.data_ptr() + # Assign tensor address appropriately + for idx in range(self.engine.num_io_tensors): + self.context.set_tensor_address( + self.engine.get_tensor_name(idx), bindings[idx] + ) with ( torch.autograd.profiler.record_function( @@ -281,9 +232,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.profiling_enabled else nullcontext() ): - self.context.execute_async_v2( - bindings, torch.cuda.current_stream().cuda_stream - ) + self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream) if len(outputs) == 1: return outputs[0] @@ -307,7 +256,6 @@ def disable_profiling(self) -> None: Disable TensorRT profiling. """ self._check_initialized() - torch.cuda.synchronize() del self.context self.context = self.engine.create_execution_context() diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 1765077930..f998ddb27a 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3,30 +3,27 @@ import math import operator import warnings -from typing import cast, Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch - -from ..converter_registry import tensorrt_converter - -from ..tracer.acc_tracer import acc_ops -from ..types import * # noqa: F403 from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target - -from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks - -from .converter_utils import * # noqa: F403 +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.passes.lower_basic_pass import ( trt_transposed_linear, trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation, convolution + +from ..converter_registry import tensorrt_converter +from ..tracer.acc_tracer import acc_ops +from ..types import * # noqa: F403 +from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter +from .converter_utils import * # noqa: F403 _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -323,7 +320,7 @@ def acc_ops_pad_with_slice_layer( ) layer.set_input(4, value_const) - layer.mode = trt.SliceMode.FILL + layer.mode = trt.SampleMode.FILL set_layer_name(layer, target, name) return layer.get_output(0) @@ -840,7 +837,7 @@ def acc_ops_tile( shapes = [1] * len(dims) strides = [1] * len(dims) layer = network.add_slice(input_val, starts, shapes, strides) - layer.mode = trt.SliceMode.WRAP + layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name) if has_dynamic_shape(input_val.shape): # type: ignore[union-attr] @@ -3536,9 +3533,9 @@ def acc_ops_interpolate( layer.scales = [1, 1] + list(scale_factor) if mode.lower() in ["linear", "bilinear", "trilinear"]: - layer.resize_mode = trt.ResizeMode.LINEAR + layer.resize_mode = trt.InterpolationMode.LINEAR else: - layer.resize_mode = trt.ResizeMode.NEAREST + layer.resize_mode = trt.InterpolationMode.NEAREST if (align_corners is not None) and align_corners: layer.coordinate_transformation = ( diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 49bf401f58..510d4ef69b 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -1,8 +1,8 @@ import operator import warnings +from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -from enum import Enum, auto import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt @@ -20,7 +20,7 @@ TRTPluginFieldCollection, TRTTensor, ) -from ..utils import unified_dtype_converter, Frameworks +from ..utils import Frameworks, unified_dtype_converter class SourceIR(Enum): @@ -351,13 +351,17 @@ def prepend_ones( # compute the final shape. if has_dynamic_shape(tensor.shape): tensor_shape_layer = network.add_shape(tensor) + tensor_shape = tensor_shape_layer.get_output(0) + tensor_shape = type_cast( + network, "shape", name + "shape_casted", tensor_shape, trt.int32 + ) tensor_shape_layer.name = f"{name}_broadcast_orig_shape" prepend_shape_layer = network.add_constant( (num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32) ) prepend_shape_layer.name = f"{name}_broadcast_prepend_ones" reshape_dim_layer = network.add_concatenation( - [prepend_shape_layer.get_output(0), tensor_shape_layer.get_output(0)] + [prepend_shape_layer.get_output(0), tensor_shape] ) reshape_dim_layer.axis = 0 reshape_dim_layer.name = f"{name}_broadcast_final_shape" diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 4202e1e96b..5bef21b6be 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,18 +1,21 @@ from enum import Enum -from typing import Dict, List, Optional, Callable, Union +from typing import Callable, Dict, List, Optional, Union + import numpy as np -from packaging import version # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch from functorch import make_fx from functorch.experimental import functionalize +from torch_tensorrt._utils import sanitized_torch_version from torch_tensorrt.fx.passes.lower_basic_pass import ( replace_op_with_indices, run_const_fold, ) -from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + from .types import Shape, TRTDataType @@ -35,6 +38,11 @@ class Frameworks(Enum): Frameworks.TORCH: torch.int32, Frameworks.TRT: trt.int32, }, + trt.int64: { + Frameworks.NUMPY: np.int64, + Frameworks.TORCH: torch.int64, + Frameworks.TRT: trt.int64, + }, trt.float16: { Frameworks.NUMPY: np.float16, Frameworks.TORCH: torch.float16, @@ -45,6 +53,11 @@ class Frameworks(Enum): Frameworks.TORCH: torch.float32, Frameworks.TRT: trt.float32, }, + trt.bool: { + Frameworks.NUMPY: bool, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + }, } if trt.__version__ >= "7.0": @@ -89,13 +102,15 @@ def unified_dtype_converter( The equivalent data type in the requested framework. """ assert to in Frameworks, f"Expected valid Framework for translation, got {to}" - + trt_major_version = int(trt.__version__.split(".")[0]) if dtype in (np.int8, torch.int8, trt.int8): return DataTypeEquivalence[trt.int8][to] - elif trt.__version__ >= "7.0" and dtype in (np.bool_, torch.bool, trt.bool): + elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool): return DataTypeEquivalence[trt.bool][to] elif dtype in (np.int32, torch.int32, trt.int32): return DataTypeEquivalence[trt.int32][to] + elif dtype in (np.int64, torch.int64, trt.int64): + return DataTypeEquivalence[trt.int64][to] elif dtype in (np.float16, torch.float16, trt.float16): return DataTypeEquivalence[trt.float16][to] elif dtype in (np.float32, torch.float32, trt.float32): diff --git a/pyproject.toml b/pyproject.toml index 5f681f5a15..c307381ef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt>=8.6,<8.7", + "tensorrt", "torch >=2.4.0.dev,<2.5.0", "pybind11==2.6.2", "numpy", @@ -42,7 +42,7 @@ requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ "torch >=2.4.0.dev,<2.5.0", - "tensorrt>=8.6,<8.7", + "tensorrt", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/core/conversion/converters/test_conv_deconv.cpp b/tests/core/conversion/converters/test_conv_deconv.cpp index 27baa1df5e..faaf7f2474 100644 --- a/tests/core/conversion/converters/test_conv_deconv.cpp +++ b/tests/core/conversion/converters/test_conv_deconv.cpp @@ -126,13 +126,13 @@ TEST(Converters, ATenConv1dWithWeightTensorsConvertsCorrectly) { %5 : int = prim::Constant[value=127]() %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5) %6 : int = prim::Constant[value=6]() - %7 : int = prim::Constant[value=5]() + %7 : int = prim::Constant[value=4]() %8 : Device = prim::Constant[value="cuda:0"]() %9 : None = prim::Constant() %10 : int[] = prim::ListConstruct(%7) %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9) %12 : int[] = prim::ListConstruct(%7) - %13 : int = prim::Constant[value=1]() + %13 : int = prim::Constant[value=0]() %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9) %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5) %15 : None = prim::Constant() diff --git a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp index 785363ccca..5550d5409b 100644 --- a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp +++ b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp @@ -10,8 +10,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor): %none : NoneType = prim::Constant() %0 : float = prim::Constant[value=0.]() + %scale : NoneType = prim::Constant() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false) + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); @@ -36,7 +37,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): %0 : float = prim::Constant[value=0.]() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + %scale : NoneType = prim::Constant() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); @@ -62,7 +64,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) { graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): %0 : float = prim::Constant[value=0.]() %false : bool = prim::Constant[value=0]() - %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false) + %scale : NoneType = prim::Constant() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale) return (%3))IR"; auto g = std::make_shared(); diff --git a/tests/core/partitioning/test_loading_model.cpp b/tests/core/partitioning/test_loading_model.cpp index b42368fe3e..67c42caef2 100644 --- a/tests/core/partitioning/test_loading_model.cpp +++ b/tests/core/partitioning/test_loading_model.cpp @@ -7,7 +7,7 @@ #ifndef DISABLE_TEST_IN_CI -TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) { +TEST(Partitioning, ComputeConditionalLoadingGraphCorrectly) { torch::jit::script::Module mod; try { mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt"); diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index 62bae5756d..7def168249 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -58,7 +58,7 @@ INSTANTIATE_TEST_SUITE_P( PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), - PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}}), + PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}}))); // NOTE: ViT tests are disabled until Python 3.11 issue is resolved // https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt", // {{1, 3, 224, 224}}, {at::kFloat}}))); diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index 4cb9dd9f8d..cc9fdd24a4 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -29,7 +29,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values( PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), - PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}))); // NOTE: ViT tests are disabled until Python 3.11 issue is resolved // https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt", // {{1, 3, 224, 224}}, {at::kFloat}}))); diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 035b957865..e06239eb4e 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -15,14 +15,18 @@ class TestArangeConverter(DispatchTestCase): (5, 0, -1), (5, 1, -2), (5, 3, -3), + (5, -2, -1), + (-5, -2, 2), + (-5, -3, 1), + (-2, -5, -1), ] ) def test_arange(self, start, end, step): class Arange(nn.Module): def forward(self, x): - return torch.ops.aten.arange.start_step(start, x.shape[0], step) + return torch.ops.aten.arange.start_step(start, end, step) - inputs = [torch.randn(end, 1)] + inputs = [torch.randn(1, 1)] self.run_test( Arange(), inputs, diff --git a/tests/py/dynamo/conversion/test_erf_aten.py b/tests/py/dynamo/conversion/test_erf_aten.py index 3f52e436b4..d9d201b0ae 100644 --- a/tests/py/dynamo/conversion/test_erf_aten.py +++ b/tests/py/dynamo/conversion/test_erf_aten.py @@ -22,11 +22,7 @@ def forward(self, input): return torch.ops.aten.erf.default(input) inputs = [torch.randn(x, dtype=type)] - self.run_test( - erf(), - inputs, - precision=type, - ) + self.run_test(erf(), inputs, precision=type) @parameterized.expand( [ diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 8013768214..7f43234211 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -24,31 +24,6 @@ def forward(self, x): inputs, ) - def test_layernorm_with_dynamic_shape(self): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.layer_norm.default( - x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - True, - ) - - input_specs = [ - Input( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], - ), - ] - - self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, - ) - class TestNativeLayerNormConverter(DispatchTestCase): def test_layer_norm(self): @@ -68,30 +43,6 @@ def forward(self, x): inputs, ) - def test_layernorm_with_dynamic_shape(self): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.native_layer_norm.default( - x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, - )[0] - - input_specs = [ - Input( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], - ), - ] - - self.run_test_with_dynamic_shape( - LayerNorm(), - input_specs, - ) - if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py index c49fc32c23..795a78354f 100644 --- a/tests/py/dynamo/conversion/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -22,11 +22,7 @@ def forward(self, input): return torch.ops.aten.neg.default(input) inputs = [torch.randn(x, dtype=type)] - self.run_test( - neg(), - inputs, - precision=type, - ) + self.run_test(neg(), inputs, precision=type) @parameterized.expand( [ diff --git a/tests/py/dynamo/runtime/gen_hw_compat.py b/tests/py/dynamo/runtime/gen_hw_compat.py new file mode 100644 index 0000000000..e279015aa2 --- /dev/null +++ b/tests/py/dynamo/runtime/gen_hw_compat.py @@ -0,0 +1,33 @@ +# This script is used to generate hw_compat.ts file that's used in test_hw_compat.py +# Generate the model on a different hardware compared to the one you're testing on to +# verify HW compatibility feature. + +import torch +import torch_tensorrt + + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + +model = MyModule().eval().cuda() +inputs = torch.randn((1, 3, 224, 224)).to("cuda") + +trt_gm = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=inputs, + min_block_size=1, + hardware_compatible=True, + version_compatible=True, +) +trt_script_model = torch.jit.trace(trt_gm, inputs) +torch.jit.save(trt_script_model, "hw_compat.ts") diff --git a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py index b10cae23fa..46a9ab392c 100644 --- a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py @@ -25,12 +25,10 @@ def forward(self, a, b): symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1] ) - # Deserialize the TensorRT engine - with trt.Logger() as logger, trt.Runtime(logger) as runtime: - engine = runtime.deserialize_cuda_engine(trt_engine_str) - # Inference on TRT Engine - py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"]) + py_trt_module = PythonTorchTensorRTModule( + trt_engine_str, ["a", "b"], ["output0"] + ) trt_output = py_trt_module(input_data_0, input_data_1).cpu() # Inference on PyTorch model diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index 29bd17cfde..fa87c9947c 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -75,16 +75,14 @@ def forward(self, x): "HW Compatibility is not supported on cards older than Ampere", ) def test_hw_compat_3080_build(self): - inputs = [torch.randn(5, 7).cuda()] + inputs = [torch.randn(1, 3, 224, 224).cuda()] cwd = os.getcwd() os.chdir(os.path.dirname(os.path.realpath(__file__))) model = torch.jit.load("../../ts/models/hw_compat.ts").cuda() out = model(*inputs) self.assertTrue( - isinstance(out, tuple) - and len(out) == 1 - and isinstance(out[0], torch.Tensor), + len(out) == 1 and isinstance(out, torch.Tensor), "Invalid output detected", ) os.chdir(cwd) diff --git a/tests/py/ts/integrations/test_trt_intercompatibility.py b/tests/py/ts/integrations/test_trt_intercompatibility.py index ed3d906386..2ee3f7bf7a 100644 --- a/tests/py/ts/integrations/test_trt_intercompatibility.py +++ b/tests/py/ts/integrations/test_trt_intercompatibility.py @@ -1,12 +1,11 @@ import unittest +import tensorrt as trt import torch import torch_tensorrt as torchtrt import torchvision.models as models from utils import COSINE_THRESHOLD, cosine_similarity -import tensorrt as trt - @unittest.skipIf( not torchtrt.ENABLED_FEATURES.torchscript_frontend, @@ -37,18 +36,19 @@ def test_pt_to_trt(self): with trt.Runtime(TRT_LOGGER) as rt: engine = rt.deserialize_cuda_engine(trt_engine) with engine.create_execution_context() as ctx: - out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0") + out = torch.empty( + size=tuple(engine.get_tensor_shape(engine.get_tensor_name(1))) + ).to("cuda:0") bindings = [ self.input.contiguous().data_ptr(), out.contiguous().data_ptr(), ] - ctx.execute_async( - batch_size=1, - bindings=bindings, - stream_handle=torch.cuda.current_stream( - device="cuda:0" - ).cuda_stream, - ) + + # Assign tensor address appropriately + for idx in range(engine.num_io_tensors): + ctx.set_tensor_address(engine.get_tensor_name(idx), bindings[idx]) + ctx.execute_async_v3(torch.cuda.current_stream().cuda_stream) + cos_sim = cosine_similarity(self.model(self.input), out) self.assertTrue( cos_sim > COSINE_THRESHOLD, diff --git a/tests/py/ts/models/hw_compat.ts b/tests/py/ts/models/hw_compat.ts index ab43e5e040ef3850c8e0b879a973dc759565830e..ad974d16578f904c40be563c6eee0bc80f53abe1 100644 GIT binary patch literal 110482 zcmeFa*{kk1wz=nXMzTssRJCPD4HtHxK*pMiRlBix&q$D6f zeCLPAL*woj7KmEIZ_ou)6{I7rd zU;eZ$eoM~B%i^oNT6``3a>)*s{Gb0c|HJ41`d7(szkQu=$@aJ3{_^?$%jWr)^tQNc ze>z94K3Buhk5T-ibJu^PMz;N@|Gedi|CT4NPwDQozWxt?`>4MphsFKAzW?Rt zBmZoETK@U^Z~x z9)5qR;#pX>5-xTc+Po*l_h3Guo9=|k+k`?1vs)aS=<^XItL{X!I6`R-p2`2GES;c!-6mmdn~-{bsoO`ZBY zyId}7<Vqx;qm*~X219z_V=^65VY&%uT63NTt9vKRc)pCMNT6<+sl^m z*hi&${_^wre!5>)!*#@Ii$aTAV zs*V<={vdhg6TO=+)hz$}=f6|nzbysiyQTZF-U*EAqOIxLUoAqnhDjnzk zGxyHLY+ng#X`{4v^-eBE!}Y#i8SaCUEp`KcxqqGCmHT8f?L2oub$CDPkx@K4H_>V6 zy_r9|=Vzm0goWsI^p53kW23#SPW@7>>D|N3k7}jvtsOnusdhgvFaFd?#Ii9f!d|I$ zGfMuXG;rVbT^?mx(JDIrY_;DT74P^Qmui#vmY=Jm?x<}~!kyOai$k8NMCiTbsp z)=1vMUxo*3S*wM|R6E$l#viGF56=jpr#`dujST0EH_j8yf0Q+UqW9wct7VURrK6X2 z{fp>q`?ajID0MoW`05|)THb#?j36uN>DkS9&#${tj8k22^~}}C*WXraS)WaOpY zluwN0w5&ANBI|U@(``P8+x30j-EI=?-a9|@Zck{&ySx+YQgn;^d76fuTU~6PyV7~k zpFBFxnlWnBW14xl`z_D+)%)nZxOLB3V1>a@yOgz2y?zQEZFoti`9g2^;1}Z@wf8qK z^Ve=USk~%+7i*rAcKlIy=l5G3?{e3U(N_A)XRUdwd0woyTiyH9Kip3B{o9!ai&Fh^ ze)!j@QEtCtYdbE@FVkUe>p}CKUpIkSaz|d%UmvTZQG1&1j@m)r4fA2xYxj$K3w}I( zofdg4i>6U&wB4Sjh1X#|cf68OEVAXabGUcj`Tj6;-j&X9BiiHlzMs#!Ezz!2wyfXK zsTlb&?-}pba$!ZN#JJICFlgXT-~L-N(%w;;UGPK|~1t#WPuuGg&IQOm@t z;YWVxE?e_{I4<>{{=m2VXnU)cjNH_UNpw2x@K#^di?w!t z_RGU|^PU;^Su`4!R*TA)A6~M7(Lw{Ion2Ya{ka%5Z$-3fJhXdi_Vaw@RhN}rui4Vh z%Wfr~=lABjVmi;Fl!(&BA6?DIzIr~rS{=)Zd;VHzx$(GJ_p8w_@*Khlxco&YA zukOOXHBQ}qHCaVY>HZ~p{v|$6`mJ5oyK8f?JNoOLyID4(`}3}4V}Iz6!uZ-QA~(Nk zj_$a_O6fFs*Nb;9I)lni_ZRN{u&*{7p;^j{UNCHl>*8IjKjf=4mXAvw+)e$~>gM}% z$=hqgeBDcX?dQqani|q+n$KRl&8^nR>Wj`=@1=g^_KkKcd*)GF%6rQn$DR27eh8^+ z+odb=v)BC^)!wrn^uJ|&f7&}AZqLm5aylie(#pMlnzbwDd7ZV#J3=VznRnI_AbzZzx3eXZsV zBCVYd>wYNm@cm9#`+LXP4AbZG) zb-bQA%;_^8Rh-jRv>(RooF|1SMn!O|SK8j!G0SJq)29Ebt(Qrs-m}N~^1MXC#;CfMP51Rd$sKEMWiaxm(e(7L>bv?v3vZjgKY3@HQ8gHp z;?B$#xt~^wpx?YNADOmI$BBQNg`!{YxVMVdKDC)ziE%sQIQS2 zT6^_ndVZMd+jebGi>9UNz2W=g=Jm3#zAuAo$22=d!wBc_3M8B7J{nz+8 zo0g8*B*=qma=blP&t{hw9+KbNk>eTa{5ofoemGxsS8P>F6!`N=I z4wkRDq{*-9qqM%imaks7q4YcJu?wqmUl+Qbwd%J<#b>|X+I4B^_fo%q9Q(7^`T4Fj zzl@<)4A=Q1eY%HsT@U2AR9_D?e;FkMqn^I{(^g|y((>?Cs*h@~+H9xQ_x31X?+0Sp z?j2tl{2#pXdti>*&(^}T?=b;<#^TMoXuiCp#|D{{IQgq&;J=!mP_J7d1z2kd1>+JNc;m2-a@2fqn zUuc#wt@{EIJE{e@cS8%B?!b?FEp1u5NvEWRr^4}v&CRr5ZBF7!M-<7`-xu%6{^`VP z=5O3%e>%3;x0*h16Ky_h%zXc)P4|`JDcI27Yv$X_&Gy|YCdZ04IXm6_F&C?TrCwZz znm=g35A|@?ucnQKqt&v(*$;cm5&Cg*9%`klY38~9ibl`cc=^)Y*)q#lgKN64RW8_% zS^v==Ci?a^%G1}YRl2*rn-7{nx7~1N&Q2?)*((2fH5cgnUL`B#d3o!HdJ*5=^JwX4 zFXw*A?XEw3S8F5mckSHI{m2*;`}s zY4hNA$tS_{u&iH?SxYUCr}2K4O;IZO9mit7?FQ|}(AZ8(y8f)>QPW--HDp{| zS`V)WKk$d^Wqlz>St-h{i9gEQ*ndUrOgpPXA9**s_m>SXhrbu+NPpxnZEMu`SsIm` z>b;uZ9+hRgkUsj?4j18Y+gRN2KSI;#lsg0LDaVjTYY=%4ZJlME{C=s-`_*{0s%VDe zp7S_;8~sXi)E1rieoXyU-81{Gs#$tyvuh*tPvcf}YXk?Ztgt?K`YV0bGukgPuW4sJ zo8-IT;u=pat6jC`M!Nn=ZkNV$Rlm2i+I{Bb(c={_+gI^Uw3EoT`OF=+mvv)lR!Ym; zv6Ihq*X%!zm%P)N87Js1uB%ba8D8|x&g_2q^NO@?4SN_n+I}Zje$vdC4`5Z=Df(F@ zkHvlX{>n7yIm6n@hx0G_Zj`#iM+alOo%dFC>s?C0 zyZt8W_H3-% zk5X?l>eu)CeM2)J(^1~do%j2_^_N#|D)V(dS>Er<>gcOK)pW1B7+&I*`L2z+MWZBa z*MIEY)Wmq4{nOyXR8p<>kYC*+1UVXAdXSmgc>b zBUF0J_NsavV{fNt|9qZKmen*FJ~U4^JAN>z4zVwcsHb%=rM@3L>vGgw*s@W&l+1hX zcebm*+Z(zxn z!u7&mcUY6h@@Rd{r$IV8PAaRK_FBAv9An{p`u_c1@e7s|3)P1LAO-E{)BIqCpWi^0 z={Jt4aMPck=;wj%D7t;$s?MDhZmKWT{XeivPP#Sq`!ja=J1qEf>|J%==VNQtRaaJ5 zhNyqW@WEikP#u5crJq*7mW)r>f5d1%7Xp!b!WnUFuuy#dFxSuT-~ap?Tk37g=4UNG zYpXg|H27Tq!-PMFpNoSA4ZiztpE}uJk4=t#|NUugf5xiDA*X*$lMk;!st+2>b=xh+ z-O@(2cRv{IXFObhrfPihgC+k$meu@0*cn&+X`}wcGxy{jI!}hk`Qc~&#^3ZMzG({W zKmGr{a6)hKWvuW z&-j0z=lC$ z{yy!Y#|Qf#Si$f2|LDkF+tsf)((23O3b_Br$2z>1qqLaSw_nStwV9vfYgw0@RjqlN zTiVfkyXxO&zQ^FJ#c%JRJ>uH7zNigeN%i}9Ru3;rd-7h)E{9<^eGX2ae~npvwM|a) z*ZH+!efMj)4ew_tq5Hi`pV54u`>dt@`-%U*zrUC7vsT|UQC+M&G~ef6%FrzpwFk ztIybwvVZ+iQ~b3UKQE~%vn&7l^N;WUt2V~>ukYdfpwKrTe#dsyp8Wp(Y50=|LN6E( ztCL5HPT+!B{A|Mz0_!u5O#jIzy4(MGpGdg=9cTK(Fv7nbxVm<>Rcz^RK0UV_j@kIB zDZS#;*HNT;NpunJf!G#yqbm$kr1rfl{P9p&!Wt`)#P&y!(HU_%=Gs3E&Pg3*Z(BQ| zHyoj_noFW@nm*Vhwe4e$6z(6Jwf^z&m>)Z|4-e~Z@JMQvhG%Jw<+Qq;)#gQV)Yyls z$GjGP_Swyvc-F`z|#P2P8%X%f!5zZ`q1^`3y%|1J5hqXP~ zuV&$4R{J_6$L3*aX(}EGLA08%w45<+lj9^?3z>}S?qO+LM=>?aoV(WpD?8N&`SLWV z#M5TQ^Bp9$Ru=Z|HMb|b)pUi-aJS3p6v0@8LSD8?A72Yy#?$A04&0sC=}mMR_JVj{g#JVZaCax%g(;LiCj^cZ6x!Lfi>_ zvIJH~9kUV~+xO3Lp^m%Gr;R#qFK0?MnOmg2ZUdMlVwLhZJ}#WkB1*c2al?)-jL;}v zGIfZqihVha#aS13wq_vuZx%WbhO`2jN4c;ZGgOOjniwrXY^vv4f~X|}+wnjiYHpqH z{`JH(gF@&?XJ?d{M{%m^fmx-I35@+6p8cpBVc2U1Bopp4MQ7A5(!eup;-b=ks(lsM zHHwH^8a$8=2lOE(92OS*E2L7Q>5O4XS&W4xXKH#=PMQ`|>PkY@x=`0e@cCnXc5J}o z`0b{_dh<X{>2kcLl#-)`3(l!Lw~7y3{Z z#*Y5HHS4BsODvRE!)eDW>3lqrzdyn(TnVEHKM0wl4$sqVT0AScq|t_TuFz z8k`Mv&dECj=H=e)#yWJwKDCOcqEQ{ceWB>3jAi2I$S098tGbLjn5bgPdN5^%0;(XA#8N=NH zLogC+y4{l(0r$vmp-e#Np-cK_U1C!S3%QNPUkpewAmd@u=)N%>HjHP za$FH}R#3S(j@cX@3uhu)LJa6eT-mNl0;XG@Yev@tBzld?9(JfBI(`GGGv*oRt!(K* zCKhUOA8HOsz;i4C&lkb7#gJ6@tU|x0+%Xr-5{Wx{vWv~?o4aVBkByI%nF(cSc00n_ z5J{<%?=h3XYnT}z@Zf7h^o9jDG44}Kms!L`sT!hwB*3<5=@T5^> zK@*NL$I!4~Ye5la=L=ezz%s~NleQlp5XY3~!Q!dJ){=#p$CfAbkvofrMj_cza;#hq zv+K&>!$-I!!kkJ?^<}eVVq=uby#qU$U8f#QXnbS+;e?M2oxmuk)Mgy4NoJZ|80yX8 zDXk8Wg|QgyjV>#j$|Wn2Dj9P-Y5u5)bcUT|5VnybqQk)w-qJl`)>x+HSy(gnd3XBI z;bHj5i0h~kJzEF96ss3u>M|pQG)zDQ+1g0zACeLnlXa&$!1Q;{iZJFOfHdUt6yoBx zBsL7R6lHA>M|<65AJ^a}o-55{Kk)`5E~23WW$V~If+<*1(p4HREHUqKC_7R#BKF6w z6LbO*i5mAv*NAmkWHe~=+y)XsSUXwBHND|BLw&_8sqX3FVCRX@P4-DmxtC^uWKK~K zu#MR*ETbCtg1nq<3WU=xg+{^-gVpQ9H39l9^E@MvF&0eNbXg0#qq9zU+%Y|b^oTQv zbrruRtqFQ~=F#(1@!?s`oBbb8XK1TPOd3!R?#QgQM6p>8vMKf^YL~E|>;T1P$JtsR zEYvBnU*sQsgx)t&UD|=;8GT2164EAD3A=wc!4Y@DCe#t}`yBY~2SDSL6>IZ-(7s zYDf}~0!EjG9jmC!GN}1x5!_JNccIT(5Q4t@GX3+cAcm8A*eCF^HT*$JAO`JZo&R%NQlY zPmrM|ItPU(1h!1ZD%jmPWruf**ii5R&DV^efY41HI*c-IEIfFtVh5Fpvh- zA~riBn>j-R+B^Hs{Ihqv&K}t&L~YkxD^+GSxz(%yJSYgpVvUE!UC6c4`>ypeB|Zvf zVzLirYtcYI?tuO*)h;lV0ft&^Ae( zNsu!46*B1D!QOY4mflc0NUVa6#c#t>PF*+l#m-`LWlcC6kF?P0ZXm}W>=OBXkD^`jA=9bdX&N4$+@_aE-<2I2b!tkH~x}?i;&_B8n z6N*kbR>q~R*3a-ZoPx#FT!INZ8b0N(F&>F6l2jlEjd-dD7iD28Sg_9a#=;a$GtDBT zRYj(8k8vyg@<#B%N3sb6Zj)^?M+0GMl6&w(-)$>PCmqu{eZZi!uew1RBBd*%XFz-S zpN^CqA3Q5UjnPD-E2{uOaWNUU4920-W_F#WJ=1nX#8R0ToMZf3OIJfghZ*S-LB>|k zgo|v2ws?!C0bT11;LLYT7I&~SIg^Su!2{dC2jxtYiZC<%`QdFA)~Q&buoxEl%bE0T zWQvZwv(C!1i=MDIT2)tk>8w-L?#;BN(D|UhwR%EV&U?<^dGT3r5 zr&uc~C9+rARc-sUO<9vm3%$$lp6;l%WdC@Dlg>H7vZD{D>!doiPgN}n4W7st&TVvo zt>K(9qX^Pui}A@g2HT1wu|pQ{U9v4)>}}E8-j#WKuu}v$&J}~d+5yM<`J?Q5+oZ?! ztCe2WxrtJM!!YPe)~2`h`douI?>1Dtn!y*e90$IUf*>AzwDik(7ft$X8v6oBp6^l7 zR`(d>L#b!!?<)6XJ%5Cqs-ImSAvg@u(oprECMDWleHkT(_57B$z1j^Ld?F$n{cEDb_7v^bSJ{1>3M=$I!d>q}Z zG=w(-Ztgk3sZf=CBp(LJvi?5qJ{rR$F_AEyA;vN;3jI>e;d_ zJaaTxd%T&jCq;b=(>0W9QIE?Vo$n6kc8o6=1RWFnl`*PK$8$b@g`9JJQ!bQ!eZlt& z)f6|n>FM0Abeb0XWLOG`6d08Y2g_6*mdL@-sXM8wlkJt7shlPj{n%w71~GT`$+nXz zCc~>@K7bHi>`uR2m{FfV5q9{tkr8a+dV}k@sr9W@u0YqWw{mr6tHiw!ZO(TYa@y5V z!_rIGBfU>wtdAq_5Kv5QtS6JCMj?qD)7IXpw&MlYmXWR6W#8@KL;Fx=y+y2?B`ahD znFF)T{!yDx-D|xT+#Aoh9L2OZ%e2wO^t(L>rThZdrd{-D2j(pnGoJ6r;KcDr39vAj zG#2oP{^2b#a@F{rjL-JWHIqQkF;T^LB2w3^5T8bV9hq-x;vF#(xCQgrm{brHy6E;<2;2>5HzNv_NB|4 z9V#87+nG!)FbRd@C_0+`w;ad-(@Ed<$1f3#qrpqOI?sJDh6m(P0#dE4c#6HVl{(jU zomUnpFVl*8x@o5%A@~|f7BCO&8Jgu&j}IzF^K2j={1cYlZ(< zoHW_Bd{azQG2Y>)rs#inMsfz&8+>9;J-DEcUmq%e(H0m1QKEBI5(vi@lF51q&-R4r z8bRYMA{#r@WIt{O#xaoA*i56tX>t@pr{^HvgVVsc2GYoaS@)!TWAQrcFL0p zdkv0*#tpfP=`XkiC;S}n6?b<}rFr&}R5F*iA@u|Wp{^spjKpwX+0W}H9%57M5_Lb^ zRm{+161yptX^?WC?!hobrpSiK!96vS`Iu?LJ-bK1= zq-9|!D-4}A&vC|ahOu9Z5p${Clqc8DMkZj89)#=|pG!YfsHkuw?m-gvrJ{!@0^nAg`BmTp<&1iGk zaIKb*I@FM+YU}dbGi7$1*-b{rvI}p%q|f12-V=D)JwBi8>o>N935GF?_Cx?gtuWqc zFtQ7H-&5Getp5-*B}=iZEQbVsd>$|xbv136=2OW#QAp=B75Fkm22vo7#+QhC6VGQ1 zbuu~?d1hw!8IF}t1KOqsZqU_*weDloV8PBG>|*5lq)BIUtkdjL!MRH>CK^FxM>!m05F(&FHj;*AUE3iZ z%rB77z=kidT_lM4)>fA6YF#M*{1TG5CA!P2!Ucxo19*x@7r{uo3uCwwcbay4Iw^5t zHC>EXNYE${)quUEY~3@@o3 zju*&*x*b*CWASxPM~8!ti$&v_jr<>r`xN`zROtWUnLNCS=nsoYLj}GDTjYH%>?x*} zJq=rquompGNq_HS3tA9!JZAChCbHZUU&nNEFqdEm4H-Xa2lkUqu;uC3%smZ!CN`V% z3`=u1xhlUq8cz^@PCt*ascXWvOQS#!l-Gyi z3I@mtRwUiPs}@YT;hORRZ=jFJr)Y(_3|s`XuJ$1IrmMy(Ti@SCQ@E%h+LvgbDjTZH zdMjCv-S+v${Lgx38U}jUF*rU#kFHJqgb~cuvrzXd?p1vx#gA|JB*!SsKKGCc9C()q}ze=9TTKA`jJ(^6ih8{sKMJ7aax@j-8o%kB-VN>hJb#5n5+X zS4`R+PK+I;YcrW%Vd4AX;XWf`Vs?r*Y)rl)0u@9Ne1Sf;?ug%N%za;{rPAXl=5E9S zv$VtM>a?JzfvsJa%nz096j-%|NZ0+EKL)iDK!rxaYP3C8ZM}<5y5M%XTPVpzj3l2m( z8iO2@4T4wvvzW|jI}0uNPMivZea#%-jDlM;MUQ8MV>B=?{rQb133D{*OF6N)^&lLy z5Btd&KZPwOjQJ^=vfHpQ=+|ZFD0^s7x#J+Bw`=pR>Sr2RuW$?yP8W+>714rt-YPm) zzoH(`!7A$2tsbPQFF@CkGmypUA|sp8_@er)gw>Ub0U(e?eThk^zUR|9GFc%IR5~F8 z2Qw47Od|#t^314;n-h<+y~1w}tJ(a>vsAlUqQFcAV{~n#0W@<3AP{^db&J1bj{wn+mI1iKAZNr#1AbPJq^KpwCwQ>Q*W;gKsF;oez%7PX5- zIu$l?Eu9V12miH5w(x^EY=-)9*T)|=Mxo=R08O%Yb!2BOBe5}e=d}Ay5MpZm)3dsc ze5!c5W738|%Z9YMb0+{(?(Lt55+%2RKh!z&Z*O14dFI?0KH4Ft0jD`fRwL(-7V>p2 z1<11Az6tj^oExfpl^mv>Mv@BWD%h#&T$6KI=%{m{iL1A}DBD?gDW;T9G&)&4?GLRF$nqBHCRT(L!ObJR2G0>Se}}FRlRg2^rY*&CkeOE(-k6JYpJd{uj@kd4a3Akr z@zhbaN#>mKsr4=Roi5#m6FEwUAgj!ZyUUPvluR%WdzMz5bS7MN&+dFBy01go(?BzP3*YYU$5XZ#f z*!Er=XH#=M3OWs7tBNS9?pxfUm)h!;{zyFF3si`Cu#6|@xS1OJutMNNHKaP^A-u_J@N$f-2Fed|fz6J8;}{f`!7M(0j5ju0 zgRHSvI(#e&#VZgG%NjWj&$a1p(o_D9o_-%I#iwVnDU(SLKh)9h@mZcHJ;sZ}8z&rt zyiS#u3oS!J2 zkw%~J?!ieIGonTnUTGmB(!;Sj5SnvG14CklTxqf$WI9D0P_MI5b>qx|BO>r)4C#Z;z- zob>1`^ycb1wv2Hu>FKNh5$j2XK?G`_F=fo5LFV)&vWF$!C;LeCT~TxCslKC66#J|f zz;+}a;queJnR-uAwziH+{mdltr5B2U8Z$=942O#lC1%_Ed^5=i0c8pB-#>eBn0ahB zXdq1{rP_;D7@7!O2b@|{GCKOct|qPA%HV3IpFW~bAK2Byy%?!B;M}Us)hs^ON9wxe zjnI($!Wx}5S3c2yafUHVeiMOWeI$2+FAeu;*5$t5Zp;Lo%nhroC{Z@! zoVwws--Egi6Vlc^viP)(QTWc)-F~hgKi5@XnR8OtT|H6PKXdEtk~PN{_sQJ0_pBqw zWC{ioDi~Z@N4sNhPN-*;q>|vf@!&uuM*4gMv7n~0MQwrl6D4S$ZkhyPyT5*$8w8k1 zauuw~b&RdlAznp%j<6iY=XXEdk}?1bC6E8AUlUuN=NP2F^6A)WcJXfF3T9Q zcr5FmWBB$D7{jlAz`1!2{(L%z=8UmeV_T?a?Xec)&s-P-Zf1b~8f^B3lG6d6YB9qB z?gmkZ)U9mYFWg}GdhUoxng%p6LSRq=QP}(nXC*p*qR+d7ExXq7!`1kVu)}7|5^J|$ z3_s{&;56fBFwZI6NBWGQRF)X=86j;8!+vfsQR5}yIo&gGDusMCbW?8-qCg&QDm3ee z4p8(bX>*EUXezy7OG@a+sYe^)t6^WsQQ~f<{bb|H6bfNS%NTh>H5cQSqB_6ONq$mTO#!J`+YB2@)u&bWDr>J5C*36Rbemrdwy$8IgCS zV%;kDWKTi@k~zs%x=}`ry;Y6_^QdF3H1_G5AQ`C+Mg=t(Bas>xM?oCe*j^E)p6x~J z9Wt*A3*!zkseR7`^+qguApFh>6lZn@)3P6%!r`qH!PAzGch(uB(J`Rj3h$97tnA1Xj1#hsa-kX_9J)6A`G8l**GJ zA$7c4W0M7iS6DqKTc`6#kb{KjMjXp|s$4z~2aQZZOae)qxTAU`Oic&Hpka__rUn@l zD`0W<<$;n&PCK2>(Exzz%_Rb+Uo^bHek|X$C~XDtFJN^vTWC4U7d#x z8sLS!YagBx6RDeDJrD}z_u*>J#gyEp8o3-IePA-(p+|4q?6AMzLwl_U`WYIYiEV{ktX&#rD(5op5@kn`s|Y| zumyGU_@KxR`DL5y$kZ$3MzLrL=GZ!ggWwx0`)wf}NFj7ENzmB7NXRSVN&U7zzvIt! zoBi zwhO|7>|-$jTj)o{7Nw>FG~mnbS{9%ZLgv17NQKp{7+S)2cjt*bvjk>0G+lJMDu!q( z&_fUAM55JetvX&5WcCUR@d7Q#4L0QyKQi#xO0|BNeyQgK#{K4mfVl=ZvWe`(!RS)L zZM|MxWsF9G<|DpLo{SZRMK?vg8>Ew&!~&Sh@EwfFKS3rKD!F{?O8g3?u#BMZ>A)JU zio-kMIz4x>07Ure30YGl#1Y&sn7nBAqskqc?fUvVRZxMjZGBy!2f7ug&+?3dzPnYi zEoB8d6+I|i)((4_TrYVi^p7mjvBS5R_qG?T$%6nD3Hsor!nltulC69hU{Bzm$n#vs zuTnUUE6^9DdH6;$CN%&WINc;2h7pyJ!GTPkG@14n&`6)O2zCc@WRlv~_0gRrT9#Cfbk*!& zyjf}A9IJgwK3x#izy*6qK8Sq4>mv(kn;9wAf(T8OU$Na1F$?TrCS(YmAS1z1(}bP$ zY>cWVKQB#Gf`8Y_N!D=Xu3aWd6n1yy;>3rlB+svlj3I6(<|b~KH1S7C@g|4C-Wpdh zU!!A=BLWE0G`Dv>=oX_uWe1#;l|qh#18#QEJFMVqNc@egLLvJvWYG*%6AQ4}L^}kc zUh|p?Y99vR$m37GM-g~tL9Pkw5!J|NgkS*EiOKJcM5>D|PsTbD7 zW8qz7k%NI^l2=Z=kOyZ^cb1aZq>{Uwl-C`wp99;n*^_IB@P0Ns!(iZYMDI(~E$}VK z-QA^g>?_?auxWR9S07>Piwr(?7V6lG=grSvPm4Km0^&tZ&wQ{b0sindvJ2x3uEc6% z=~Od;&&Z42C^->6=`E^RY+ht~3HsrTeq!7eGfdCQ&mS)=WIv;@!jW_Di(@K|yroIy zUxvy>g`IZh4!pjHpN78s!i!|FVE=|sd}cOIce~;)s8w@}YEQ2#%NA885+tQkw+?=o z(j5r~3fM~k3Lj*K(s$!WO7t#9z$<^#idsm)Efw}3pJF!K!E0M{7Z3v$(6IzxI0#!2 z@q!*PSe!GLd42SN6O+K=;D|^F5WwQB5$3kJW&ZM0ozggRp3}%$Z>qNY+<~4F)0y>F z_&)dt)klobR5^U{0sd3=bPfIne;tQO&W-F-`BGq>Gbf$I9FRZA5S*aokekdMNTeTf z>fnYvTY%F|yDBb_tQ8>*s|bC994hA<1C5**VsH1fm&Um3X6(Wv(F;1Od zJ$x&I@z49*7=!Z)0;JyMyOBw6Co3QS{9^F=1(9-ks-| zATIfrFTg^V*zK$4RvIE?K=*4Eow?+MtXN~T#q9imw+Ti`-U)10wv5j**>&WV_!OKJlaL>9N((~dAJ{`n!&64(pNycZ)2P?Oq+rV?vsEO@0R<}>R< zCT}lr$m-Es$WYyD5R)LMP6A;neZaHsXP$L+1sd8g7A!e-arsk0A@xpXl-Ltegn89I(xKX1mPZ? z=E0B|y4<5`FUHW_7y7YFjR*>rQc3m?Mi;X;(u+!e)vNJb@KO%Ejvn!E#2>{i)?MXZ z2R*W7ohVK4yvJSk?;y(TE*ghe`X&ksB*!=5*-F9dPh)qp%XFI5NWz`CAS_F`cmjMWZ%wM=*2x%IgP``C;sn* zk_K@t_8JHBf&e4!Z_%TuW7Q_ng182OxTvNiH$X8Z1msl1eSP|=Ubb};$6jv%k zi%Etv&KXAfD0Cv-QeOsG>AI(M0(Sq$qC)k&by01Czi8fQ}lM>3x3^hyKo-yHF*NIfgVF@t`r79;()m*Gp0bbYAv%!shje&?w&akq;0?7%$c}Dqjt*8L~#q&;vX3 zxMq&=0<1-XPapv*F`HG$Fo)L#sOq|)o=_np7#Y;rP;G#`WPRAvDt8zYFPW&^RX52!C-W2^jUQZ_M1u3SU05GE%jt(UWA|5s}_-{T;!BhZ* zGVqin{Ct4pejaisfpVo`{_r75@b5n=67Y8)eYdGhyf7syGg5i+K!f1^c?x91v6ILL z*$Gd?(T(hnDn7_mPKaWV8KvgN_lH(OZ@w4ZfvV%);u$^|^l3eVfe73DntNt}K3Pgx^A9W-|vL@yx-xd-tmCz>sY_P=Yc$i z3N_=4>GP1(4z2kJ;W(n+!}fl>;{i}ww;JBq5n7PEFvd4L>k!u`mH9-w&1ddoK7DS~ z`>o*-f0fG3k}>k0kAtd>pPAZ=>)Uy1ZGBaLKB&mKtL>My=9`-O$I4Kg8YYw0hVp`x_}0tK=t8V9=ub(q6sdnN+o0x|-k0)++0nf=9{xx3z7o-d9m; z&q#KqrWvtZEBtEV>jq+UnzDbz)x+!N16HE=RTYUnl^X;1;8Jg3(7sv;j&Gl1x>m0C z=a?G4U9X&VjN=F zcG!Jvlp_3;)84?v2xM^r3zu)3jAZ&jPOFq|RU*ZJr4xPkFHQh`G)UQxF4?<0xE!cObTWQAftuFj2V}tNIvKiwH{W1>h^#vR=AEDT6CU7GC@by#}3l$-s9_`G2A2ys0qQA2sKuiglRK5Eqs` zvvskU=|tsnqiR%bKUq~(==8JZoF!v5cI1PBJNjkyY7Xk6NNL7I=L&VZU6p9$ZOU7G zL`8>}5dh1K`{ka~>2f=m(;7tx7k~wLTjk$E*igrAmYG|ev7B`kTV@!97B&R)3{#a= zr^M#fI|mpqo&`1JFzfEGYS4FVZbiFd!}}3BWy(>(J{*(ei_iArJshl}xznjwVjE}r zj<^jqo@A+DFgZFFJO3jRg%HE9U`K&Lq^!wEd3{I&sT7d1*@z6kyrtNBAAb1MjURai zItRa&dUB$v_%Pydj`6XB=TP-`DxdIGu!qReQnVYDyHDst1Y`|n10}GFbZVNVI`|=w#HNF63C#AIz{5$!E zBLt(Mik2Z#^F%Qfkpy-pJIdpCh`+CpS@@w?rxhJ35*Z^!yP&&PRb+t<2VntpKZyx{ ztUFir%dE_=x^pX>-=#_TkkSD4rlUy7TOw!&VPp1<(;(S@F+Xd{oWZZrWYi^aY<_HQ zSb$n{@N`y7Nf3mqS;g0_sX9>ov3admde`;3E^r?2f5v_uO^HEK3W5CtEJMD|pqMPs zKDh6P#&CZp=-nggjxSnwwnJhAY9A?#MpAD$eCI47Ol2xuR@Ecp7^uLNPPuNIeo+SJ z+1a8s6g0{;2i(t8b>?U44X{1ImXm$9g|8;-lrv2`O{TC6LIs3mo++*m1 zMIe!hxp~D#B!fwM(aWWtkK5TOH$*fzXq zvP9rD$fcuqSRc^67LKb@XfhkZTTK4HRA-**KkLj@T{_C_$U7-EQ_3Bwq~-O1NF-!l z4|9%S3NIIW@8LVlpckA;Ws_gp>RF!}GL_&#QN;3#$GVykR;wC$(P+(SeDsxv32;8&eFSl6#Qa{?R~L#(ONe2SkHP5yD6 z`LfIk7XQ^cbA$bs@K#r6gA?8%E8f{vb`Yg_>?d7+@=tNGxh{*AW>i< z9K!jk$8a%e7-<_1d5@0Q!(>Z+792`WK*M?yTax{Ag$Ur^)t9T93{_uF0zDe$Q(q34 z&R76)pNxu{F>Hn{I1iCY+&q`SOH9qw%(1LbT`jLn2M@NN>U}3rMNL(znrc!X(N*wX z4eibEV4|!Asy^vU^i?gq{kr`?Y_^7~RHrwzl5KocMa!inX20WvDzH&QPG8ip#P5goB5*_&@R zG|pJZCeR?vJM3Ygqa_=?n}B~&?@b9iIRFZ0`O>#6TsO&8{!;3L@m>!X zoE4zfYK!Y#jXtVbWh&510Ty&HpP&J>Q6zEXp3&7hdfGgtH zk9uzK4$E5r>o|#e9x6RtV<~v4EP@MDM2p)IdnG_YQhOk*NyDyUtKbtOCqH|R+zb`t zi`MeK1PlRV!-Fu!=t{6^6Z=wDmNyoT+#dD9xGMak4EQ6wxEgYnd`tF*Gr_z?dS9Ai z$Uqr++yT#$RhXE_iJHk!)!R@@F3Ue^$uWw)Ysq=*@9(wb0hlNqFh-0!s0;rRxP#`F zbMF34Eja}lpIY+lM=g0@^6Clr(i~!7lOxbYM6mv^20uZ6KXFRZ3y2Mk(SXkJQ%l~> zKK0z7V=B&EBexKnLh-45IccVV1!~F;a1}M;ELd;bz@`IpC@`AzbAnLg9PbX{oSlBo z#R^mDZDWN4QDgR{+Ktt@K~NP0S;TtGp6fahNqxw%tREOE$1x+xPgCyAFc<;V>1SS= zM?dL-1p;@;STz0I<@G+C+Y(lyetZmcsh{odb0LN-_2U_K8EFJC6*@?A4ip=dj!pR_ zw;YzBcAN<2U)PRTu~B~2j`M#MK<&7CchXZ;Yf&EdhVA}PJ5F7Z`KcYprBUXD^Dbk1 z@;V70bn5`#sQeBoT1X(ECTjE?5UL$xg%^yf+IlXzuT&C!-ZO#&hZ&>$3C@@&$vz z#+RCrK!AiKhC2*cfWX+swt!#X?-Gv8$ZyE1>aKdaFFLBe3_=Kv=j^kGwf3r&M34k) zYmbAPnu{t>a*a9$57i}E0S|-U)M^$Vhlg%Wxqrh28(wN{*ZR~T9>ZEvc{BAqw=4x2 zevG@c*Um?cu1w~QBhC6puD}Kxg&zX_+u7a;6JYKS%(E%n@8HG>AA+4OpTRlebugYW zzhhE*DflyQoSz=a;jALJMNC`aJEnN4%VXRgR`W?0$NI4g$g^;43nCuRI2_N8{bi@( zGJ}2Nx>)zS;j|}@j@5_r`@)J$dERfFC$RD=-^H@QgB$q2bH@EGHCM1pI+1hWeCCKj zO;-S`_xJj=t>`SK4=m3i0&(BU|!MkXvD zT=9=SkSsDSY=={-?HB;xsFy>N9zTq0aC07$osr>XoMRjknkx&>n%h*h%5d+1fexA7`F1YZmv}4Y^M@S}AAYhr8*Er4DlNfL@%Zgkq_B z6S(1S@l@G*lI@Oru4iQ7*x#d5&Iyn)&m1;!9$n!uy7E23$iEH}r&`B{(;JsNjKb^S z!1oF3NrBDcEr*Ofiq&`n#)Z5%zN~9=3>rmnOFX;5m5TF)*;n!K@u*Ux;^B9U!J#1j z!O12FqWVjI5Z{}3;Osg%egm;1zvuqCrPW}JSgiJwYy{w$I2=2yQxWTsDDpXuZ}Hye zE$40Cn1#&6Vca;k)A%K)ouYWdz=SR5FWh$?72BT~Lt~Oq#;Vk{*g2Uf##7oqF(P@V znS~j3h8HewF3dzE6B_3y;xMjU9C8o0B>jpguu?1mYD6+RgR9CL#ZU334@C#6&Hg>h zY0hDEE}T=~TA^;cgALqz4qB~K=Kh{1nL1uzW$rw8raSUm)wP2^H{nPJc9y|;AWsJr z$amlobDV5Ul%;9dbAl%Q&V1PIO z#>0oh9CIy;^^MOQb01{kUHAkU{^)r9!6Q%bVo70n!&rAgRueAsWrbI_?ENIG{O8pVml+Pic|_(nKj6%kmb>D8 z2N&^}ol8wMapL&K0Y2}LUWlYmBYcX~;c-ml%sNAw?{Bl|i$}g=L79IGf0_OSZ{&CI zn}k(OZBOAo~Vtt@`*HjNP zcRr^6qP1_3?gQ~DR|h~|TsY}A%?k|8T%>rsi15M9E`pTTfa@h?q^>E?H%@o1LmS_f&cZY9u(bOfi_0U&mc2h)eB6A1o(08#r#|S%aY@P zH3xfpEC=AK;7gG$hjV3;!4oYYIc~#*g=2d>52&K`b`AsCe-D$Dwp|7!Gw3EoAR*hXOzQ3NI-nB4wtKeiDQL6J|0coY;P1knQ%79?J)#tAZ1+aY@5v~~ zx9?~jMJPsJMFj}#>ypc1-&E&+1MZ-0!R_~iUDdW~zkz=DPMbJwvk7b1Z-8O<#Vf|8 z{)1DzpsM3|!nx;nN|L>hR&jdgJKi8|W-e?~3_wqaE?ELIOcQY%y5$f{cLJJ@TTy}+ z7@p7!q<1gePLBV9_rq>G4_}nA3aF@8-{h}X+#ftry7RO>ij1vmS}8|QxHGV+2w5T--nd*tL@xd#arAJ*l4!H4deN?N4&cpW0!3Q#x0yU^;`>H z6K{wRW)%($`*XPaEJBtgW zsJOpfK>xi&kOj(uFx`4J^>GwhsFqj zXT2HDacl-fdpPHSbCX!Tc$g0o4*6c?7WbpcrnHfG*$N^C-x7nGn`M)$6(0(TYdZ9nkm8ZN4ABnej^7Trhvy_M&aH2i3NK>@VAj=i9a?Z5yx#vaMbn&%PaR zK=p39N5%zgiV2P-eIK(Ra~wZ;*KanrZ7JW6--2YJZTn%hUwWr)y#%I7uXuwZ!f%Vm z$h@q23371ah>YFhyN|0LXpP9@K0yS(|J1&(@X8gDesT@TvUAE4nzv@ zI`sP)!>=)dcX5Tf+YX=oL*F6~bNkJ*FOJ}$(C_0ud~IVloAv>u{(N8bFV$zQC-42s z8y*4NlzVWl)~7z}cW?de>DT^g|H9{uLxH&yw9He}-PS+g=N_mQT>85Rt3XAJ&oMpi z@OiegxUOh>txc+Im`QI}Xmk0dA&U&A&iFPN3( zD?}-xVyEHn>a9A@fn^DDSiKF`?O~AAANU*n&TlGH#ru4cN_Ob{N@3PKrsqwbO{0$sM1u59le<5Y|j)*O5(*_;!!Qi2ei2 zeZBL$5jGQksbazQi2R~EmMy~Py25!I8;_*@&hEJSOk&tMy5mh`H zsmWgBBG^=J&C{r--gg#`JxjAYZevepN4R+r0|9nc%(H{<%eDKo;Kp=k{j)yLtIsi~ zr!%S2WLZcINQbs6w*&T*&#UTl%>H2#SYF3yD_88(lAq^1#a9`5@P^It$|h8$JZp?x zqyfz=70B(KW6vlKjxVHUN-uVsds&NYYI8**&?PPTCmWWV6yZt?Gcx`|6k zCV+hgxHsCzWL%NFgxgC7+r;qbJD#UQLri}6#2k$=Kz#pb*qajIlA~_!9@}SJ}zDK8i_`HJ#Kq3&vbiy zb2HAfE#Qdw@4Z+)%JH4z0o#rl+YUW(Mh>`2jL+VN?+a!CeF}IcbamqXH|59e7~+}u zMeoc$gcD|bI>PO33@;@u2mNT*;5w`l*8w0blby^Uv7vuIx|9q0CG2 zebzNTw-~5}2?8$euj^8X`^xVb|ywK~zuP@2qIE+iz zO$mn1)q5A`H|P9pP{Jg7N%*}BgOJZ3tiW7X$Uft_gDjg!Qf^%Nq&ughMWx)zC-uOu z;;TXmis(TCcgj4$8v=jFJ=lE4FBSfg{q$X-^MF9XSEU!dxVzMfy4Ct|W@do-_pngf zPFcI+)6x&70^GP~MW$Tp<}{=K+cVuOY_Ht0jk)C%AsCqgfGuU%sH@&K*acEt=`$D; zmo{Lm#ljk;n!dzVgxvLte@ihungOr{21hS}?SV_6Di?j~8h}Vf&*a}6F9!K#tJeaq z4+|;pD5z7M&~;ny{1!^yGN(6+_JNznv0e8025&hF?^y$S#)x^T`x^{cfayUom{6T( zpA#`b1g6h1#USaDj^v$rRznpSZH$@N`z|K$3rCOM)N3D;ibJ`J$vcp+vB-!J8SbQ{IP zRi@7H-c`ZxqIpa+ga^a8%h8*RdR_1cug4{HSg7%IZqzy+aYNikl+)7^IK|LZpD_kZ zvT+^d#vQN;rN0KRXC`JowueMC_jm^!gQC&;{K+b;bFdPuJ>!Y@BTldJ0_PY1&7${X zv@FY*{3Gj*+z)(a(@`#U+?Uu!Y&7TMD+?pjra~OkuO?<>upDy!d|*4^G;j*g4;h{t z$Pi}-xFjqrup4Z*0AJzxSX8PY(XL}Wxnuks(6gh2dUJQ5kIXKGc^as#d1_d?QACa0 z`e-V&0Bb|Oj)5U=QxwK}*Q`Bv6!gFzz$1X~3wWAP6o_8ecgz-g5lVQPc8{SOv+pCd zfG5_wHY{G|t{oQye3DV*b^>a8oD7OTQ5@!6C6dY3y7zwv!(w;$QFitl@3 zMu3D;3g!y_cgb%{&_Wd$&vPrQ%P%?*P89J+nPBiP+wfp#_^n7;0 z@pw#gA&4Ba9u>rx-aG%u@$JGj_~H2CV|~G93U3@=^Gtk5>+seb?u0osRr3~JkA_38 z=S&ND`HZPZz7ug-7$v2~B%kB@f* zvFtx0Hbbch!-8+zjrzE8*cg@TEWuhq^D*o})mYS5<;q1p`FUnxO=6LSv{z=N zF-B$P`*)77v0^^vU@N&tA`|+<6IvLfdNfwOmB`~u^?Ksp1^c0Q(6WYTe*``^~h zxwX>QB4?#(~_y4E}`CTcxa#Q@`h(@x6i`na|)+*v%KLD)U4b)Ofui z!sy0P#TVw`4B-;0=>YlE50-YGDadNFJ<_7#g&Qk%6w>vVMOA7WVt*}M;p_sw@;v)g z-3^Rc51WaQ5vv`xb8wF2oHL5y_p-VWW3wo?H7>hjUcSh@uuQ>*lEHKxDh2$?rRHw= zxNlJm7}+LJjKXOb?IhMy)&F6E_)h}G2EeZuIkdl^b_DS@h=Cl;c#<^r#gI_5aGX#2 z`&xInfIKvfxE?-@GiNX+!qa&z2Q7oyV{QU>dm#m-MhQ{W@{L|XM-Sq%F(ZN$&vMdH zA>9jrNVCtnG*|Iy@s;yQeL(y=2*j#6cT4IThX10^ z-X&#jYa2c_{3Y?voKIcAGd*w_cjyCA8}D%zz)WCv!LL%SSJ4P_FdnYS?z!CmKXZJy zj6F;NaU(Y8(~IE5(O;j>AGZ3}Z1a$MsO9B+-gzGP%{l$hL%#w}m_>`}EpwRR=BJyJ z@qiVEJx+8wn6tu)f5jO2o#VTK1ZV!0<2xLu3lCowpuk1Qr6tkI#p%EfR;h*g6mgpc~?xca>(%;Vnv%#_p8*GV_bE0u1_}dA&*Liq{Kl z7AYGnQD}iP}6WxZeS{yZMdh z8-n5fgXe33*~eJ?!Se;>&y%?Ee8F%tgQ+#DKjAN)@8HJsrO$wHIV;c+0~N*@!L)?H zqO4d0YiS=+BG{u0O2CrAwg`0YuX5#=|hbL?l5DM*WPWgJ5PSec4r9L zem5I<8}iU8?Z~J3qQeuQofw?iRh?CNS}`KLOUZsQgKpvg=@vuwJ~&(0S$9 z2<{XPiXZ3YSkThcTGr8M5QQXabH_ehUwYA#jybjIEzPSuZfU7T%@j{Rqn`s>LJo6T zeKzRrFK8O;iSY4%k%2LmP_-P>ggLlT_=c7D^gGz6S z0jX%7W5MI2V@y6vc#n}hWDo+3udWC$2dWy}33)%OSgKrwu11&KA+cE9101g29H5{b zAFi*gi|1^ZD017usNme#0j*t#*{9Tmdi1#;5df1+j4Ix5_@;MFf6}jG#jmVGhY8{q zeCd1yKi6^#8Rl!-HHR`*D5al%4;=|7H4xFD-ne?s+VooDEAP-GZ|++mML_JL0_d zYr_k*!tY10Mrge(KHIwfRJ=tW*>vI{o?s-m1WZfHnoAI*;jgvTjK!Pzk2t{GG7i}r zy}p?yMCn@Ltil>{Cx;zA)Nbg)uS`cq+Tgvwb*8-&c?URt(TEvdHo+vF7~?HmeNOr< zZZKTUi2;|>neYHEBwOHeW_ibE-6hZZPGgXc!I}$gWyHQt0m2jd>*wEcfOGpqisD~! zfR827#iH^>9QuT-C6w_wA^u#WF;!Z!^^u^J;6mU8Io1bjnsfAsL|hYY?k;^ZH}r>L zzA!(>?h7sZ@e5n;;i*#`pW8xJ^x2TW-N5P(I`C88DQg@ZZ1)*e!wYHUsV z&dKwKP0B)0P}l$L($^(#Zo#L8sEruM*Y0MDbu98rULP>G#@Uyc~ALVHEJt~;i`R_Qtcs=?5Cg+!qY7TiSYkOR4To6Ng z@Ou0FEpBk1<{*rNlf&otF5VhFA+Wg}F%<}XbFf#^_?UV91HQfin+H9BE{4I}-tpkY z@qOZ55bId~Wv*{b4)2Bj4!kwM>fh*sU?G6lGk;~wqOp4>9cNf~Gv!^e{)ut=!q&2) z`~%l_^l!Mn#%S*tL%^AL{C@e}lf|Rg7b|JNOW|}!k4fD}p>fA7;pF&qoZorCFjxzU zGO{JQ7Gq4~I?5t^n^+B~bl}-_OlL}RxcJ$p(u(ZyQ_WK*7{!Fb?phsE@(u}RHUH%n z?jiaEyJdg?v1A#izE_=7_!7f@42k)e>m0+`f;YyVq@FA|$6VOac~4r7=Vd++`(g6@ zXM@nBr-+JCzw}z~dVM{|kE33UELQ{;i($Cz1Ex>q+OY|LuCx8QuiOhKLACIo<=9 zUaQK*gyhE%mLn=MrnU#4gC3L8jC9B`61aL{vC4&jTT+#+xYqPf3|+#kM752oKN}to z&deUKkb_9~iX{UMZ3>Q&$596NfpoZWsfc=q0T|PNrzgFFoc9+!>4Eg5qwmgfM=x9k zl|{WrY*veP=|b&aaXZL{4w;Gwt!cnFzVJz0#@HCt$bg4YowbTehxCrIv}8%>{sES$ z2md3d9}cJp*60ZiEp?B)Z`L3$T=7{`;;=kYA3TLg2E!>k2 zNMY*|ZSeoG4W^b=PTJ99GDnof=yrjhiG`=*Gsd1Pr+fHWn$j$|q?{~RY;ZD&Lh`@?tkaGYv zgibW|EfpKNE@@0}!0Lx&VHppJ32&KEX!`WkCdxBbEm%F^lNq!Vyey(x4A=N62ly3) z`&C||G{c*5=~c#iS{9jmDDk4$-Bk_($vEDJN33GZfDDf5^*XF`f7y(CCVs)E&Zss} zml98kAz)6{%AqT?q+cIg<77hj_^M|>uU!ish);`O8q(xB(2!;hsz(mVQE5o~_h$A% zwIfdH5AXn%AUEFBhb`(Zw~m?*@+_IXASK|RJ`Nee8dOof!td9N=?O1MqK}T0+)xf9 z3C?asU4hFv+cnv@TwR8{H>?D)wY|R#$E0q8?18681xZhlab^Y%<21eFDgU9H4Ra4^ z8T#d5FRcEV?j4|j)yUB)hE?yNkYJP+3$i{CTMW0wGoWb2b)cEOnSFD8Sphft^f1(W zk9fB^CFUztpt3O$KepY4OJmIDyxJcPX>c#%T^5W1z!<`3!1*DM439{y3jYHQX)mHQ zi8pI*=7cnJ-$a{Ni`Ec+sknaW#=MYLq{ijQh8;0DIKOB}n@=%s2OW*l02uQ3b_nl-$bGmB1?Bm#LpUG&b**vr@k|6veIY#GxpM=S}) z?m&d)*lq`|A&r+j2B)bgw;Inp6GxmWoKy47?8PmY@q09uk@<~p4|937oRf*tGfMcK z4J-@XOHBSZMhSr5@&q=e-=kOTf}clRw-b?<>xSh_m0Zj$4sTTS2`tRbMXV>Yhk{n%YN+Kj(Dk0-In!I4f&EKA8YDm0 zrRFGmbEErS{^&;+f9gk{F`85eEBlFlG+HC^KWRsUuUrawFlLkpdkv(OKNJW4%i7T| zhFK$f^%w1E_$q&+9ZiCFxyz?RWuCrv#QYP+0M%zM06Q1FxV$&Y3j;SCfF(f%gsMX} zqw_k?%5|{Z7+ck9sd3Q{9n~#ZMPQET7f;-urqm=E0UeM0+p|l|#zfReq<0-$U%r!w z6Ub;~J1#DqT<#gHAo&AB++D0Dxr9C~0{L0gW7y&hU@wY^RRT@5XNg7Ma|3(I@!w$F zv5nOwW*^t`m!-MrFX;2>uuh9r3YCnEIr~t%jb$H}W9uwWMdl2)hCme_D$yllaOfLQkXZ-T1$NC>VyekG1ERaHSbDZR*U5;K5x$ z$Y3ictlM<`;O|O*Ii#nUIASYspWFv+FBn&DDo%4xLt%9h9@!7z=5Q)qD?O$y)(X>5 z$gOICc%2dZbWL)aTyt!nx!t7rG^~4NQ9h`nYx@Yk6o01)Xmi#$kKCZa*jCZj+y&|* zM6Cb@ZF+uZW!q3#C8qiFgkg^Dk}sq=MTX|zCIg*H1J$6{&ZIGBHJtMHNUmdashLi76AHphpi?xOZ-XhK(p^`IBVRCDXu_3EG`Qv094p~jnvwj4#&TDO%Ro~q*+JR_ zKqbT#N`xWV(GL(_Z3LF4h=$9PYBkmLmaSuaN?(iv7u!wFMVijTl6VEa5j6pL#5vhz zeQrX;eN|~?(3%}}A8@~QLtG=rzzhPn-~uKQzo~VE(|MS^RK1gpX}KZySCd|^F`C#l zDmGG}c6gqud(0mv3v;5S`J@TYIynws$#`CPgWt5j zWxk>}^?5zRK*Io15NRbI=Q+b|;zkV<7wE>1T+v;s;5 z%-(`GIp!{@cmuY)=|MxV?2Yfg=^nkN84yoe{ z&zJP{{5RMg*KthTPCtQZrh8|SR^OemdyX)O&O3LRsm=q|n|l`Z0Eutu z++Ys*JmI?WM^E6>#B*S~Rc$A@V+18e%Zo$h(y2$xnCy&K z4&Bl>xS*KJi_-#Xh(79YVlv-K?#6a{VY2>q(dp7t2VH*QJ~Q4rRxQ5_Rvl?d08fAq z+!*F|L!908GWe)hwtm6dIPJz0m&YGQGZ=Q|Tydf_opIUdgOF+oI~%*OK~n*(zjN#Euf)2CZQ98Adaon!`Nph{xNP9oDIcby%Tq zdeE#Eua$-~#mSk`Mq7FxyLZsmE3PZqa9SuJ6xsziS*tY}w>cbG7hH_lS3&ol%fJru z9YfoVzA!PEBIY`#IC|`{?Lhj$L@b&`*kavpF-Ne~%Q+-TpzkwE-i!$f{`l|eaYprx z?piPJNx4OPm7cZmOYEyS^(=a3mkymvaaNm*`D&f=rW(0XG@}fTN@Tnqq7z(Q`y0NI z&-Ouv0t)F9Ly}iu-%5^?A<0O?G;DnIT4z7>D#&`UuvS5?ckOfGv-ZNpSEk2mbwYA(OCEyYGTF%{It+`R2-M=!!^MrI1$VVp~(Pu5AN4I#jkrF2yD;_|3&vhP+4qT0jQs%+ig?0P-hU4O_dC|3hJQ1r$bbQ* zp*vDB`+Zc!OvXZHJ4jj0c?=HtdT9n@^QQ}4Q?~b|4a|6LjMM#zIn!L=e{rpcn4u4Y zA+qu9K_f>1fX%Fq-$m({ignXV^3L&xmm9!?M7uzIUSgOw4YJU%+(xZY zdVk~e5;?MsDW1QvR^kPL&QbpiV&i>pW#}dmC0&2$N#DHg3BL1#V;POI2wup#N)-q zbI!3i7ATFfy?3r*D3Y#yXZR0iRLa2=fUf}KcKR;mmr_3m_rHSc;(Uxf>jn1*kMs&v zqL`VT9ZL}N*5)mgjc2F*mcq;2H4OOGc?Ztgg7-Ea+70gG*aFa{?KL2C+CCg6*l2Nj zt9Mu2@$TUHV&}HAhY@QgWlr>RSK<7Dj?`L(KD>E4JTsReWe)9RrQe<+EpR5hyX>69Suw?M306|n5 z$RU5@_gjn*gPsXPQR;`y>_AIe{tZqN%zTK_X4Mnyow+%w(-4F2rr?CCb3hr<-*lb8 z*(T0NjShWhw$tmVD7iO9N)S~0XI_l*!m}QIjwmvPLuJe&v7muDT37Cwa6)C z{ViCu!%=+9H9gv;3U=yke}1zmKlOPb-qXhU(&xcu`PBo$cRObEefun3$K!VoTQk?% zf)-5=JA2!w_@Z~*?em8H`RqhnFzm0ncb!f8bov$BhVy^Hi(#1+H#Bg>b95sxe)jBp zuw#40FAluHmCw7`p+DX-n_$k}`yT^w+cB=)YufV3dj&)#@MF&diS8UypE}vf*Ltc6 zWzYRV!V-b5FTVP-4^gzIwXs+>=|^PI5q;<1Ax^>3`w4Rr&!c56y_6u9chIKq&;^s! zB#Ut+>Vpw$6fRyMesa9rE>{S72|xocIB(EB-Ja+0%}_NCAAg?vbY*_}7kr<<_5iGicUD$yadb`JM)ES=1!b$gX*Scfa zX2R&;FM+z^vush4*M?Uj2m8}lB96^i;4@}@oM%0>%d#`~M$Vr@H^?_FFfItc@CGoH zUfKCjp3q0{jYDO999wt0iQJ9FKtS(L&>w*DTzfuF$zPf&ct({U^f`lOLq}Dm$;=d( zKKm$7MeD|ua77K_BtyKo`9+0rT}RKhEt4F?uJFY^c5fPPDo+e|#h z2^N8qlU{*a?Q}bHHzEIUpFV)a)1!XSvl~&T^Uf*x6JMOl30k|Ab}!;EdUhHc_cXb! zw0@~o!mnlfg(BO8_mL7G*>B$OI=*mvi7&+IZOb32G0D4vr|A=jk zc$}yBB=plp*PN;3iqk7!Y?}jl@?U$Vd@-`S7Fxay!^Y2iyYY7m@ps|$;suE3i;sf; zi|8lYHl%d6vv-$H097xaUF%)nEB6sq`ISt5!9&+N-Hst1$~OAYJ{VlryGTzsee9bX zydw+W%rE*TuES;o&xMSk&%(T0*81if^}Q40-w@yPy&r4;(oZ1(;xg^^d)d(=yQzEh}MU*i%WW2t50Z}QCP97_YU5U^^@~6=t!8D=97HP^rwH! zYefbDKW7F-LjO(JU`BTHNfAlbqC#(YogbbBzAD^VcxB>CdF2KFmb^yA^`+MoK@&dr z?YoHcfD@LlO3{yfcX)8pNcXAn3q}DMM_gL!95?OT0KBbAK@px5c8{dEcc!#@md;N2 z4E@B$z`(67=}$3+zyR+R&r?0sj0nze41X7XuMK~oOG*uhDm6KJ2EwTsYtC&Z+;>h> zxY)PZ-h<=I%N$rdGw=MyeVp(uqdCk_#PIkYpPFzPoBRstE7x_Cu8V_;|E@ig`!rhWNE zee^{H=;e5AfULGZ;q2~PHtfZm?mF0sHrQX)Ch)|mtr;dPKfWI$7N$@d1BGw9A1awh!M&E7N7-6TIN{1RNQ-d zvORGZ#}yTkHPk;tI=*E4w2rGk5BFqAM-~cV%k8S{WUPh`nJDr)5jhuQ!~YyNl%{XX znxpAUCAfRwIv%|`o`}xD({G``$2b6ZW0-dX^^5^8;F{)@{5$$etUDFOfb)stVB)t@ zxOeRNp6gH(OPdnTzyT}@Ul*6mNl3-l)t{X?DJi! zA@Wnt&@?1ZV?0rBHr5krFi=Q27dQ~PxW4?v{ds5}r=@!SsM4Zl9dguA}9ywuUb?S;A%)M{$EO}P} zEEKc{xF-s)gu$$3c1Smy&z!Cg-pMKw=Fm;Y_mWaL6W;;c|HXx~iB9Dp9N^mbF|ML`~!0%?F;iYOnYVT!pr3N ze`xuVhp=7e)!NcYpjdA*Qb>M8r#kpexqO1iB_Hivm|AFUQ%T6^e1{7T_6OWT%NYjn zvhb|JsbNB>;q8jw>$X{O$h?5lo78G>^gwS%`}xI-mP~2mE;^H|E@3F}z@e_5y?8m7 zf%!wjcN5J`mf0T--_)fI2YFKy?mT)}C=%Emimx5Vo_&Ua$G<<`t@b{fiUJgA} z(C~%ptLIKOm22PNlR7^1(jf2;JjWtHUlQwdfRmw0`ndj1FG%zfs7s>vI^V&|P2uIj zQHV?7Gv28|i?cN;XBcb(VWcoE3N_87dLGT z50&hq1|q=){m1ZJeFhfEyPesxzbTiD(YEC^&kMQ?GVZw#N|)rhVdG(+4j2@giNMd5 zD)+|sk4!!Maxch06B)B+3?>rwZ^qUV-xmK^Jh0Tr&r}VKklf3 ziCSfYs!@HFB@;mzE;09_U=8k_et47QYtn^blSdVXf-W1*5g4!=eTe%povG-Z6gTbHq7vQId~)tnny0L=)3;Qog6jX;E+1SNhe(DgOCcn$i1 zamM2_z8B)HWG`P&!frk>VqmDkbuirrHR@YY3#JH84$gSsya5LXRCqG=K6=bc;@%xv zAVha&V@U--*N={vN^NcIFMUtxtPFO6UQw#MbvY*?O~`$V=}#YX$RM~pa7K?~Ru`Af z2-dC|mn%jMEAbM<1#z@^5-BtXB2rLsGa8xD?*rZz8R(4hbzZLNuweeu_!WnVL^t}q z=DX2b8?S%r_wINLaB7+1)GTiLy?tC@?n_l1mJjCEfwbxn+zVOk=;D~WFI?!ze6o2s zoyT7xr>9v1D1bVt2dEqdbqBj;&R zb2g9B1JY};G3)U}nbq`og@{huA3*{Fe2!O%`2lgFXP-?=qFE|&PUlqYhF~Ch_3+mT~kSX!kO=|6aS7+W3))xTk%%u?N5UA|M9m zrFn6iCpSy6b=dtBY)hp;B>i2p*HeT}zehU6z=Z!@zc;+;_m*mJx%e;Z_wHF?0WlDp z7C5UD|NkL42fF4K^WBHGhv?>-))-LXu)V@PGc;$KyDX zgJRH;Q^#E=8ENMoR+#yg`aARIWotZ*D}6<*x&$f9wbT8J~xn z)?8r(KXU`-Y|S_|h0SK6o0`tkh~LYPWOU$XH}Sy)I&Dc~*QbZP`Na{GF$E4dl*TU4 zuPd!wWA@DALS2avp%S;T$>D;JjELVX7vJpYDFz9QYYi)k%1OtFuVD|5fDBqP7VjD7 zj@Dw#K}y&7C{#0{`#@7ddFd!QXSJU=Q7451!cy>qS+ET1{Fh?=psN+P&iL0{ggqhkhx`5(=fXkj6$P zk9+Nh(UER%kx93gJLs4{y1lelSqT@w6FQEk%v8qQORE0$!~|@QS|>C%^ZG(^G;!{5 zeee*Gv4Y&zbML7&x=Z8<6acZ`#+lJaOE4u|Ub5?k+2Un$-R1a&V>@0iig(V`J)%LxYS_Z=CV;!OD$Zm$*a3 z#%!BvytPmBI5d5;egl0@`E_qQ z`To|luGp4MAj9dph+hVt_ji5Y1$|=wC-r@?+T#C?>EJ*niC*tE3!_0zuV)rDuukRx z%M{RPft9Jp_`&{A=ZLqo=?0NJqYWZyCGlMMeCQ^8+z&3&STrmTa=8|*Wv-Jg;x6vQTynCussxoO zJBuD3^Z+2*k`wq2ew3JC$%1-X(>I{sEndXcCFgb1=rxA$D5TZ9X#rn85<@=oOvRUH zzZ~5ezN>V9#r>t94KaHU{E3Gz!+-wN@||q~KUZ45#S+8t7cF1VZ}tA&ame3cclN|s zp+=@Ycyv^zZUnfZ17?Me59CpZ8=k<@A>{4=j_XKD%_0Q_1J;-8rYc>fi%Xd*UnGqg|@T4Qh|> z|5!Y92U$tME~7OJj~SQZn`astZx}k}qHI+1&QAkMnJnx-8ofQR2x>QnZ`w~M#(!GJ zH(d71XWrF><9>1x{|7q0CVU|;B`wDY=pKpv2yr(jwO)8vM8&0*&s!F0o{ZIG{Ly;C zEhk4XCRloMDA*cyVg+G}krE-R!}cRjm?aa3p-X=sR!4RrufThHm($AQsQ&1u;8^&% z*k-pscAXo#{lU^=^El-1f5;#F{?FTIFzNri5w`u4n>ehak)zJBc+GqUgn%0-FOnyS zft-}6X{Ip$kM0D&-H?+*7E;+My(qMIkIVb`f~GB28|r|yYCm&tPuAdN+1z|d3%7r0 zH~WYE7YHr(kN4|t^tEi(TcP{ZdY&{k=-!sg!&5oEUwzG+@ymP{Ui7~ynWB8s-YvST zkJ){2r8!(QmCmh~T^G~0{jik;_h*UihSNk4Hc6Z(|r2(+Mi4O-#;X}|We-D(=| zr_Im(^3dKtqxrjh1=46dH4XRan`;ESkQHktb1ia<`6Ot}8)yh0+)pz4Uwn5W`meNj z%Zct|uzB{*uY6w$i2!Yg0kKexr=nZF= z73WNwjs5euP3L3#(byeF*W8@jO_~CC4IuaXlU@SCa7rMB=h@3!i@L(EVbSm2W1ogMw*E63loNhv?Z;k30#-h5#3|y-B?4T#iP#i;WVUQB@<0~W(77TiIG9K7iv0DC(POqKQh3gM3-e#r6 z+u%3k^6@|EF{i$BoM)u80keeT39&;k!>LSz*bamrM#3}~{3n_)!e<)yp~+Asjl@Kh zj|h4g6al=qC^{ne&hPmaN4x_j9;sVHjb5>|97FCCKjax<1=&$02#x`IkpE#SebNRb z9P*{@t#t)J3?C*4+eheC1|GBYPi5mq#lo%d2Q3u<@IE8 zpFK{764O9k9jtW;FSB>Tza2`CV-BtX>t46u%2k@ZLA!3W2-tsV_M(}`)zT-7(<|S? zdDQ^t#!Z!-a0N2rHEemKY60N$V(ide>lGPma~EC09y=sMD)Gs0h-EuOw*)Ns|w$&iVdhW#NZMg}eO$5JNn z6)Z_C-YWU`#>#lF;e6LXLK&nW3K=M$0Kv_V(<&P4(rzH!W!=G zn1MhqB65#z?R@9s)2q`%a+l4*xXi3hP@;#Vz#7Q4}RBBhKLcgakfmDF{vY-HmtpIIz3rDP)hTlNr}I$ z+IBMP>bZ^7T z`RTq-vYi{F(GAy_OT&%tAtD&|1U_(@sUMPQz0m)jiWXOJ$58cPQd4H*(eWGQ4hwsz z2`uK2MeVoZfZUsVhnH<^2iq;1*^~WweykfLzNMY~ee5%*7%yd)tT*dB4<8oxY2T+e zr9nLlw1cy3@|;M&z}%6Oo>!81tWwP19221#jk}B?!Hb|+Mil~Mj^@*hv8neEO*Yt@ zG}3u}VnS56b7l?F(r_$IGpJ9A1e0w+)pHzHSx4&Kbqh>4Ir6Qr^Cop-xCa-`;j;TSNN{d> zPh1$$AKJvb_h(?3lzh+|Cg;YbjCtwi;9Bf0dUAnt=gn|3<&NqWJcGg!Kby%!E#ur_ zY2Gb`gD--Xe=S%%E1l>Bz?p$T0GH;^fV$f7DacxJKQ-<#Su)O5QD;VAQwQi-??A*m2>1qi(Z`g?kMx98e z2zux@Q)_WzYVcq?fIWE`??xK0)J$71RS$;creCX}T@g}$O~ph~uR$-APr_?c*HCN} zSZS(h@UFS3Y&~BCHDR9NPoR-29&A^AE{(>@e68^}LzxD5W&N4Ic^={3Gr8B+3mkF_ zUh~W&7_u|jEGDJ$U<{?VrT{>m{R(9a&SlvXpLcV|*r$fxMPm>Qex+ke$9@jkr)V?Zkry6wT%~m* z&E&xuE-ueMt#or!^#>Sx)Q0BpveMcUvzX!#PkqglXU&+P3w+~di%-pP^`aIwh38?- zw{}|=f*H_=q8GCaJ#sY8;F+gWAE`EsJvQ}xaJ}xkQJo}%>vBD`0c5r=@?phGrnYQM z?@r85j?EIt%Wwhr;QuddbHVk@-n_W%$u!b;0+k`UGT>xYJqr#C(~H3u(c!L|+xTEY zX6c=Ag;3n3YYa%mJPlmd9)xAmE|zAG^~cP;Q7=%53f^YX^)M08;+UI0aXqs`{TD_Z zIZIrA?t1x=$pOC6Jkr;$I+jO{1IP^&aON0i<%#*xMZ6_ul3}xI$2P@3hx%Ekrt^@8 z;J?uY+6vAe@8yZqZ+xOg%A{pPu@*gZHsP2y-56N}GBc((WL=@*TxoA%Utt&gOE<|^ z*ov|8$SxL^8K)I(Xrp=)&*>MKZ<%^zs_d9X$qzgM%$Svqk+8$Gu|InmCHN6Rc5HWI zbI*o3y2{38WgY$zcK)6#g_{g-Q}ubQ3|b}hx^&J3m>H~~YG<2#qWS2D3-$Bcpfkzb z<6Oe((~r=V_vOvhn{bKzB~!pqZF0fYCJONw5O9F~i3)wPlf5`cD2gzMqYuxH##@l&(V}wiq(#^hD2wE%P^jeWj`qWjv3QvNoapx)%qsSTDFq>037ae zi4Q?@IWyequ?;TV6-ImMZvd;YEL+K1n2^HA=3ML6p}RE}vS&tbnp4W}UWt6^DQLEt z+td%~0l>4O_nT+&ELe?W`Jdnm6I}y7Jz6}*-j+|6*jxKduMnG(4J$Es&1V`ENb1V! zdaA(<)EAojq3C6aZ;XL)Kz!SI+??rvAp{ig(>+a42AaZiJo|*U9n0 zg>XGwm&ywoFZAcb^U&W^h4eT1X|xD+K4XN-QRxgL45Ksrt}`4d1RUWGpNALLuCww* z)VKX^L-vi~#XbMdOMWM2{WrX1dM&izcF~zhDVHfZ_VJCQ%$>h3Sj0GP#ZlHg8(ZM8 zY0AgItu%*OwumP1O3)nUygUvezy^DIi1Vg$)<094vhakLzUs4sZRD$W>jHn7u@>NW z(mSyN47|HwxbUpk_(nX3Q2myxq(AfE#!a@1g_W*6W6FrfnvAjeXm|Ham>fciZ7XS;yeUOMa98^==B|X<=U&6}>!@ z6?->Hk+RLX@sa~z0}$p&jT)4qv@K270k_3_wzjg=k9wM~<+I28P*_3U&+|u}_DdO9 z^54PV(&yO~t-Lqdgz%pR++!|^7sXl>vfkC2Jev}sKp$ThjMFPahV$XO`43nC!oJgFBWua} zKD1BlWAn%;IDbfwSS#!D!M_@VG5rA7?wtl!izXe0bH$#@7ArAE{XnO5whCdnR-!L> z&x)JKnM!Y%*ntMw%u(KY#ch^d7sjlr32HvjhLJnsO;Ob5`Bb%|*Yb@_0D8jMb*Djs zBxmkdTEvi(xp{XFq}H?*gFKI^sXLRr&l`g+;=sdn6Jo1G9N`ua%LQ>1vkv=8&-CIG z5SHHihN`;kp7eI<3BZL*NBD=I%;JDO-t6g<)PAk_pR51~c}U;%83TaYA&=S(Z`|Y+ zB~W^Eao?3zFgFwJ)DyP=Me?Ib%C=a8P*_rL3{G@Ia}Ep*H3}Dw>@LOiu`)3ETeOPj z&~vb#XmN18Gr+Z6Jvi{R6t@-(99;I}@quNTj5*zpNpuj^{HbTmgSqJ$GeO6Q7A^MF zo34I$?NE8+z1ULi`<)A|b>ch^@y0Z?n$lCH4$VYw)syXDD7+xsbAS7)vBnrT=2Hez zw%xySJzxQDn%W~)Sz>I}m_=TF>#yPSGN6X|GtxxXy}LAA;usnKoJ|k6ZC$-&TQpj% z`EErH4O<{B;Z69?+A^UaZv1F1H+3C1P2|~GINrv2 z?SWd-mk(`f$;D11lGoix=p9Vf^`iG7(uf5Vy)@xyyxd6iY&*bRVs_2XByx)|l zf^fg6($_izpOu8)gca{P#h>SdJ!b%Vvi+l{{7j8y%BX#VM~>2p^vq0pabxpI zX;pmdTIB!1w}wHkIei9$)lBG$tm{Wtk;kRj467a(1h`8g)fzZ|XoJKU6d-W!V&Ds) zXuuXbS`cY2U+G2jY`GT-^Nj7uq8c1TO4yzP{w(WnI#Fx{rzVQpTddJkhWJbSaiLMh zYc@ZJ_?E;ZIzFjb{A<&d4l@~2=D*h2Jg`|(q{RkDnqP=3Va(^cZQ5{<9b&i6wPDm7UpIrtk$@LZ`e!D^f}Y}tTcreQ*Ow+$Nw|WJuD@{d0>yF zP5X-Qa&n5`B=e}w;^@xBTkc)X;E~}cbKPEs7wqg$OLs;w&ADGPuG#Ma4)+&x=rHy0 zYw!{1nG=ym+0E*Y5p}qbiA(1dxWyjVLGL2=O=9w0JQwWo0*`TYY0<5ll>!{A%~(yjF967;N+48?`bsS9c`5NpnDsc3hvo0aLXb2n|jvxbq2fLn?06>b+sq*>!5zGP24NtTT=Y51;H0 zHgY6uv*)$RaOnck492JT7T^bo(`eje1TP2TG*didE5AlgV>zGQ&w3QzxVinlz!y52 zVKK=AILG#Zeo6Z|fivlgf}AjXLG(R2IyHS|ZCY?GX}u_})U3tqskkD8Geu!|zS0;D zD3;>>i<7(^CeRdIagx!=qi$De5Th$h<5_(V_xOn={D*ri?l->|_ZTm4!T0fS*P=j; z7VDX-ftT#rNC;OPX^28$SWOl02i4aC?q!sA@P8j^A#64H!t+y_@dJn z=b)JNE^Vh@Pz*bAGvO_ag534-t6MPU2j-=z-W!2Yc4HeDDFqUNTm1n z!wDg`dU}E1?|jEYUOQoX#YY}vNge(@e)`cQX2?b#o?4CxAJWfX61d7As3XA7JtgpR zJ<@H&+To3lY7Q6JG3hkS>*rV2!h_=2McRw=% zVtAwzhqu_Z_iu@Z*<`=x5kJA`hD|}6FeW8jYurPO(G_dQJETTU1+iPzM5}g-nC^~^ zJ#t|haXRdW99oNBG)LKs}jZtQ`9$!5%LG=ek3T#zZRG=ecCg994AjK0MNY|&xf zr|G%To9yM+w#SEoV(jh0xrai*Wpo!K^)gG+ieG$G485erdNj}bb!2AX*fCcwv2svf zOj@ivdwsK`o12eCb&F&N*gjwd(Z$&w_zmPW5w77MdN(*v23$fg9A;kUIH=rczdOV8uoGaG zxDcuK%$}>`)M+22??HLMX3;TtBvyNjA_HXPszE zwg}ETImhZ4bRkhP1kIN2=KC0$mH+yHyz&zsEJit83eXj>$HT ziT?DJ-YXaMry;?>Rfhw=d>dYC@OROwxCd=}M+ZD!1R$hIz}TKX7}>bcOzfXpVn4Lr~t zY0|FzsdUftcc6bxpF>)L5nM39cy2baAxkb15tJ+2bQ?{iPX}0JVKWp2b z-Xq#Zthj~G4F4C7^Pt*Ruug8#Eq-JSQWqJ!A20prIrnx)XIy81ZQuF{_|%~OIb`8R z`8`Sa`S5|ZZn}k5XU8;f*}LVfmuc`FxX4dg^R*5dC(8mdCFG(imU6KsK~nW zywLjj33v2|*L*>V6aU$k{xJ2Xn(!+hQme_8lhhe~TVLBIHy)~@u6^;tTM(l7qm z9$5V`)9A)c{%76d0+l3qEhqC>=_7M*2l3COk1Svv9Kl2X@~Pj8+Xc`03yKzl>;BnGu44qJZD|o$gFBZkyHf0byYm)pYsI08@IXOUIKW; zEK9rD7ae1M?#cfg?DQG?b?;vWSA5Gox{&5+&tm+}8*wVl;(IyvxpLuq^w6M!EJ+Xf zO|F57_)9ApJru))_ewmMqodr$l^|!QcZxzy#c_7Q>IUVY>`@B{KpDYhh+Z6v%Il)+ z?v`kajmRCiQD}N=Cj#k>x{xys2(jq2fd(Fj!eHVS!hOK;nvs=Q61w14&|H&j_RX({ z0}aNy%HMB|=4eVH8J%ErnpvX+qs!tC7pBrUCN+x2@d;foX&l45OJ-;se{cclkeiSDs5lv?SP`SD(k)-*&dzpO~EO$Lh1Vf^;^mK1&8OO#CgaiYj8DZ$~@rul|WccsgDJ zr}+{tLb@+99_aSUc-_m7LjDhIAMeD1 zb5GjBkI_Meqg*96P`Z33;BuNv_{jKLhIiQ-?iGn{Vj{6`e&al6gl}#}Dc__V!acDdu z{DT|k`73wCDK%#geKd`k3GWU7H~N%BCM3R5lMpWF{ZIBG8FQY z*alLZZDX`5D&6ASwmWT0@O{}9;()OqSVpin2?;C9L2R5H2_t>BeSk7XA5m5nf8xJW zhP>k9-p%eK_|Nc7!6Jm@-_rMy-_7ocdL8lYPwyKK`9V|Ld;G;P83v4 zR4?#x%*QPc`i3*m0IKn%$U@pxCq|+;aXb9rG34#NdSGx?I7ZSG#tz$?ehWK~ZY)_@ zyKPB(7#l=G&kU}#1@p9vA(*|0mvb`OXawg4UB$_kck|Iz7zLDx67QiLz|JKdqws_d#uc|WRvp9vVk0MxnzV2ZP{X5p)C+n5^fcy_p4k>V@gAH~>8t>{auZ}(hJaoW zpyOVdT)oKy?(^U^NAXvqByjocs->2*TyTY~HbfsCo-Ch1;+Y3tp#e@p9Iy0=F`o z8bkhA=JA-!VkC^*4=d@gm4`TE${kVnm#I?*?!K2Ig z{RntxUSABAp~=Vys1K8YwP@F;erVVw)Q-s$JM@d|?~th`{>3-Uh{3*{;W(nt3Uct| zOpwm?xgUx_9lN0#IXG_?oRbq%o9?LAUvmunZ`*;B-^X5amFy2G_3FO1?jcBzsXsOu zXFT~D`x1dT&fZ1O9x6;ojozhyD)}?=)Q1Q>gLwhlGi|Y)r(DvF*ZidVi2BFOZPuMQ zapvwBBZe0MeooZ#wVgsfzNQA3*AlhjNg253Y zg6yAD^_s;JQGI29#tOM3KMNNBi?26ZPnGSubq^>TLB&KjDHWv@EO1XTqM$K}NkEQ% z-5ryR_nUKG>;3=l8c`{qqL;0=j_rvKaMmqu3|!wkmR}TK=;L9JNnZIAcwSQ2sjTGb zlVI;Vxu-HGPF(F29fs&t4j=pc2(e%uS|l5ziXt6dyM2vbj5Ti~F@jK%BTTQYZ8mL9 zGv6_O3^B5yKtchb7>y)WcX~1K7j%yT96iFuKR;Fp-D^eu8&_3_ohcaTx^&wv;Vb)o zypV}!c;?FsHpsoU8O(j`onP>}oY@s0+Mxt$ULzmDA!J?PZp1*Gy-I7qLv0vG2>Y@#);&i`X$8(V?A zk<~Nw!A724kvS$VBt8*24u=tt=NbEudw?7Jk~{(S9X6Q0B6n-)kL|erLj`Q1_ShHTFmki>7O`_d<6vG*9S^fvy zJ0vTK)eBaPMall3Z=tuboXc7TO@`Mok{V=4#*`J4d#__Ai)-mS;dA(JkK)BNq(oV_ zUO>NeVap|sWf1PT_fE8D*`6ukO!^fg63C!{CVS0})Hqf8RYe~f^h5?1Gk1C1+`x=U z34sm9B+vj8zRYM_K+XjVgGC~oPP+Ccwaj4Uz4W8u)#{N#SjKdumKi1QQuCcjOQfg^ z>j+pML~|+7nY}Jy$3x4dy`5kJvM%>J>lFP0?OmH?86W2&BCzfRA6hB;oC59}dK4g_ zNvvdcJInPPXq@ynRXU^eO5)LEW0;>6y(ZDK&m#?U?*MiIZ|uOca_yr5GlK>iOL~?| z{FJKbvT`e7IziayH%7I^8aY2TV}LjqG3u*~{Uf6`8>|V=q{4|pw0CTz$VeLVX0JROb1g3$ zB?eu7h#zP#zSLJ4o)Jlk1sNc{i6NNz>rE_wxBV(3{G7ku#9R@P_uHFTI^8dsEP!EP zW=xg?I{Mz57$bw=^^$>}UL9J^V0*;EV{wo%%9$j(z?%39jH?)I{C>H?p)9y7^%Xq| zplOT_cUZsf4uSt7)>MT(XqRLvcYQ2az3n#)(x3X#i`1uKP%s}6HNEp=mS_2-C~+x> z{hHJiNWW04?K6oTt!-8GBE`i2Q@>2Hti|lHw8{P$+PdvajJ1{S9L9=z4l;&}9-8-y zy^AQZuJn;FC&-XJauaNS+n0DooCwBTjuSwU!6B@_0GkcFLBv2}ZgY3-y~Oxj1&% z1ysJ^QKbJnbtgTC`~?l4M{EG<=4^%6;I*hx#sy$iQ)etiD$Y|8q}2G$#vId`!6E`m z^6&{x?79-;8b&vSl=RdQLkkF49VHu?{fb`2)GYHVD@!ntHu+^6Im-)V@;i2$Q2@`O z7Yp=FxpB%&`EBP@S6R*H#I_;l!+**LbjlG25blaPinrX8VG(SCYxhW-H<|P)hDlzS z2g6Ph8%Oe$j2ejMtSUb*v%0f=-rhizhB9pSsB)_#P*%gKA>;Sa%4yF05@2m zxrFm6`M6~k8+vbl-2#RX4~>m*s}1va!uyH?L}L>lXGJ3n%l0li2d zyD48t!vM#YKEyVSyp9f1Wx)B$yd-zdTaG7R6@| zwW3E&4`;UH4|k7@iNnxy?}@Xb7g&8gXx~q{F)C61XU2mpjVQ!Of4lc2&g~TN$oEQ+ zr;FS;hk|XM?8T;{QUpc#&KaJNL}W-TVjPHR6>kDu$IY0k#NSrRKuf4pN9Ip)c(m~i zY%{1BQE4LI$jb70KvnEO506q72E1VcZxV1vNBdm4)d-#gFE=`$$O7h!^%}5p*pDv$ z0JzTUB4J-bxXqQWnGRZ*SdH|EUF*c8^t}$;I77u>qvp&UM`(Adhp{RU6G~tD5Ywx^ z?L!Pc_&v9|vk3^p5Cs;m^={ct&+(fje2!cTSIvpvckPwF8HLIeXkl zuFxVz*jgW)OMUVq&-KZ;hoy=9b{xpa^n|^6w2t)c9x)a@?Pz8kOMNo=^;`DBW}Q%< zeAh$NCsRj!LXS$mJD^GdY#jOxa}I+84As%IY7v)*@%sz)YZ~c%bgy~vWSj%`=Y~DT z0;ibZ$aAr8T1X4Fr=K%Bd+x3VmMuvfO^>r13ubtbD{1Sk=r)IiLu?Fz*vbb)7cmyVUV1cV?yIfL z>!_kSx*^1YY{Uzuz>i(2;oWIhH>_xX#W7OT%;zl_RhL?3*2qligzfu+NEW72XB1dQ0B)!}LfqwK5Qw|c;XXx|$il#x74WiK0_Tp!PP z2EPgt@&k^JzTiOR{(1l*h`0dqw=b7Ff5~I%19U%UUB!s(GMNN?C)o(0&9hLX7!y%L zWc}1Vztdv^Izi9m=y&dnsps7=EN1#UEzfvof)`F_l+nop3kF>v??0-9;qncwf)q=& zQ<(c5UYY0cCd`$D1(O~j8|ezcvzYcH;wv9=u9XNSB@D6OkaQ9MX6#TmI#F#%%h>OE z%3cS@GkeqVv{MAmh(>eF1;@?57!aOdAhMiFJ;EOHfB*9zD=|Wh@gMEye}1{wTjIL# zA7%)7`JW$t{{iLvfBy%noqPTt|1o>~$4W#5(eU3?&vy$50ChH8_T4E|^6r;eSFs!fj>Y%Jw6L@s{ z&1M8Xfqw1o*}iU;yD#!5UA{jChQ42xs-_mn^ z;@H#nb#{A+xp$IN<9ztsF75sLRqBRyxE&2^)lD1k%>PdCX*l**+CB6(UH_KBS~I6! zd9$ofFk@=}KlRQEJvsh=>YbyDaqL&=YI4|a2p}iNGciwgV>w$flLKS7_w4JKE+OXl5G23uP~f1J#M{Pe=Hvo{pR$u}y>`TU#L zT&LUB;_9C3Z}-*fU_LNLQitw68?SyRJ-|WR%=tdQd79u%!~r+}^IAUP5g@kyJ9afV#8V=Wj?sr1NUBTPT)Y z2&aU#)gGwbXYaCqzg@6C+b-)D?c)^Qm1yQZct$YJoNhr^UtfiXzHyWwNb# zuw4yq`qEzz^hU}~XuXKAk35;@Wd-OV?FlrmT^IDPV!UfOHngB1Uk;kvP7X}#MJjDd z#vC7@gl8P<5dTphLl$MuWR7lfPMA2iU^h1>+;a)DqGujvD%x|51dS|_?E9PtF9pM2SXWDe7DJiJSpmH$Lw9H%3n5&t~``K9f7N{4K^_+Vr z55Uffbr3baJe$Z?{c@vPY~cf~7@L!C^vV{U*5x~nBIjhzsHQ6AGmScb&H)>(9p_E3 zB+-;0w_ybwHz&@WB)+L}&Vh6C9_J|F7sy*-+;g`hTyp6WFtY9_uayOChF;B4`o^5V zFpw$Ujo8y+P;IC;wh9u3vr>Tr;TqUICL>c_DygSM)WMu8{o{>)#yo#Xjj46=&GZ>F zPsKOw!&^-9ZnfbV6?Q1=BBw)Kq`HzlSrL6Q%bpfzJ=tRhyU0e2+r2>RN3_2J{qZZ5?Gqh-gPjq?1Ts-87ywZz!32;y0U0`}C6x4+a91YEGI{SaSo4bp76DIXhw7BpZh|6fzsWzN_>}k={0WDrU^pyWPM>9RkcZ>Xy?-OpC79 zk^M5mhAX>^DplA&tQ9>a4c>UoI>vNkPo=W%UDol!@1Q}C*D-U=T=Iqq*1ham)47t} z-UmXNSv%baDlq-WY^-hOQB}A>f8|$~iQD8#kB877h)ahQMR@E1kg~$Hf5Qp# zb{DRnDY90kiHFV0+Gg#8>o|6F?s{K#eeRBx8TA-tpS&75(UXL5+6TRljZuaA|s3G}(aZygs$WfUt4egHK+DbW|&zSN%TvCohhAAENb zYgLcvl!+6n*QI)kt{o5kL&{gMtnB-|4><2Fz(qhu^jR0YXVp2wvm6L&yLY8q9;*ZS zeMYhoslhFUp(hCbe%Q=^nPF2`=@e@pWyI36n^m2n_6 z62&~mhOAF{-gqP8@KCH^%ofN|6wdg`JeJrD*T56q{*CS(%zYuJWnF@Aoa~B7ee=za zcj6xnNS4gIw6Qu8S0uv<(g(Ka(y?&$(#qQ%@k*SO&J7{+4 zyfAV@a%pJrhQFtIL|GHvVV!(S1$MMkiX`>t!_0w*iSy5jtQPp zyG0#;N=bg2B9HOl>!|+$XB*wum4NRjW_HE#GpD(=@q1T6aF}ea4%c^&xiiHVFBu;~ z#7ftfwfG8x(af*dz-w=Ou_wqb8cXlJ4Y8=!rbLyz&#@0SU?#3z^!eWJK^1Oz*qh8# z4hjx(7M}T5NX-p3c8nnor?v4e7jOv{;>&))-r_w8eH|n8Al^II7!@b>;D8uTl*qFh z?}Via=g2$OUQf|hiPQDKzQ5;=@c@4enKyE|R+vPDOxLR4u9Gbwy2AJMH!pdZWs3*-7jSq*hdcj_}^ftLg517%2xU~*1 zvrrolAF`&OQM&g&=v{x~{j`$t%Q)ui#aw&DKf-@TW}mPXrf`IRkb(Y! zkCMKy8uo*+9k32*LX<0^Np-e#Xm%5?&))`dY)UINL1=t}Q{LhlV|hrQW?D09 z2Kr0K9!Ed*h#e^ueaTUAtl+t(U#yjy7&zSLodo(7<7x+9G+mVScNWSO#MWzFU#Ly*=IM*iaoCdAX_W6gWP!Ft>nFfzZt0n)3zf*Iyn@O}Dy)cBIf7<(w7s$tKCbV~1^%qnJV#;c#P zY5t83h)%}FZ^wqevH#j@mR6vp8SCUDiDMH@m*($`_bm zc~|T@857Q7jI5&MoG#UWlP?-dVs^xb);tAKPj6C!hh^iLvQFH+c2gUV?t?C}n&+N~ z-}v(mC+u`EZj?BI+2bCKdw;H4)|qyh>oAA)OOS@SxZ1+#248T*Qvr{G{o*;e$N zwUz5Sp9f2guBOMD@EhhW0H0q5!w*#4Q?a2`Gq*RTevmy~ip%W7t_b`+1wKE0$| zAvrNl?K{bRUVMwOK|aQj6T;VPb3AfVRbaB#boFW^jfH(KhzXvnlH&Lz+q9jIjMDX*sNfF$-h7X5spx1KO9qo;qlY3B-cuSc2&)(Vr*(#s%&O(>=RI&yg4!*7cNBKB5s-i;_q1noyqh_;rkbV0pkNY$Mq;0`5r9I4gxTrvcKJUJaZb{ z#nNjwA`f~JRcCTD=>)2+2Me=9hy*5_Q7y+2i)=FCq_l|J?fvJJgI%9~4>M%kEkJgF zRjEB$2!8#*TMlUR`OF%^G3z7xC$)MvmYp$Oc*tGb8UPq3@ zwJiO_=ld8rJ0~RPpI>KK3FJTw$@O8EGk)D2v^ zQDyTly@Bw0D<|Y{Y-$`D{UX29uT$({%7n#_Bu;No?qhYgJ1bEe=7I0WY;K~hZKS0} zf0Zh}3(j8a#}Y@;LQveTzr){C&=AIuBok{v<@$8UQ?Snr-X&WL|B2b>lC!dr^Ej6C zoIZcM{LOo)nn_6zPZ;RA-1C9+qBZ>lf6rr;w&#h>%I{w<$i5rrzS9>5-Drva`1Tul zWn!)3$FW}##=Nhi@Fg)#ZmO3Je>-)q#OsSSLmOKr-%k^(&-WXjaj=^7U5_M3fserF zQ9GtyZvsyp$AD67nOh_s%s0q>nw>=c4pdtj_dDBjiP4sKgVkk8eD5hcu0z)PjozYX#0F`D zw7b74qq-`OULq}jc@iV*1j=9~vDp|;7(2&2MpwMF%girS5Si)H*w7IB8=sK)+9g3OTR%a>ZK4I&OO3X%qO4e)B_5Xcnmz?;_rED)*r?tu@K|~@pIz~{$IFo z<`i&K7I+Eu@-nrO{jd(?yK1!R=(N5LfD-qhQ57We`jNI5%6*v+5r8} z^VTWcPbOGTU-CL1+w&0QFY>RPoz~=kB4YDt_w$Wf(PX#Mt7K#=p~wJW>5tW0LXPFM zwkId@Knci0V%7^Ug(r4<_UUu$-KYW$9<|zEFFV0jk=&2ng!9^(-#%&V9sz=V=(6*6X~B#8jA9>y>k`f+AJGTi5uGG@)2ruhtC!hNe1^%F z)I$D7+Ru}Y0^tGKHcHU__s_&=unolLM=$WZ#PBMQSU2C1-+aS9m|$9iA=8F@c24+^ z`0CkI{9b^5T#KNrCdzz(4KANO`HXxSDEH4aXQ6j8sa?^Cy&40l81~ zH`iuQG9exXJ@kws2hihuOO6*DsCFpj`X--=jlH7xJhieV9hY{j!?Mr7wVknRKMQCypuFQ4Cb|MGC7+8&5 z_x%|z7$Qo?rcUUc@Yeju4g%#ULqucvVm)8l74v2}$;^vWag zvBo2oM+)DrK@!Hz`x4e`@m5tBo?lAdXZE3aWg0Gpg=3Ioy1j# zav=CwR8tf;&WKn|+hCco-xou&pNd1A{0go)1=mDO0}ch&hVALZo~gLOPHSVVE;92y zaE)unr)(cyQb!*W@FzPj>-8h9`s~w+2bxgx)*-J@^LOYSuVClA&@MK(SMQZ;R2~0K za*oy&la5NK!8Jgi((M)Ahk5sv+=Tarec6Fnm1;cRgQsUz!+iqxjWGfLFY}%^=)g?w82VQ`GH^@&ER|f|`&H zv4>LWPzvD}&Xam8kK}a>Ls_WA3fHtmymo!i(d;29N-__5h93nCr;&ckmpwJ3-lUo?!m8r$DX|biQ6i5) zSO!VhVJ!L{46l&ty$Y%bS#rfch9^Yr-y-&UVsUJSElw$u)vKT2kfRXb2K6Ff7tkcN%3sF)*zx+$HM z!BDL1>8yG28RSL zqn2l*mx1I52e#pLymv}6z-xG?&K8@phrhS?X64*S=u=(z_&*N3t5geO;}Gjd*eJaf zvxGhWrALon;F)57M@OQVmrK2)+2fF0_TqGD6N7s=MEFgtbLL`Oh2+e#xLvfIfNvNxQ7QPCBvr+*NZmY)L=#OEL_kILt7|0 znHElBKErHo5G&ChRVh!SIMRJ5%f))d7m5cMJ{(h*4T#p*Xa3D0HdStajv|=#h$r$r zTLAwunSaPO!`=ZtKe>j7fubCU>n7m2+Usq_K9VXg$Q>j z)v1g(pO_HrHF@u6>=IDs6cWie3ESM@hTroRBr;C&u3$bS3^2AXU`FH_aFrk*k!ul3 zFclU|KH0;8d_P`$=wcsHqo0$rY{xYmBa7+Ve9@20?*N~VepJ|CRQFh+OK&V>EAjKc zvHqF2jO#JSzw>r~KIR;o3tF5&##u$TF0U`hN5%WPjLj$S)tFhGmsx(uiY1(56Y?pM zkz~#=J;ffZ)0K){xXMU za1p{Y+kqi~{kKGh4mzJGR`5#~!x(hiSWD%ujf2zq(4E*{Ra5eTolhy-h;s-3ZEa9( zOK&u6F!UcWUsGXhNae5xNZZN!q8*6##;)=jRB#*X1n+ZY++&ZSHs&-TFCy^F1M)f& z(`3!Hb1pGWV2Z3WjtyUl(b1E8P4OWy6yc7^r?P&)jfC)08rnNDUgi+F>o%FN2D4XW z4Z}WV40_&;@jaC7_gyYbZl2%ZE0yn=L+~NY1sAM@8pbi|I;XT9o(R^{mHEbwl>H|A zDwH+M=XSlxXLwlI!?NCFeO**pUzWt=wMSY1#Nw`$c(m07-{1l6_x;&d;15JT@T~HE zyBAgJ<>XoA`*t6*uFPw7#qaRgcb4QVARuEZKhb2wBQ<#khQ(;dO;EuPtb5iJ@m}qX zO}kHfR(E}N5*cF|SIhXAk(O@r(#$~=O}v7<7h?-u&ItQj@wk@uoo|S}_a)|RHNnY@ zBo^!l77jVW!LtGQfIs>5#rx7V!XZZ_dk4jQ|E#l@WbaU3fD1}*oEO6*8m5NWB?3RO z+j(Qm#?1_e#F_aJLAQI0{rK~pvAWGguIXGuUnW|}-y~ap#ta+=+d{BLBXh`_uj~C;Ol3E+aNLClOrDWBdXV|1-^agaTo%qNIKdCYBqnH5nOfzu07e^n zN-fcewtI*=Ry;ueuRH9;)IVOGut|v@@_>?CR3ruzu{Rz0-P%>`RgMweIj!S3v+2D&E{YDH{m_ZEi5$IW+U^RKD4_*MS zc#+8m`rb{^&jW>=4pp3t?a$hno>Rt0F&B0Di2RWJgUCVJi!;ZJi~K&MJbeOPE4gXO zLkL$bVIJ9Qp6oT}Q^uF3%7Pb=-$%(>?NA%fU#-WBf*um*JAk0hb*9M0y|aq zZo;b%r~qQUHNX0E2&PCr|IRH`9T8294uej8!{$6*(my%hfrRRo_kQKBsnGA>R;ulD z$-Rbz(TkF1=B*4tZeJSqVXIFKu{9F*qBh)7pBMW+7FqFs?nBvU!wc1m*t*o`{^-9Y zHfUhV9(^QMgkCbq327i2jZ6E!Rw1w)aD3$6Dg2f5jD2pgrz8gibTKz@z3%&vocMho zMzY4KIT1PVrQ#DrKRQRosojT+)xSMpB;|tr-uLA_yKJP}oE$pyC7(tA9U&+a83cUP ze7snUOTQhbGcN&{igD?{)X%k@K8vKoxDW>ij5Hf`<#{K>WPapbNS&w=u%>?H?6{pP zWV~Fz-H)u7x|HV?SvR>HRrX%ax*k#se1S?XLif$5I(Z7dZ*|`;%J#vSeGXog{j~Qn z)@gmM{=fDe@4>#yuy^usKBGnse3RIcEIoJ<|Mpz}zUSCa26_ivQp4 zW6^tz1A1?9w#ql2hx}f}{0|hzym23F*{J4yS=PMI`?`oPuGq+CJEvQU&0;@;ho3!i zpXU#o@FzUvI!!g(IVG)YwR1k#)QzkwvByPEnhNB461Ndq+}cCuK*u(iN9s4@bD`G3 z_{0MiokdbFbOleDSbsJV_|PAlNY(>fb?|YtbIN0rbFCSB^S!##eh73cpu4e2m@nccP8H!-w0ja>tG|60`_J={L8%Wl?NmBleUo`W zo^Y;=4HqNlPHaGQbuDA_jNObcBKv~80^hU9h$FQ&HE~JVi^!3(BV(PxpFfPLI4D*> zW8K>2GGEBnU-|W`Pk(I3Km97U276C?JB_1++IIAm%v%zjlzZpBR%Stz^HDe!6 z&Z}9-Vw*bsZf+5SL`Qt$P%yRZdC_~5_PBzayf!@3B<%u68h^vxcDQs{f7asQg!72- zH0MbS;ckCgKfYP3m#~X|^i=U{7qxy7*OGY^Jyss5wzU@%T_GFdv6tkzf9*Ba7c}@_MU7Pd^AQOX$$li) zaz8gb(%(?~s*>Nm-R|Xq;!VNKV0VIrz{VeTJ}?EpBF7)fR)?VTri`!o87VSv%3ZI3mt!%C zd>i9k(+|X#=WOi?HHV65x`Y~oae+?B;GzGsHyO+67~btsrONeW-=n8mOh{|vQ&p^k zbZ)QL%CrvH1N7jZ-3Y5#`}x=BL&kyL;L(RXfOooMIRDAZJjz;sk@3CH;rkry7l(2l zHq?Lp6ZWdod7BRm`ss~?K-2449(lP7d53R58ozIUaX-Yqux%>uSnBk}?w)gH&lHpGerUqd{!#V!d(*H3KESmilQ$092z$cK`H z3t})rCbHH%aGv-ZEpEwBtHBsA@5KX0CAg(YWN&j z#IQ`(VSwNF1^)KGxFuMhqL4UZ`z*oVvp7AT<&Im*ooAv!j0^8q$TQMg%utDIfiG)X z+>*h#eBo#P;+6vKnRTXIP3XuSuLqJ1&&7=T_|9tzH&ZbCDl1XOB)jg->_xgk;p7O; zjOTT_UH|Mff z|9zh@fB$e#|Gtkv`$nH!+jbGDzxPP|Pk#XxslB)0jm%nKfKl%%ojK51hoz%|U9yb$ zzt-Kv6eJe2Ht%79rU17I&S&bp&qc_xhd1{DC?%dWCWw;s?B! z!+R|kT*s#}2Iv5VlfTzlFG85mTZ|KUB)NtG)}(W4uOT_haV*%}8P9(B`@UZnJoCwO z^#whwy5Kqr;`vG5`tR-hnmIkJEaT}A{n;>|2SUHhi7)ov5SUT~(_fCCh4-yTpTQ`a z*6z(+6fG6&*@B|L!j=;_>GmA z{cHiH!jyVk#S{OG+y%bJGhO14qu7dn@v2{4bzY4liusNM^0dh7Huw4CJM^jz>`CC8 z#fBr6p$&^sg$F3~7-k-j@9wo!C(AU~W7Pq#B z#4G#M9+307`@|B@7!F8dGDg|2PZh2Y)PW`BS^h_NG44?tpE6eXEdOJlo8m9E@tPm| zAjcl$J+P53^fa3$c{`sE9OKdU&ef>ZXspgdMPAS_9+%aD|BZ{D2sWY;A{zLktdyB1sB_Q=0f;T5x#yeST0xD{yy{=%zfOf$%+6z+e zSr4*j40$wLbOSqb0V~bqmx`y#HJqkvUPHLzENw$B|A^8k2Gk!p3Uz1lJ##X@@wat< z4v04}#C(c*%m%J5C=b~K;9HeP@>AHg7)N~1e0F#q3#&!V$o-Q3f*CNK^qOi<=SLQR zMP6)jQ}Z;=KGadGFJ9^}A+P1zJ%p&$Xw{@3xk9dgxJAE`7oY}6upY!Y#kPJ@2H4gs zWcl<+M>X(wmXXNy-pI$iI$OP%@jM&RPu&(rx8PaNv_)Gd9bxQ&MF`{&`GasFiko zq%KE3YdTmjQ<-GbLg@zPd4yfnh2}^0W%U@%iEl`@WjLVcg$2zhb%>ftULSvWa~f&$ zggAl7UPCLYiQeKisa>K>@a-Tz1hyHtKTp`aZsfg49Seg5b$+l^AshLA+=z_!E?!)- zKb)=`5^j>?!Lyqk@JadXbA{b27(9G5#UH5>V`%XqR6TtA<{aTeJVY2UhENnWY5SSF zAht86%*PGTYI#qF5Ii{f*tE-Upc~bZNv44Nk&9(apeZ#$9m{jJ z&rQc#bC{xExUP&Riw92o`3w*BB;*8`ob0fIBpX2W1E7+0y+~q&d zECIuTUf*KSQ~AW!c^KC*73J*CJ* zt8*hHLtO=%UDJcb+@E%_DNTJ{F0ZUNa%&#kA7E?AJr3dgrq-f>^XXdC&1>a9uuf`7 z=8F^=WqjdNZ5ZFm78D+ObT^Mi)YwMBWSY=ia=s@nFufrQz{DSpIk7?q+g_I%kDg~z zv&%VK`gSUGLE&=~Pn@aUB}eNbp9M2{yqIC5^}Qt)5G08jh~ojvcrM2nD7hQsKB!ju zM{dDikkJbmliDt5L5Z-QQE1G5yGeae%a3L-Ut@I4#-r!NL@+tz44@Jp^S#9VT0G^9 zVro9`z5NcG)PG9jh+0?VfAwjQ*QV}h*Eeei%sN=x7ivYCCnroYYJPGbf2ENaq4_x5 zlN;-u$uzTO1+(1d-DnjYgw&7v)aqi7gVzCzGF@5=Z7?)Uq> zd`~v64sAUhgzvUwcq8_+Ga~)!&3-B(iflP8$2dldQQy|xZu4~rfV+CWjZ{_zP z6MPz&S1sfjuzNDN07S-z@(wd!#R1`3g7Y)h*NAbtd7pC6(EGinTDRAvs+D#Zi-B2E zB+I{+_MXjkaC75E#>t`%bbkN-)w1v3%Qyi^raQVEZxuT#;lVLdS0m3tyGZg=r#;*U z1$liWu3+rsJuKs<%|{}uhS=}llD~OTYV4Mv(NCm?pv8|^ZxUBb)?1WHcNOI!JPi}8*uU5{wO0q$&27G<&Qk22z5kIR zny`DhzvO$_H)E*c$j|Lpa7a)?Ox}rt5lvU+TIN@jZ~Z6aP3Njp&q@4hT8}GX=#scB zEej+LsJUDQB@JF;2d^2n`!R@XFqjg*htITNk2LMy!T>M6myTUbE2DcZTH;cSx2f%* zG$58GxwB2){+7^!ciQ?l5}TA4AH@`^g<7uwhkb{7* z|HT@}ialV?LC^2C`>ZGX41Szhiz%IfqgH|E~ud;aC4dlPEmeMOm%n;q>quq*C4Qi;h^msd>2?BO73^Jb3Z%?R$z z{^Ifkq;`jD20n%Dkxtrub;#OyhwR7bXU}=+-=4e26k2RZ8#^G!@)Mf}cP{hwlXH=H zM>~Huxpl40f0cn*5Z5+=qH9Cul2~txNhH>b54J`Z64$w7y(B*0DSiBRYU+OS#@f(@ z9~gcHOV*rVAg1S8)$qBY6?^zjUnzPT&8%O%(9gbuuYmi@yx;dB{_}r7zf#}(^L(<0 zt*teyZ>%AsrCNIR7zQ;O_sFkA(Zy@?7*t5cm)^g)#={w(tq(>; z;_Y{Q&Qf)#Euhbt#Pw0)3GenV)-fS>#yn(!?DcojDOOQ@{pH7^A-GeAx@bq@34+Zc zCn4hkKE7#VL@DnF)X?Fc-fxU(S~gV@xZ2Wu23>h@)JG^YfA=8xD&jPrvdunGrDb1IL+Sb@6dO_s@;ceF4iJjhS-N-;zWmPTX`7C>@~7UgD7tZF81jc}jE^InZcYiHyRq zl=UI!{OHr)^V44()=w-fZJ$ln-aU5)X8GsY)#Jo2+MM#mZ(1Cw?7trI!Z@&D!2ZSHeA2_ZIdoij zktgz72=S2UQ;9b|@;7sK$Fa#cKhW`k_+dNtHf5%Do;2np0kMG{?ACZ-U#X2(o zVh6QY5cy1OL~sQBCiN3CccXvUO~s4RS#R(_etKh5874+F*_rIUb5%*q%V%Gy9(#`* zR2P4Y-k_B!Ix{)3SHDw_7#ckZCbz1p(3$zywP2r!&8gWe$7H7S&Tdswk;(O5Q`dT( z)-1SN;2Skqy_@)Hw8n6Cp2K{g##5V6yFTF!O73rY?fvoD@Q*|nBHNg0iKDRp9P~Z8 zt5BP(%+X=Oxc}Op=)5)$-7w$O*@{fVmu1wcpZfhx@*ZJ^y!+i_;8qPNDkJ0r{pXz! z{*b@lT8!j<|HXlq(RH_SO=LK@#GkdnhnW1o*c<=w3*lT-*drx5M!pwk`5A|Y`Vx@g zcm1R1f?q(#E7NiF62tQr2UMZ8>w`HJOd7T_wNyWIT9fDc!RE@IAxr*WtYo*v;=WxOQL-{z|6=<4+cQ4|hy$zu@n z;yO;@NDe`CVH|8^TrWXulNdj_j5{V(bkaY&;|I$@ou7;=`|!Rm@%#Cluyo2U^@4q` z_Tvf*(%IehDl%X#?GQ!ycSb%V3-j zR=Wr1Hhb`+nxQXU>oB=ZYDf0qYzp_GMI4-J!I|#U#XQ7v?y%9SRu5A1i<-mVgZKaM zn#X)VVXCHwde!2d!7+wh_bjpi&cFY;E_r%=ts(l1-=TS( z>x&knb%b4xy|P0eh%9A|j+^ru-LD2o_XD;cBK;8wM=g&H7Kf`_SJGVo@7V`Pmo)r6W++M4OY#71gy zh3Owg@-F*{^Z(gL^?&6D>u?$y=`7d}q7nT9t(nz-g7rEpj>$yiguMG7PV!e5wmCGx zg#7Ble}16klEJIXZ|JyRKM))=zACap_9*;%e=<)vQ$!xZTd^>YL*nLHA$*xXHb;va z0nR77v&CF9-cl<-d|1ZnZx5m?f5+;AJsd8_z?NDa`5!rFs`!n^viA=s+3H2&G*V~K z;&SeLG5*1nVsAja2D3;x99P5A`mfylr&q-emS!1E(6iow$va?UB*!QC-z0olNU8mSHeHR_in67 zzLDd=Kh^vO9>^kJ$#*CAAordw)Bd>{B-#6>B<^kLl?B@mU4;$jicCI~g7-K1e$!%!yAoR@Z$f)A zlDe!1`Y1@Pm5;ojI`2n*x4151((iy;I3{5H%AsjclR$=Y+;dmZyGxx`7he-GKi{?2 zAooLUGjXp+>WJ(qzr)gn!hR?COby0Na)NlnV3bK)38$ad@SWvU)PcPKD`m>byTLgM z&zY&o@g_M^klEDT2q)FtLg!o6yRPlCPj5l2yY?o)%0P1H$4qCzqfWSn*yi_Z^xA7w zmOHB)i@LrrhRQsanrLbJMNd|qyGyvA=Pr5fA+hWGa|h%u$?HoEUClKL_9(x@0=bsaFAWM*>w!P$i!!N&hjl)c7xXm$;^ z(IXP0PJH}U<1y8_PKVFrx!b=h=U|avKu$6FNDDb^pv*(P66)|SY zOy@IPKpUO(BX{h22d)d>JpxkR(W95bf&>q9r>9mVb4O2+=(M5Z<842?5aU7(QzOpY z-Ml&R?^7wYgYetH^8mJk9%1dfUS^iOpCMRzze&z*Ju7aK*}G0DdE{A~_8h!{DVxT) zyCGWKW(BeO&m-7pxX1(Sv!BXE46%ko`-S}?Gb*l`!o1ZG7#H>%J`ykKh_WV zt9n3e{(8d(5HKv2+u%s(B*+63O^xCCXBD)UAt@b10@Xu=clE$|& zJc9L9&5psMgDa;a_XZz_+H$tCUZ!KifMX3z$3BV%HQwa`dDPt-i>e1c>YcbRc0L#- z!QRfro`!!oHR>Z=-^82Dh$2ktwe2cw!u&8i5{?~3-WRSa0kYyXj9hZkA_LBEbu=A; zS7A38ihfwAY$?(06Ww6t$|Tc-jlMKsX|M^_mm%Z1)_UcGR#2`hMS(l^V8g%3U?!iZ zfK;)F@%~&G?Uqtw%@v4~wpLEPnEj;xuG~Sclh`wsvE9 zDmCK?&aA}#hCc|wcS*w=yeKoVPtzyKF@xbM;q{>xf~NXY$FRrW52hBp8D#b-^=OP) zXLzQ~t&>;O5rf+}s^OA+81-CIV@bU?jdt@Zjhloju$14xgPk_Sz`m0^R>2~DY}!n~ zrpn-VqC;v0{}n#gH#rFYg?55#G8eY+;K2#0=K-t0hnurcOhUeM9!L3v)(R^lf$Af< z2?sA?g;FHF3qLbBH|$@wKAFo^3U6o;zSpH!DHgmVp%GD2IaOWy47iI?yoNIraZHHc zg-7mTh|q;s*E1$QF%fVA@IRY$W-;K+yO0Mr;F^Z(1uW3*DEQSFnQl;{{X)kU;n#xj zCXh}huUBGRp9o!i@oZw~8{(|V=x9ZoZq!^DP562%ih&~YI>)!HBWg0xmgFAm4&hrJ zLMN~x{hdmKBYyuV$-QsSPw*jL!KuLXhi^r#xsMj}-k|BG705WD-6FQ0o23TFtlND4 z2EBqlw8qA4RK0j5_)5PyJ}cNgA4)cTyMPmEY~u~yV$Yk5r)E|IGcyk0ci9sewdm;5 zx?MY3hAzUJ(H2r`1+wH->+jLY*Ta?mPnOTx&D{SHu|X@U=inlBBqyt=8me z!PH1$tsAo0X0k;L5P&S@vFI`Vx zkUPAPPwU7WFSt%duUNP+8!Q-dnxD${io(*B z+1v$dv>=wsG4Mvf=cVm9ewtGYgss$_y`Ul&M)Uy>!9=;pbI%+|0qKG4bk|7Hz42v( zj|mU)iW;nBZpYOe<(?~33oI}DtEtuf7MLm=zYU*hB?ynIIeX|3+&CZ2;s!|6L{Ez*L9+LQX z9(F|$FNyOhpHk15l=KRY`gDFlSEhZfTjAh1d!N=MYe{*7;e66|d+bWEq6YwHDVV|X zaSf(!0S%Mzns(vNRgv`m-hf`3xDtz<3v2f!Df4hlWw3KBKu_Y$|)smtUYl2letGd zAM)Utt3|AHOnX_@46 z`kBn@hKd?KkCS{8<4=O8^hiU1-@0|L@UWdekb};VrjS*|#o;{575LY|mAylVnMWIK z2q5Ko03Qwzkvn8RD2@u)^{GBG~9PoIBbRgE4- z*bm0`aHW>=XAH2Bps*(9enNMo)TjSktHrC8f<+@g&6u7q+yA}KMYqip7_jZy*71LH z#eAsdlPyf~qk}qj1n<6pu1NDwTFjiS(pRy?=eIlxcGs&Rp?&7qH8YqE24eTZ*Y080 zP&0#`Xa6B&;ht!6);MG)Mh$0AZ=ALE^xy=n1uPa5aDr=<>pp4ugT%b$yo4Is2<~t4 z8=UK`$Tfmz=YH=VT#=c`dP6lbxDFU}CVrz|^bPs~p0JSfu?y2H=kI;Ee_L5a3mCbV z0hb=}%sqoVfmsB5*%{Nk9IP4qU&UoXt{;1l`+Qmk_Yq!TojM-t=ROV%_QW@!!>81L zIqiE9ZUr}?Hgc?RgfTx-+vdbzB|7M?^62wV>-oIJ!6TT6=8r(-}cQ$Zy%DI<% zF87fb`c%e%B!a%0UF6(1@5s3Z*l~HEKxI4qypM$$4(T3xr7+P6nAQ5>fE^FV0yZ_X66&MKC3%dQ zs#&QuyWGCvh6H!>i9Wzi@4uTj>T$=GXY9(YT}-xiWUpWI1AIc#U*ICTqDi=vog|JG zeTJryHG;cz$T9S^!Pff*pxW%WV8hq`!*^;mFUI@L|5~rr3>+KzCv$EeF4ai~$0vN- z60I{|7x*5k_ep=nXYeeVO;F>7?@|>GW?hf5iN{(^tM_d%dAb@7W4j;N`h{8yFil^z zRec0E_r#0$IJY{btr#ZdYBo+={x*QFd9|7&9Dig zT8Ds;k4xTq(eJXf0okvU*3fzN&DF$ER~0=Or^2(cT{(>>&KbH4_;aQ!MA zmakxz{;z8=kMu^oUjws$A8>yIST72l^HjO_I~KWIqu*Ym<;4DdjtOdWvZgO7&_&_w zI}Arr7em!%RO%slaNMeldPOY@N^(2z-grV?$4BjDtd?Gde&0esho3ATrn^IbjQJ?yJ6o;seDj-Om?YYRWK|=oQrvs*HX9cAEu6 z71)8)uEX1i{CRl+i@y(+^ZB+~qM1|mq+2dMZ7$-A>+z&qTWzmFI;Agc@BxI&gqGR9f!+>NEry;^UZplJ=$>XQtz;yM)4QhigCU&X7JchZ~r1SPcNmR z7pS+Uy#G{Tr+LN%qlQW57>=9B<~iOvSMZV87Wf_b)wF?tgN)cZW3}X& z@GtZyDvZckP-O&n9!3g?Ym^%P!gPio&;ACN6uoV@&eNz0@lV6$`)yIIpBdbq7LAq( zF`4W^j}~v_;=vvfWYUh@sB9kS@#P z^EWz@exo|MXu*%Re0SXbz)e zSYwV;BSY*$z47;?pI_kj;j?!5TpurepZ03V5BZ(3hX?g0I3hl4m;DE}On?sp=FYGa z=hZl+nZ^!r%pBXi$_WLj5pTvsfTt&|H@j1zH?1*=B&17rP&u5TZ zJ1~?tEmB_ms+zc9Vd!7Gg-6zrnjDkwz(N2T1G}Yd5l=o%k*RyM2ki6F3E!+2Ea^J} zE`nl#k+4jfJCwKTe&Au_Nz>@PO2>!bYT~RfU$ktdePC2I%q6iLIPK2bJ2(`460wA! zY(dNq0BU_K?L70D-%zVBHQZR}CTdk0tqX^e1J>iz;#N)Qj^L<}SelPL@gC*IzQ!gq zzc#E7>Q)xJHP#!rU>E9TaM#&8*iZQ7<~mqSQ@+o>D%Lmn>L~yAI;6zWnZ!uga|5+t zZO{)wFzgoj(!Cec10Ay$J%YjVz`vOern22T(IaX5DuYu=T7;83W<9yhX}88er9J@r zk3hBNQwwi?TGeSbu0`-&ek~4`pBI4L9k;E1sBYx9vrf0D(sp;Ui zSqfhf7-wA>X&+#q_zsR2jq+pK5OzU!!e!_k5D_IPj@2Xk!&aqZGsg^(Y zB1;SJoJ*TmVn#p5LgtVE6SBly&HA8E26h|EkNzV8#-R9U(!U2`EqModC(oi)$C$>S zNM{P}cWiuWb;-_X>lDL;f5mrUmBCfvn#5_qom1?}aoWcWIxUbNo!etl>MzmdNKe8^ zpmoJa>HXcO!C(QC!!uE5!fCNLHQKR5Uuz@c5Wi!0U;-#*kXod6jLTa)#=?Gz+uPTp z(F)x@i3TnL&-e%jg4x!owfDCS%7Uzf_K9*p= zi5dGQzl)F2zn_anC?0jGS<)}xaHZIGQmEsZqS`Z+j&Ir+S>xIV^$NCoIX>p_V`iBC zaB97#-{T^TjeNACr_%pd-IahxRb=Z+Sb`lsMN~B4a~H!z(@VOOZiRu~w{$1z?(~Ks z(|dQ)`;uN9M;sJf(Gf*PN1vb~%J@)6ahOk(5k_Ru8Ae6iH+&D3(Z^HK5pmvMb-UBI z19=4XoA=Gl*SWXqR-LL_b?Tf`Rj1CuT0_RpFN9hJ*dGD?^3yu#1U4o1;>cMG$W&co zg#h!Mo?2uL#*82Pb3y20E29O0E+{89 z+@l%rAe@AUeeiy4*xNb*}q`LNMV*TrLGSP)+sWC77&#n^X7yY*O;C;STi;V(h3<| z?&Jg-_;bLl1{?51Hv_xbDtDp_{+BdZ`;M{lGTbe~G&ea<7{r>)AnMfJ?IZFsc&tz9 zP-d|9IUpd@Vfey}m~hrbVHzDJ7ZGlYv8;T&1Z#}(2PY1a=syqfhZ{HuvUY!tZF4>-at25&#_|-FM;<4 z$e*dOzT2pW-Ub9k7dHFIT-i)yTuHEyLD~RpeC3fA;}B~DLC}^?PXcl;Hb)AsR;}FU z?evG>Yp_RBm5&n=$j9(hpMc^?k4B*tV0PS}Rz#q49LJ_Wh!6P~`=w6!1ueORoyoim z;9%z_m^8=@88t+2L%$;HcEH)TG8y;64Xgq#ZlOPsh5}7Y9)N=|xO66Vkm2VWG{6q8 z>1t(8nb3BjeZ zy)-I;*$Z0{CTU2cY<9Z4U?;#N$zVR8fZrC3(R_~-zFn{`YSLl$fi(}M2fDO4{^nDY z@drBJfqT#-!I_f~^!te)OK66hOPxrk1Ft5!JQyy(4OJ+Lc?)Ep4%luk;+%#z0hz!L zUQTSAW2_5sZommvArq^OC6k~)>?D`4za!`CjD$XwTu_D*a?Iw*1>q#9TEuEoc(d~ra z!BA4B4B?1ztF#m(ZfAg_1Si<5nS`Ov0lWYWCvNQ!##SIGB*NtjuGAnVaPuiB9LTJd zOAF-;0v94&`GSAyIDB9iAk`s0K=K19E*!Zg@Hv$z=cv84$EA>Rkq#%m%9VjbY7QKl zoPq>>#_l8eht`D@k_!qTj9}e^en*FyAB{^#`%R$Nm8tq_E46B5;tE<+MJw)uUsxD zOCXIGn8;pGn-$19?E^BC-EJ}11QLITJsXB^=bTd{%lnFwey>O(@`?&Vzn1t|mkhWa z?y%oc5IC(uZ->Y{=!LxK63O6)U6%1XElH&J>){#HAv9=g8vAf~@Eu+dcEpI^4vo;O zv*3MZ5Wd+neyfbWb6P0wR4RyjrJbu$unCWEq7sJWvb~(%xLn{bo5^>evVY$v8JcJPWyj!@N3dj ze55gK+fg{MmRD&@S{aZ|1rF}gcO-)1P_>$5d{P{$Mya$xoQJRAI~z~K913ytJ}tzh zY+tEz1Qx8=R>n4k2$eNbb0xM6EPQOKqkbup4$_k}Mc%P^ud)a8Y&;t-k8z`hpQJd0 zEo7RPZ3P0E;tnRNv=sa!CrGi$M`2guhQgI`(_!JqoKny|4@o!FR+yY3%n7RLw_wS$ zS{`5y3o8Xd3D(kFSX!t}_XFWq#{Ed}6|P_#sjrf@b&9mhKY?yx{!*dCv^8t-n4(N& zTyh!>h8U2#fUN9SkyoJO1ubd+|C`QIkfeLb=52)!z%&-^XErZZ!9fGS9J;aML*w;G zOn6X$#_@4+XbQF$YjCqnW9;QLwAxr^28p?A>70MUIQ{r}!APftmjMHqK!VzpE zcnEK;fW_epZ2mHztUw!Wd$6^l=uYVz+aIIiJG3Yw2b6~kMfBSKqe&|LM+5X9_*zYS zvyr6lRR03)%{JMOFPG4kY$aI9^6^pOGi+2soy>+zEZ-+I5aRO>kQ<~F*I9mwB}0M` zWDInb<6(*gWx&Y?i6bKc+tNe0P5CM(OmSLncb3zTy0Nk#9xB)2{G7#i^%2HEd0c4R ze0!^OhzE4#1Efy$3OZ#@QXk5vh@>1UK?9O`#EViBgrQ}Pfj4(qlJmOHpWU>c9ac`6Gb=X zfniu;6u6y+o`6#U-*#k9;U4s29_4Y!p!$@B-!2K9;VHbB&z79s$m6gKcpQ?r%^@3b zcPr8sd$`*Mv)Wd>CD=0{v)cwNL5o`Ev=Q%1G5b(@m^H?sNQEC3We0t&u;f;TQ;mlw zomiB>^g^GefU?vcM*;I8%zdC&P>$=uxHY?+3V4ZE!+oT`hVQLnwvP5&V$lrdcttYPA^9yj=yAfwwo710x|}8m!!*6s?KCIdI(yg)jqP&3xVoLN zh}vp_KQ)EWuZ|92gk^D(nAT^6cLa{YC(!gGeU8kJtqyXSM1cAtP8$_LpA_>(c*cU6 zdOJ>&yBucpy}1*0ipq87EcE|^0j)6WvddyVong?Yk+s6tI8D=3WFXGml8vYTE3#V!26zvNkJ2>~4$T4`q;NUuLMAHn za(GlVJxeDs&JLSnl*cDl_{vvv5UR$)VRRP9*P*gX^H_4o2ES;#9 zGn{VpkaE8_oN7A&jt6{2*%c9E%aLA6H+fDqI43X+<4K1^eZtfvXFURmTsE1@3Q|1h z7=(nW;{(p)@X!&-WankF>E6@{i!7nF~9J-hDiyT>r(*JT_# z?qp`-#EHmOdkGmO6VCVf60xKw%l9Sx0e)X!Hj(y)`}+9w&JPy&p|mF%?u#Y-xoDsz zS=1=zFzz3A#k#ilZx&2?>NCgQ%cIfUs|Fie>;l8_z*V)^1kH2*sle%Dqss5rfPW+GLWzBmh-zb|a59ik0xnkj! zuLXC0yi+xhx@Z4)X3=}k{=D83}`vH&qM#e=mGwmpC(U%RxGl_IM%l8-Ko>+pP z&c(Bl7;x4`P}WA@T5P5&qD~#5d15!21~#8HdC_MFoBwn9+@E|Td1S+vP9WeTlV&vj z?lH}#)sG+gs%Po4WuNLVpYvR3#Ex2>fui9^@ldGCt;IUj#IZo%*M%bLe~ zZ<-i7twsHB_gb#FZ=+>i?9{zK%f7hh<^!u&3Z>gVI?z&lZcXONJp?nCKK56^1+a_#kV?)v*n+g>?tN!#Ar z#=ZTT^Ua-ie9^FA&-gcvA6&C&@dF>ce#%$O9HnOR7rfgC6HE85>)Z2%+&by1%T`a@ z_O3B=!|Jbg`Ih$YURH0tCc1R+n>vDwzukWh^9dUN!7?%)nHV^k6Mp0l;9_b6w=15w zb$j9^!RcFG8vJD4zCV7xk+RcUbFI-k{_CCMw^3`i^UN%nZqrcbT^|F1Tv`FQY5GFK!!q zvsFDmAeZuHoL7I{cEN<@q4iI{HSNG<{CA$(D4lf6IN!XNcjkH~{dVH^^|$g=4O8zs z{rN_#?V0Nj{q=9tKYRYGg4n2QG=~22>Z3QTzp$?HwUt*swr!7Sf@RZI$FBeW=B@b5 zpLQ0m{Dw#XgluiLHpovF3=+$V09FZ)&auD4WwmKmq}lie*O&a*{N;cJbLz z7N0$7%lPAqcZ}cmd{b!8*Q*p4tdlC{zjDw%tLuGFgL{Q6mwoT_-|S32bjkZ^&C-e1 zRa5S0o;Rz%_wj3{$s63p{ku;)R9rpv_A{IdbGsMJOHOZd4Ay`8?CkKS%L~j2D|+6Z z{XZ{GH|`(je*3=lS3Pn6S8JxeGdtIN{i~N=vFlR9tlo>S7(+?vjhiep>lkJkSR+eH znIn)=5z~GY?B)BSo=hf?;aA8A{`*#IueAtZ?R-lfy_1*4~>pz{TAE_+%8^)5-7StVqsigxMaJoQWwwS~}Pl%Ud zE^neDSTK>!d(!^Ym@w1ZRLYsrm**yi{Ein~i|;gaWCtY!;p- zn($<;Y$Pu=j>>a_)OafL`Eq{G#e&oX1JhB*WV=)KMyAO)#yHmauJPS$6OA=7n-Ln5 zz1HVTM56&;Hj;>Ec%5l~ARX|lB0ezUi8C!(OX@g|geZ>}dsDncmYnls^btu#gvJZc zqT(ZY*_&#B{5Xkr=bazjHuog(2=W`biBDu&I88hlNM}$Hjp8TA(2cD;GO%~v9qPfj zX2@^k;!a{_bBgm8vw;NpWQ{_035=e#OHmOuBmbD66W#V>`Skx{Mi`yivqE?|y(k2s$ z#L`xol+^H}m%skD3AH1?F5?)|SVOWTSIVD4UesJw$(~f0pg)_x%Ko`hs<0;Sr{ zJ-h`_)Dc7?PiO5+s0H!{~E&s@*t;;xJ!iD8rXNB_8x+ivScc*i6kN`o@$n!;j3v{D^qq(;HxE3c-GAWZ1H4q3kBub(NN^+G4 zN&*6;-};ZvLw?`l=qkEw*S74c19M}SDSlkjH?Ku<{jNIy@jw0TZ~ywQfBWwEKmPW= z|7{um>Fu}IliSj4ay$9Q^(IjH|N1vKcl^Kq!Ta;i+x=H>`RAYi5P$!%IQ+x^I=No{ zx!uq1g84uGOm5gtc78Dbv+{S1$cCW){mb9}+kgLm{_BW3aQk(#+O2l}mw)*mqyH1X z|6JYtxws30Kc^FK6HdO){#@?&n=k*t9Q|)Z{d%|ibHDp@;(33Tzx>B$akPM><)7gW0+_oC?9m;XdPTmH-cSfa)Mj27oRf3=%``JcXR z>K|S(xw)C&{3B`d?~1td?_d7&fBWxV#FX{w_8)&Gz#{O`?w|NQUee>VrZth0a9<|>9W|2Mlo z`nOuSS^l^F>dW6p|KH+2{mtFk#~Bj`kN)6a`Czgz`X#Y3=RTeDu3c`Z;;Jkzh|^C_ zD2o%u!&$3c7U$i#Fn-iiI+J{+Fo|^-vzW&A4DBi68owLTwHDFw7fp$KEc(ssbf0KX z(4)T0ai6?*51WXttQ#8^l^1oeRaB1_^^7IIm89QPo{1qw^-EUTKGRW%=$v}OB+<{f zZ=!Re?NX_9Q4(dZN+sMT`c^76ak`D#Zprsix1+W|u>Po(ZWn(@L~o0+;bQ*i*P0m* z*cVYx+fmm=Rd3y)<@dLP@~nDsoE*W1 zG1Eo;345qOc8(552ak9`KTtQAYmsb{Lt~OozlndzkvmW86S9f)h;F+S^;~CTQRr%& zerr@8&($$zh8Sa77w@@4_uku9ma~=}xDsD_TQ3})l+R*3LU&jsK47N^@xwk9q}(GZ&6+}c{?h# z7Uf3ihU%4sKBeuk&L=7-WF7Ta6iH;=O8cx`RAwz+naS+ucp>!YjJ%-w!Hc$#Q8xC0 z&@u5_=vJh2q922A?2RMDJv>5aqir zZlDeBJouijkad^znaP>lqXr_|MZ2QkIVa!|UC1`>$)kTF-wv`atBZs1Gm#yMaV7Sj zRULVJs$afal)s=~{IW>rM1z-u(xGH&G|2)O7~Pw_%Aru^oacz1{EK_hrjxP zjf!#P;zFT~ikpOuHlsL`3_p&WBG#Cn!Ebr?kl$SR`zWqtwkqTrMG&x(as_Vn6P~s% z)V?y7n5CR@{!wC=bv7sc&K;f-r7O4;_LA}{Qyo$-LFZ=QB(gY95#JJ-gfud24=)F_UTB9~>H)u~@ct7O3FatZ7?#z2k4@F!3^|&%B&y0S#&HL)b zq90r=QohT4VXW=NQ~ExtwgYn=1m?-FjZVg~D~~J2+NcrVW1btL?M-ksXa%9`I3ejv zZ*#ad@Yn7yzxFrng{=gpQ+6KJ5vS6uCUL*#h_YwYn;FR_mdl|_ZTFlDhit+YGJ2sN z4fyt@<2|QM*e$dCjZGZCx8cWZqR}_hi}o3H=U^!IX{MT$@x~mwLo?Eq4RpnMA>Yqy zzkP3Oeq+DwXWMT-Fv3y!&3&Jc^N)RvQb(M=_cW0sP#jvgwKnE0`~*L|n5~tAS+nALi`Yha3H;@8S{sDk$zF_V;{*I4 z>IgQ;_3j|8^V(G+e_-^N>uL2OjB*^xSK@nGMj7##yL5aj5j-F+{#otn9!%M&(@)nTQr_TMd*61KNa%w;aE-v{3C4XZZ@nq zF`rLo=AUS@tG$`LqV^KoS1(@RAB+R>5V~-JB_O5{S0%l=gRJR zZ7EgEE$?9KSi7yiQGa;jdy=VGp|sdqCaSoJlXe8$6h)B4GHQo}qJbN%23F;Dm)qi-g@ zZ|@Ed6xU9%pMc-Bj@`c109WyMpdx;qkSF6h<;sC5Z(%(QeSB^_d&TybqW%6*Jx!h+ z3o_j?VFy#qIT)!gUIx<^a)X^oaT>fmd@*IcVal7TdBR^YZX>!bV#8sHwTlrm1KZ(^t)y5CquVaTSVUOpN_2ca#hMp!` zu$?LA;2phM8*wG+v;OnPS0-20sN{8*be zW2|+Z$%j1BrV?U3^BT{p&&A)m;BV}�h^Gt0EN&9mC;l63bF`3O_!F?>^fiumX{ z70kD8Up*#u71GtnA7O5ZbIh?ZFFCV352o>rzWB&ubia1C4 z%~|DX^j*@MlW14O_h_DZW_i-wAsHjqAK?orww>4K7Pom;d1Br|V-tMk#Al54%m zSQ^=guvPx3eVSNz=s69?9<;l(p8iT}%||>(sUk-|Vy(JtZsWO^94xVpu$X9AhiE11 z=;ITuxeb&-uTRe!UTZBM&t}Y>jpt>vU?#iqTq$gL)e4?P(3(n+%;&1eecS!OR-H`) zYwoZ0oTGg=Tf11#Y46;TRZZp_-H~~VHTfFmm)A*7@y0SZO$W^i){9q@Q8Hh7dl{Jh z_B+eK!My%mWl(*08LZWJmcjV=&N8UIy9`e6Z3kZEon=seZy9uD_nqwk&uQP+7uE5M z=UwALx6j{A21Z{Mc?n{470*V3KxyiuV()(=PiP-RpZV1IzGHPU$KTH-g$cyKP32CdwafmZ?CVXxAgCVul2X*t6zO< zeZ@PCZ>_KOZv1TcE#hV5UpL`-`T7&jY1e$wkDt*y)n&!qy7UaLGSe=$?k0X$@jPo< zFL{1_G_dC7ZdH2EU@;x4b>#PW@7MfF@2JeTXST^F-_6RN{~BxeG{$rFnP%=s)kbAH zxH#f{COrRj1H4PNn$xrXrW=c|wb-Zu#F&Ia`_dDks{@8nFj@eq@SfV4a_yJMy>C8v$1nG;+L?0Hm)N=jdrsaRjO`|Sm+n{SAHEM6+~E0t^3KRF zsfXs)yqS$meEZOTGaWqNJJQR$@g}(QT7&YL{n6RJFTfmh)vuxdzwsE%ZDVC?!R_xV z0~3B^uHd~6=e+Z{GupwLzD;rH{oV8CgJ-h?thb$#Z)|?^alj{u?;-klSMl7>^Ufv@ za~PZ)#k9|x>!0nrrh_rWcaasm!&@=elpFrm<0y;oY<}l))W%rf{k`qFEpiR|CbviP zpWW{oE0p7O+rGm71_8b=q&zVB{@_eG^x15@hH*Q~GUl4}QWo3nvG$h`$8SFlGqZ~p-tVSwOZ)Wg+wUqD za~quc8uHs{EEc}4!`=wzeYSf>n=jMaWwL*TH9TCqRqPig;XBdF#bmhF%TIlaj`xD= z*zZEZH>pckGnaVI&9UBy=iR1^Jp~SCf6UtWh86n|Oz>^{Q+o=ixaZm{w~zSlax<^2dZ0{3@g|VK&fB{nxEL;0^fq+oyGC6jJ+M~V7wdbMpYNTE3{t+zGV&X z3irRzl<%(Rwj_Dvl)HiuvUc$e(j=x|K0?Q)Q8+@DZKCqvMmeh20^P`Ml9Har8@>tp zJ}a5g8{#6-^Ba(5cg|=J6||4dS_+R_LAO9xt#ogn^n^*vl-|nr{>kWUAoh@<_Ol(Z zvKtv|#Pkl6v&&g|>;^^qzfk*nm|jj#7$1qv;;8<3w>ug|j^2#v9i|oaUk+2c*ds0C zL9YBkoKySQ@-Cw0>z9L+-YKe|Z^hn4LOx09e0j-F=~}YUO`a}0DZQNvHR}j=>Jobv zQT^B|FFDeVGOna_>{OQ=Ir^uRp6&dWt6!?6^j5ZaU{)S|rOSf5ExD$?NQrzpMf{c8 zDP1iZpS(}$*>+qx{ooy8c0&Gz^#>^(J9Z^U)<3zF(zS$|=RbF4>K4_n7uH`*=|WMX zI zEzexhEz$poM*rsee@y8)rON1XJEeDurD{sg=|@&S8ws&c`P`<38J*WtdM zGCKP#%!caE#m|f`O;dW#|73Kzl+v?{IsWpIsFvzyMeR$6l%C4~b5)tbZIJ%wj*QNC zF}=f(>?TJ>mo8Fz&Oc>z))#3N^^X_zk6)+sTzt#QOQn>a^N$&w4Mf-z<4<~**FXL> zrDHeON#`_fHH*9*%BBJCC9FKoY- z(mO?T-Vtgo+RyYO8=pLh>BtWX*OG5MJO2i~f7=s+w!>=;iph8?~>P#1$ z5#r~Pk^B~Qa;|^#TY30f_7v9Ih^kERnzKFV#3CeNH1tD&f2>iG5?~IDvB}}sVu~6j zeZddh)zmOC=Ux&$Pz3ZV+FLyz=6%oZo06(1Us2h< z-eSi79_8z3BdV8=;U?rR=B~!dmrG*qMdf`1$OcX9-|tJMFjjr4r)y#MaKJ{=j1BR> zr(FWwK-T0?()|$Sv($-wM)?(dODWVdf8F)37_a<6<+DS+LC7^l{vKb%aImRdfzd0v5$p>62XIp}c8@k}|jYkYDB| zIt&}EkMIS5EP$p2TE~*GUVmiw0@DKRNnMmz5ZgAMzbJ|Jnb9Hn9~F|H-r^zU3*fDg zANVKDz;h$Y!hKPGgYmm_D-twrUgLk}qJ1hKVP>EcI?4ZG`M7-@<ZznQ1vcO=0-ln4F`x<{BHGSra9 zKL-DLWAIPtB=^4%|Hb(C-;AHySN3}pB>hADuT%ep{$r1Rl6BgD4RedloH^0TNo>Ez zsC|`t6uX-R zuEPI_AP?GRj^%!YP0p$Pr z>noW%&&W~CjS4-;#}ob~`bxJX*Tn6yeaNZ#L3~j?IGH48*)@iIwVB*eN$I&|1%N>$ z2NW?ruKp+`?M5*&x}KO}Zc!{XP23@3hPYEM#}(0^q=8%vPOqVV_euY<<_Y*EIeH?V zM8A@n;usk*L(s*ks35u+{VN783T5$iU^}u+JdPqL+etYCB|FVkXa z^>Gx82;W$eC>0M%R9z%5u2aYm2UeKlv-ppQF}bEPYE(p!0OpfjPtTM@{!)&YG=3rf z+)8~~0zfa~JgP0>Z;s6L43O+AV&L~jHvWtziTWaaOZt>iadwyD&*Z(HxJ8V=Xmspa zFhI1I)0(VaLqWbF4mlcCMiOUHf4;dGWoaTxv`kM@u7o<4AhUI&ZAko~f`_A7*Q@N3}^HN&oKRl73bcKj3K5S>`vl*e_G zFhG7o9HL^q`RBkPDju5>{xUd3#k%v0;SfD)zrPF)(PMHhz+n5pAzDAVN5SC(hv)-` z=!e)@A2>vW8}fle6l1<7SYQdZS`PaJd&+#^5G8vUJdIzL?OB99;y-YRK5&SD9q@re z^npVpFi~H!Ki5y=E`8t-J%#z1`RJd-KKj5R`XQ$3v-$ghLv%h|sbWl^^WxSNw|Q3F znhzYJGh)Ae;1GRdpMQDZ&f@jLH^d=&WG(Owafpi7`rZ(S=w&iE0}j#4W$@ZKL@$@Y zE8`HoTn4XY z*TW%tHeX*4hv?aSeKj1SXY=*-aEPAG*Vn@#dM;mI4~OXae0?<>qG$8<^>B!a`TFF0 zKeYtE8PD;{6YMc=yaf(XF@Ik2tm}<%h#t|Mm*CO7B@WSF=c|yxS#gM7qmO6DA$kS5 zpAm=XuiN^0afn`_f9J;`dIdc^FAh<$4!*{o$i?zG0}fF!k6&Rg`eNRk1&8R3$MEVn zL~q^ySH~eLmh~C!;Pr8c-dF}_!y$V6ahwN-=#9tm+BihT{`(a$N{ab%HXNdN8OIrL zh~9V{ua84iEcdg;X`_EG9HO_573aetdTX6}bsVBM)~Ouc$}8blzqu}*2Z!j5JjqL$3?>sm}Z#|BjU7Q7n=*`D-Rve-?xA_koq7NLR4;-Qo9HMl4pltIb{((b8 zT%ld1bNdKp{Q1Bk`oJNI^+N!HL?)Eej}IIo+H;NU=L3f*QV#r^+o$aVhlsdBI1o9U zjEp}YI7BJVP&ZYdoc?^^5Yd+SLVrGRh|;i=o?!Fd4;-RcKR$4XG6$0E9m_s&h(2(L zN~H(;o@ZK;o_*jDr8q_rPd;#nQr-AC`K3AVfkTwWIzm7Az#;m;A^N}}vOjQ$ep?)( z`nmvzh;XnzaEKZYBk}ogj6QIPGES#J5!s?$9+Be%hlp4uau9)i@_|DHc6{Iv#rBgT zx5U5WKOZ^MXM>nX+|V&e#b66|bUsC{KD zF-tk67)$iX?y}Cl#Z+W!f@dWG+c6e}x>tt@4v~N+2kH}OmpgOs)$HR;Q&&uCoMx?d zc|6qVI=(Zrm%xv_^J>8@(ML6NKi%4O+;=xyH7f4Tx^st(vbWo?OYC`exee~zZ4+hf zVBP|)I=FL3?d41xZ2Ixu*p7B_J>A*0$;b@idakEi_x^G_3#Qw_@hN;9uew}%+rj3! z_q)}@{dlY1#BJ@S+Qn)-%y5WM*Ubz1dvu3~n0jO}(4R>}j6nEH6EUTqi9Pa!TZx$$LJmJd}5Eahhli;LNhS z>27>C8@lDo?cff2(a6h-X-~@wa<9*%#tu3XVEkowXm=ss>)Cc-O{)RwoLzfWC$BqC zk1x_a^ihQ@o~>hv+}Wsd-F0eK1LLWZ*#adoPlchmtXsv_QF;I(2H5BNJb8SreD*ih4Nx317%67XyFyr_13v?tf7iNIH@OS3D zw*f8(%AQegY9t%0;s)dL!01!ko8W3dHsLse=P%SF%iX+iyyvtD#Xil{Z){@oy$wHR z6V8b*dS}!fd@rCNHuq=4wmjGl){(BXp(}%5%6=ci6?&Ha!X6Fv;_ZDO$=MnF$ag%s z=Z4Xrl_#xsKp1mwt$jZm@da`Xco)LUF%b*e*RJA-nDaE&1!7JO_p28>Z+lVkOT=v8!d2Ud zhvrx8`xlK>y?6P9zD53^1vkhuQW;Q=0J}zBRdHh7IXzn*CwGB*%alWnE6CGE*TE>0_X@e~dQ!O% zd)0iyle)Vb+)R}U_;mYRT{@XT(Rd!p5ba-M&+MGuBbN1lN_S7j0PP_!M9z+J&6eJd zgMWqXMY0~bKhMjxS;Wh4#$o^;%{draqcR=DcGXH^^OLstll&8HzM;HCdyNFMXoNiq zp0AH8@en!fjC8y<9<{^A^pX6JacZlx%}?_)w;z*IpN_bW5Z zgJ#3QPw*|ke6JEtHO%`)#=)fdo;HBLS8qmD^B(yt=ADglZ`d}*Yw{Tn=U7k2%V*!9L-3sHDOm} ztXrJbo*!uwJ{{|)CUX6+Z#7Q(o_(SP+nI7M{w{O&Y6@PTU#5j^p4DEvLU)zOM_`?0 zg8gW*j(AR7tU|1)^UJdkU+By>TG-}U*NR2%Su9-+%-5Rfyft6_H|HyGFyET5^|$x0e&enEK)27&%hwm+ z{g~pph;u}qXvWXz`;*pS>uwK$RU4G2Rd;({{yh24M((+$egyv0c1-UXs28{6(GvMR z-X*Wz(mN`j&zy5T`EFM3-GQUI!;@9KUp=la7t`I~c2sKzcpvJB_pIq%E3C(!J0E@t zU&Xykyfbt{o{s({?{KB>ot!E6cu^lyjGYJX%?=L6DE)5Y_a4Wvu@c|6R7||9cv7E1Sl|7E$9uu&o#UU)$JQ_x44w1O{m-aNe1~spP7ZX*!gor)sT}BCVtlh; zTF|8y#`D|8iZOjx`}>b0XBV^RTj7fM_U(6-OBK14f%mB6vGAp<;TxIrzGrnto40B0 zGTC>5BgfrM@to#I$K|daV6OuB>LvDmc;el;cu$LQ@7ecnM)*#07~sA6r}mJ*w|8AF zI9a<`U~F9N&%I zd$ma3G^ZR@ng_E_`55|vV zlBW&Do-PKHlD%E>M4WGF?|JMgDlkyKq0!I#m@XI4=+aTtt1#?Tmu{nm-LcZUv2zR# zLBC8d=IEc&@;p~Q>uk%Sy3jw-kF18|>{8Gb(76-s`9qYf{63?H+2xF`TGTA{FW*8& z@5Xdc3(FtHPI1(}WE9mu2x%WEsvqUkBddRbu@N2p-xQT^rgZF>lom@*m}k+sC|^i# zq;#pUe%B}a7UNGgtitcC{yj_Rxt)Y^^h!#XGok0`vQK+DiSdigitqzRR)3Er^!$!c z8GVtK=VEiHZ_D3hU|!}^<6Qa@`J(YwBf%VPXFN0+s~wh}tzlO7qpp3#rSg?oY5CnLg$6$%Nac@l;dC9r*4V<7xr&%C3IF!yPqq6Ov~pkVwd6MP<;}c z=znqve>1wv61v>YZpc*OHlw$T=$w(*ME|owxvsb=os%HqBV$DvrsWImOZVwA)t@0s z)P8M~F;5&Vibw@~lulzY4L4`V0BT4m4~* zFSO60ozSI1`Fu^w3r13BaAf_HMA)JF^AR$mj}A$Ez&2~?k@{rPV!_MfRc|wef?-^Y=ME)23+ZT<0Cc=tnKc~N0`HG#C&&QXH-nFS1 zmCrFPce=^wdQtnI4@vpF{%7^uhs3`Y^LKAYbnquDlxzPYE|2^swVr(Z=-ZKhMgPY| z{o5ZBdOknQn);NMmkRZ>e~A1q>dzh7k83-j%Z2>^l$OuuKUsMPTO7vvQC$Bap^HEm z-;xGQF%NOis_P4L?0a!{|Nm}n@vwxvzO2zzw~l?!j?j< zMf*AZ$mpsF`yxKlK;k2#U-uIK$D2jv|INaWID=?M#j+a~_PzowVd6V~Y^_R9<9e}7Eq(rs2g zlCh~3MD{uJtG%;&swPFOGgx+k(f<3aYEzd?+I>sCH`fS&FO9mM7H z-XUMUL`SrUATEpF;-4l`M!P9a!=~Y5J6kA-UaD*OyW+kc`BOnZk_>-%r2XNM8%yd_ zCYi1&mzb+9MKvT-rs(N!aZyqHgdYFK+$_33`fY|*yuYqBO0r>SCE|)4jxDCBVlF}V z;s^e1Y6{~##brVfj6>iuW#!>&Apkj49+XEpD$gI3r;6k9#$#Fakkf)PCf}mVY)V*MTg*AZoJ9$zx3DN-Ej0(*uH2$vbu~1 z4Uv{o&2X#e)k^lc8tq*ib(e5(>VA`HBXlx8+N6vfz024EQEOD9%bp5!v`}8*U)rZO zj$Dqk#=QTyq4GW}g<`x+^-;G`zha8l79FAr2&poZzsR91&&67TxJE|7#+Y0AJ1UeZ z5$oz#>@xUF*+K*yK>kx+$)RM;nG@NdT=GkpS5g1vR*3dJ+`~iXo?QY%^op@ro!vKc zHYv&v=giUuBF(3P`m7-}S+Mg}9#|SKkW6*ae|_0tS53iAszhSQZ9++(TXHF40F4*r z!9y#@Ka`ePcOUg1cSN#I`p;(bZ_Mlu%pIx!(Zqs>W9q+&@=7^3*H8E-N%Qb)Db_8) zzx`wU6Y@h}ad_HgndBGo4~uy;^c??e>OYWq>E5XS;Gbb;`lHAF_gUY=T1P$)Bpe~W zGa$c6hXUdl=}+6Ob2e~#S&F^|{EL`(SRgz+z18@&X(j}IO4hyE>3;5+FB z859m_)%U44`7147;#Yr^N@1x~ziPIaakyt}fUnX?@o-#`7_fHdvyJCp;y$|8 zWTd|&|8`5)N+pUBOqTX_lwV$9o%;xB2pN>ZPL~V*(fD&IBnlkRe!HgQ%Xg2;OH4=k zu5~5sFOjvF7V?-vX}+O?H@6Iv@btwAN)uA%-;x^a9^~)kweSP>?bk}7xI@| zI%eLa4m?E5vd|w1NuRGroU^_p{Wm?6-He%S7q z?0?I2?1@f8-E;*1ME?{0h5hVd`-%PHu7~zZ{O?EnhoU9r;lHF3Y>R3Yx}`z>yiK|E zg07XCQqup;J}qyF_J#cSx>5@FxKl0}^_c%c|JQl{1^pxcO%goKa8J~%f~_T$N0&9BK^}vSA}~M@*hF}#2PW=PvyJ%JyYiGELjuwdlQw1{^&x7 zL?yBQG7fJ`{TKQR{&8i{V3yDnTo>ihJ~P0OmblvDtEPVg|I=hG0e(y6dDtZRlm0O= zX(68`>|YAyo*eR)7WI=Z!yFh#i!`SQxc z^18~nn6!xsye=G^;2&?z#acjAFSTDq|6#xCcjG7j54R(h2GA0C1N|=^Klx8C;!5H_ zfrF%LE*KsAkIfuwm*l_s%Ce&6D#qLv)-_%9q0oOTWt2~jWc=iRFX4a1G*8HXysfjl z(H?x3#4HeZvdDgTUBurmnIOkMq(Ay*!MsR*h%IaAp0Iz?pUD43`eW?xneH~w|H%IM z1_F^+EaYZBM@;0W6aUbP$U+nhIai)I$K&#lpP9!lNXpx_U6Wm%L>gGCV>b~49+M;f zJvh*Vkd5=@i5xFxgvz)GuB0ZBUtbZfSKOYG84J?#50d#v zO(Ch=t;sb}BB|_CT0Yn6kc1o^S^29{De0G)T_vNB4mB6Yl50%LOITV_c#z$64@F8M`n5spmR6%qW z{VOUKg)(}SN@TyIVG8>aw@2fN{uTTZH)zy@sDI+OIE^Y1zoYqLc7H??&r9RJfBgC2 zzy}A;%K>43xuu(LTFcxY$Mo+TBK<777Cq`qmLR?_)CnV72jza9IjHQ$$iZa4#aUnc z8#$D3Nc5A3Y0lV26iyDhUl*Os8cI?_QHt6sY%i*TasqK62P!&>#xHyEo%uEUTw!gQ zjz8*793vUXgJN|Ne~Q|RZtE%;+p>aU zy+X|IBBw7XS$ zn{<3qGH|9v~R z-=VIx8@kdM_8hxy^>}v>xTcQ%a;GCsSOZ6C)x4+w^Vf8xUBNf08(j&M?uyHOseCf9 zceb&1>SnmVzdUkzHnf+zmeix-8`g_Ey=C7|tB2F2#0jS?!26PcE-S~&wQ{4YIcy$e}9MX$B+2d9duSlxejp8fW0;gW)i1880#kP z!F!n-;EkNT7>`1i$7RXiw}}sUca(6iZlvK;?~scA{PlN2k9W zyNQk9FEF=W>m5;D0|woZuiW9_PRF~!*!v%Q-CLb_yukTl;LG9Kt=4^`$L?l*-~S=* zSpViZWr0IMIAtbbegHoz+nc`|?g=wP;N0m_1K&^K9p4oz;~aaQ>!R%fOs~f<%l?p7 zjRY?ZGE{FS!`YoKTmI$FITCuJ#c%6@r_w7(?p-vdYOujs_1SZhD2V{eTCh`pu1 z!2s;B3tiTKTBpyBGgf5hKf)Ok<57NcJmst|$F%V1%@1 z^ZC9vRKmw_=^n#ee7-+m$50IG`wsu^U1fv)7lEa*#okRf{h@lAyszF8PJ{YO=*lD5daom^3~LR3HQ6^S z#`{b5p7W2!FvYm!z&ZcDuP+%lN9IPL8bW7s}{SD42adqjrs$GZnOM~~RsgSDiW z*!&ZCY)@hxeuqQ%#^r9j*Du0I1J<9~&Ews@ zwGF%g=kSf?2y98J?DoPAyaY(KCXsc*Z;~N6(ngvp?gHF)3plxF0?zd$87VrhSJW!44IeEk-TI zDWhkbOmdz2yz;qzH2^dX8BAWqDbNaN&gQoF03A``&t) zVtJBYR?S~k=4&-C^W!h8n+Yyxu;AY&#v)2QefBH zgPZz@VpHqpa4E6yvHus3%-bmU8?U4Hde4NzbDZwR)dXktA>Qj?+EvToFaHB;y=C`WEv5yhotk&GG&Po(=EZk*y9-%NgnP{G?aRUg*=h zHO5jXk6%X)jeER*9lZyy2su2i+u3vNx8QGzQyb;k&wAeehW(mFbLW%yOYzGDKBDle z{U7+KcaS6D;zs@YWnz7d*Bh;Yn0@2dlH)p-_PIb#Gk91isV7*4>CX z6TbBo|J}C-qbl&0Bl$dH*T392 zeuT+PYYoRo=D82QgFVl)!sj`fZ=7 zyL6`KA!myx<}f%pifR8iA6Do)wX;5}I;SkZ#n*l$FF4S9cISO2bT%J!FoxKdr)-~Q zs=3CuZ-M4S3E+)=!I^c|qz+}9rAf7;V+Bk&U3XffXRoOhp5cY#;luMG$vJHaa# z&(9yR)3=z@p3mP_fz2LcnTwe6j_Y0L^SLv^+U`$A*Fh9ljQ3ohIGZf7SIk39abcs6 zy_@zHbLTVK=r_3ILU$j19`u`!b!ZOVp((K9ZQ$NVIy`>teeK^mZ!oK~qo)eY^43fd zwNGrjPrlSx#_{w@N7b9Eb*;b)per%4C7p%-kHa~mhLt7hrqK72m2jj zuZkA-?0gEFyythMUL(Z2Ov|a*;}LrcK2>IL5#2v%hL;R`WbQcLBT{d_*|U+}$=bO? z}SX@UQgc_)TBLo{>*ssgG(=-$y6o zM)EgzQGd$0{UY&R1HIe#e4dntC;hvzB>QM4@7q1UUxTT!wrA36)0v0F>vrHC*`0^I ze_V;R+auOA`r}WKe@Yv#m8#A**oVOYZS#QUt|4F0+;xceZ|i@EF>ePu$Zkb@Y)ZYs z@k-V!SIA|c7ubIl$Ua*g`C>^=DfZ`=qF(dsOa(c^b+hcBK1m z!q3F-NIq}>Ql5(BbRxuzllN#0wL9s0toE*uyk_ z*Kq=#KBb!~?z#5L?c=nnZpKyhG*i^Ix4k%d<--ztktma4Woit{;NiYK)vQE!qWw9E z|7#EX)!shUPR<6XE__ESCH5HYdrC4fSD`jhXV;0gqmNdXMzVfyT zcw`&c2TZc??46#iLs#%j*AC(950^Uwd)f+rKxLtO;vHJjrxW&24k7=8Dd_E!(T83p z@7PlPcwQ%s%yMQkw06SU!^(5_8h+eX;nOi5(WdCjOiF!&*iTp*+&*`|CrzwQ2w8 zOa?mks1|ZOklv6yBY7{6xpwG%XR70fzmxa1CQJ-l4>bxD0{z=GBw3NTM4$E_=R+6#!n^u(58_WbSD#L@0>3SxBPQwmu)Ur1vbaa@Zet8bA%Bd8>ZEZQYr%Ky z7Z!fgM?Ka9_$}xit)E9eH;Fxw9Yy-@lW%)XT^D+pv;Ww>8NGW>u|MJko?}Zx%nURI z{BEZEYeNq+-`2;w_UOMW3%-eQ&hTDEP7jF|k5!DB=1+;vXY~mG*4`LdeY6j>@L^`} zdml#M{6S8!51R#czs^=;4UIA&KK}vUm99b{zS_ z;aKosc%-qJ6vy&Dp(=dBL%Ti3c1t`I?=QhG2D26e-<(XBWXO`_z1GEF(Z5K?L%e6R z5xJMFP#w2ICq&=6hFeH8H#e}!)r{w|jfAIOd``$Ct=SHc&6uX_f4d-empG0~q$Z;Y?FxBdkG z^LmAEFCXP`?8m?3@elnu`5kv>f6nW*@UyXg8``*RC{753Is8xdj{O9s#?ADY<`vM@ z*zZL?PQ-z6%`K06jWX^_WjO2OyMM}Ew^kg(+8Ys9qB`mMZ_ZB1){WQ}P``}zGWcef za^X{2Pc-ISeJCS*7jl8L-AH$WDd~)3FS9#kJ z^`A~PQke_=XTeDBHZ3i>r^R5j41*lzMX$ZiN z2%7^4LV#Pd1Q>2~<*B*j)GX!X@^i+7n9bFFea9$a8fQn?*f;U~#g1{d%-ZTf#fau`(` zclduH_Ouz%dJGS z1H)l!q#apg@FoOLl0d{ne+52?z)ul)DiMB(D3JS|-6RSX-xRg-&8-oyME&0ul?A`E z`}0^SIq29jNsb2u>B)eL+ux+Wlp{z7?C%WIbxY>7~OB9@8^(W&4|HlEzs zqPxBp`i8$8dw0Tb3>J77m~?FG$mjrEW;6jPQMdpqrIRy2mfBxb??i_!%^db>(t;j}N^QFjZ{tCKyc8a+eka>}eixOI+tl zfxsnSN^l)J?GpB?)0nStgD{INzGc4jiL4gw9zwP!U?CW?r?V@^bl5<^p97Ftp`ig5 zlg4`IRhK8M+rh?kb_q=Pb}hIC#`A#>F0iK3Ae`_!d{2pYdfjre#wYu7O&JDv?zYK* zZ4=B}@W0hEeHSidGZMa_B5a%N8<@+X5?tI(@jv84d!U`@`Hr~;NbRXz@2X<2>77ga zj#e>e;vQ@I#$hhE>jvcNZY#@CHClfh_W_W6<#fjJ;sSO!@*|h^V?k&_&+=HD~+RkPx1D)!|r*9>5yUD2Ms^>OR7g|QBtt7X{J<%6Mj zEW%Bn7gJFs@!A5QR*;kwMp>H%q^xv{-A*1R(9vih#*2lnj}& zfa#a5CP0TI0=(Btbq#-gV39!}QWO4LK?CLu{!s#08z-F8jlvAe_7wDAt}%AmyT-^A z*3l#&fU{u>NLhBtf?aXcF$VA$28?TcU4r2(?-4ZhYrtN2_;|u(!y+={QV)Y(GD&aZ ze^i6|?gUn-@)ZJF1K#Q~^dpd~7BH11rzdrE-VQArAZkeT4 zI0t(K@Gv^-7yvWwGl1(|h1Q&_82+-U0?gJdg`0x{)NLV=7MzVP0i#d$cesU18)zlz z!YNcjYY%AquuJ|+M*S`J2jN0@ai2r@zZ!3VU?DkR_14!b<}?+iNbYPXNF^#xJ| z{CT~8g`OS!ov)f~4-7qk2Sdq^{T4gwe4Vyn&-2H{9O zZ{1YqA(sF>Es-54mgBRYUNUfQ?Pqm)eF4qZ$ilh96a~h#I)Qi09-~LjQeEwS z6V|Hq>QFW{oWQ6g!nyBWfu;n|6wAgKSlwc(bAq`oN$+~ldU#$-j2357tK=VGFD_Ch zL+Z+(z>|)B6fr{X1KwT*TCCwN`37?rkyV=T=B0EyJWZKJ-@3E>Ng)NBm)?; z&Cv&a-9+@|`gKd!z+SfhGN^L~G&3Amux1)J#-??k4+aDdCxpz}QcDZiLWqn+YiZCS z7^b(EEhS{jHN>mf%{5sLlU0V%Qf4n(+EwTvg!h40FIkH3Xq3k2H`djL8FrPj+Yvn6 z(XK)X*w}z;oiN#1LURP(vjOAeva|rLhf7^J48`x1LcT)(H3@b&Mje(NG~mM^2_0Oq zE5JX;-SPq5mdEwk5lAqweK-*VIP|*Wa!gTWpZgKEJP?JLvEytIvRVgrf(oz5?J=vq zh91j|?7<;S_tK$*vC69r(?eNqE--dW1yR{FeTJP9QBtWhV>p5ruy_FFYyu&k87kSa z;V%$djx4;Am?D=tdJ|qsW!lJBO7K*MWjK^uIlK*tUHSb`?}b<34+lR_uwf`4=T-@_ z=r^P$VZ6iQr4_u?5D)OAGFyiiXzU1my2c>kg?(m|uDAKMrQ?wz>N>*fcYrr<`F6nS zvVy3HCR=cRKx9W5j8|Dg=qgavedme~m}IpT+>Z__unE1XL#6G`bT6~7@R9TIO2-&B zBVCa?(64oK?9)(Nx+)O zM*GC)p>@k(rY&@se#1ke%C9hAwCvrJqSw)Hc!Bv2N?Nl%R+AI7$HMEzGc!t8GzX`Y!1N|5pPabqzeOVSrYQhH6TK8ilg%Ah%bE%I8z&7|;*#}a)1Ww41AG5lU=@+}cvG_(Y5Vw!$e1D92 zybV8v`3hpe zKHEoSxgRUc9&L8MrCL2Esq@~PD+7uXgU;aC;tt%aWg)p@1~&%U+#0GRY{~Hz#CQh3 z+_o*K8uBFstcAZWA&7CrDK=;$uG)@s7YSR=3GaMhSEF;ERLzKv7_ebg*C zq7A!5WD23&Av5IzWOB`}Ir5rj5B;kl2Kn&(NBGu$+1ktK=Rq0i=9mRcwh=qPf6OKj z(+T0(^7*D6s>n{^90~Ijxz|R9=W^vE_+;JK$Qy=m##}v+wfeTVLdq)a8u!2#C`XF> z9gSC+Y_jqtX1+3`>lUutP8b;Cv zhjm}tltNgwq~FSrzrJg>!Oe9iM8DwG;Ww_O75HGmI1l3?JQb3R{qP2ow5&lLkq0(? zS)v%qqOWvlLkF9?4YiF z#;nV}aoB=_4F8PvQFeO=sVv~eWySWE5SGSVlXBHA4N5s^p;RklbYnJH!hD+=vQMQw9K=88llq=0Q*>w%YqGs6*TK z$6!{yLH@cGE?elQJyshn1TWOE2R{Q$7vVZp*aoxavQ;LVuolQuyGYvBSZBOc5cB)+ zvE4Fy?H_pw{4fse5+1g7N#&N_64O6lM?_XDxD&r&(~uDa$+}j!rFlZv@^_SLn6@)k zp*u57NA;0zYnW+~?Sr2U+1{J5vxa2fkX}{QWe6QYhTel0S(ba^o@KoYT0>ekJMcSL zO<^PWoNKniJx8Ta<>jMJ^N_N1G%oR1M(4K5)gPg}(R5^~$P-wvkC14!)jIbBpZ8d) z1Dt1X-y}YPoyhqt%pjKvCFLv0o`7BGyfij2MW^Ltz$QbRtDdrH`MPSYVd9uiz(>G# z6i>T6BJrY7iunwMcE%PBrt3SJGq9ob6*9iD(37!q108{nz2tM&QhFte+n$Exu-Y2I z$}N8Q1G=(oZIGeBJ8f2&)#=vUusp~~T7J^6IDELwe9d69q0;gZu`suRC83PVYRZfZ zM;Trsf72;PTg_m;oGYC&yXhmzVvaSdv->9ZF#jT@V5&Y5(>+YClnJWidYi)d0x^Qi z-lc3IBXE@70`sLkyO9mHJff{3%3#3*W$rMI=AH%<8^NBOQW#oXqnL)AvkhN@nFHGC zsaTSMGo$)fkjDs4n)TYeF4+h;ica$))W5M1P&D7~DRUb>y91eFLN$huQA!h^FLmUZ zy0T|(YmP`y7P6?@4F!=RBm)usM$XmquVk!V3~gSAhlitrWkD)+HiJwTD>!Ga?5@F# zWnJQwb1}?OmDQa3C&9CqZfwMS%m&Z~+SYeELYcEyXzon&F6r_!4QgQ3VWx0b;3MWx z7A8Xck~QI5*d8+hi`fRy>JeRQ`WSW&r#P`#!&zv$jd=>9FuR2poKvjq!-p}uXC5_s z6SB6fV8x(4U;1O%rOE6r*rkqSjC0Je@PuH(oblRQ?k65XfB)d)UI+Ich`%%mAJ7sQ ztw^yl(Kap~i)iS_P{j}?`{r00?I#c#R#7*^XV{nlr$Vb+;=U0QU0O$Uai1tNL_kHC zF)d_ck$Hh+ox;r!uq4t+~3i&Y zv>a@(#s^>T_>>PrwGEh_HO9moQw{JOTF0CjG7M{qD4-)(=|j%E2VbFyr!U|QTITYm ztJ*X_!F=QkngE|k<;fnEF7zGo5q?}z!W&xE^ud1{zP3g3#TYrvM^TCQ7$|hg9#f`U z)34-Dh|){yuZ5&uu}Xb%amw2BZ!SG%88>V|y4R+8&-mCJ!{-hl6Fd>PgETgnf0nVP zWA!CCEUcUGf$yWNuOpUK)(1(|APtr7hij1iE#)$h8{T-~4su|fRtWmC1{u_l`}vBG z2L(3yhBil5$n>qUL|S2uLkSF&9r!74Ydr}dyyHgvhh`Leq~G{$re z{kOrGgE7EU3F3#Ww82w!9(-vJj`kdxFUH-uo&7MNOUy!%lgp#~$DhDuov%>n#-rcZe2RdOZ9VU0Uj#-PT zlT4vRa_f@PZ`azG zH1}5Q3PObedaTy+S&MSrw|t|DWm5RRyVoE&*fXX(SCGFQ1^tsZ=i zYGK~#<3GMJ!4O;Pqn}I6A2AfTP3WA}rY8*=hY4cp4f*zwLz5nLCYFR55f;uC#DA}6 zFc&^WH{?H{OUMD0<-y{WJ?1y1lKdGCc-e(dxapQc%$qo30A^m~OIwgn8+>t?XC095 zX$Sas+Z$j_REGoD&wS`tI7W`@J2JA%?w7$4!Kh_peSq;Vh&5XktD3zQ))zb6!4kP; z?3tlq5bHTZzRdTCLf%SVBIft5{(yr6V`1;%J$$r=S+d_os(oCr<`fUzR;3U$$+zWM zG!B{LK?Zb%+f8J9mfmmhBVGf*q`|T^0TipgZy`?|OZO15%=AHh0OuJ_ zZ$N^d>8M+*F!!dYEb*Oe^%gYp0*kWC&XS)RavC&0Xc0Rv`8LJ|YNzQ5w!hu+=i zCSkB;sZ_(NRi$@TsyDR)11Y${j+CV!HsxU6C0C|h=F%|fAiP(AaGqCc9UgKqVb zeXN^f<5uGaJZ^|!$G0yAa7fu!%(*7Vri;Pwfzr~&&_!~==qvT8yCc?~C%-l>$s{5_x>zxmnfJP{+{^c#D1aC$@8Dd_qDPL0jWdS$YyQT$)81iIxz(<6cb}6&7uz0w zvA?U0{7$BO=IBSw?|%n;X?J%id2K#y`nu5NQoXBL_(#^2d19^SqN}`OkH2JDnrpyw z_3%rIQWrO@@De)*!lx5r0_a1b8{G#xjZE1De@MfK;}P0=u|Sm2*&qM&PpZnagInTv zN%E9)3 z7r7_-A)k~e``z5cwyU38!(227Tnzr-9<$5w>;HR(GSE|x??$GBrUSR&xi&Exg6)z> z3~O%>6hJI{42z(AGPo1=N7uw;1NZTQb#u&4ts?BBB8&}!>?Ib2bEmt`t^dT(UKtnq zr*95+S*VgLp?@l{$jl~N6IKocA{!JFHBKiQ-5Pu)FT!u$;KOwJ>Jt;+k1L)T?B(1H z@(VcWWq{c1eRX5C$M_-n-NbFMu?roU@u}4NQpbYuVi0uAwtTMHGSHGkykz3I*ox#S zs>;rrwZkH@r$YA4$?W(JCI$UYYU6riJyyJ)c_-1eVc)6#HSTM~7;5$#(N;y+mLftt zQEH~|j$)0$n#fBJ7{7df%k8E<8E7wV#U5U|U`yz;qhCR<5WNReYO=pDgX}X~%{KUp zxvnDOe2?*M8Q&A5_4dfC3HxA6L)=TkOa4PI#T7YEm<}c-Gb67h@v_lKioPsd5`i=3 zA8(?>d;Q$#iJ0_cs|)T+=D};BM)+0L!A;)zR`rVRWDi{?*iIF7 zZ8!?mA%KRpH`T zjN?&vGy^pefBG)+i+RIW+Gs-kX&rCTn*I-C`vzItjffJr?=4uN#OC@B0$-DVQW0NCrQy40KiM>PKacGI5eB^C} zKc?~SxPNs+WURf<`Aar1?wYRmxuNa}DdPS-YeqhL!y45*E2dd|M0B!h<|^E~?;-J; zt?&_TZP~jBQ}D_&VErO)*-gRRoZv4O5hpiz!iNBJ*JT!{(LnTsVPA)FBRqoQhHu7{ zOeL>jZnz^RKC+KNXTK|*qyhf(5=UR@7v_Y098~R?P$4<|V?J(f-?Z34Re!*1=P+fC ziD9rybNpQ`d%+l9Cu)R(p-l++^t4@QQHK<*WjFHoJz&LU@{ofebA-e+H>xu9OYLu{k= z^(kOm~XqF<-1ghz}XR* z*axhS-7tUpf~hQzzQNyuhcG-G`#zqPvTJmBA1up?+@HX9;(*+wH(oXVPSTOVl;9Lx zmi#w?8mbD2gPYmJxg4eks6P|(}*>6PE6z_6?*uRvgcTj0GkmPfHr&n&rI=ay)=g) zQufqa)h*S+t>Sv@e=%Y&0P(skyg%Z#x9+lamwj|ZW=55Ji;w8A^oCRri!cU>V?9BL z;(_cVA9Ba@*=&wI)8uR>_V$fG+}7AolD~EZ=&i!4o*{c1^?|F6UUwSg!@cn2qps{8 znFr#cP%y*Mtvhj$w{bPWP(d7l7y#G}gGCTwK<9p(`IA-?U@ z&Cao#kUae_*PM)h)kY?s_*w_83Y-FJg4_tW1wR0ngqTpFGhok1jKTx5M<=gbmq+6L z~vs8J#fjFb}OvZdrqxx5;NtG5VT1LF)-iSxslI#w# zGIO;c;^jNJ3}o}lHVVE?cXW0kdO@93ZI0k2HxNkk*(b!7+`-(dve?z-A9Doi zW2?y7g^mNt=rrgmFw>-hZKQ*0S`Wf+7lF6f*W`^;d&pe6njjt7(bC5;?izXJm2nS z?^N(sc2^%v!Anx^d02w&t?EN#_{>0scYlvQ2a_gn#Vsp$?Gm}V`Q)v;s(1z8K9pvC z`a=2edmS6`7xBjv78~)QO1BD6$IGzddYM~=N3gDad52RoC5wt67Wi70JpCdk=W!03 zWo&N_Mvyq+&I+*}1KpeboD8Yd9w^ zW9yAKtTFpzp!-|$Qjxq4X7cM3DVI@&)S}UMcVe;Vy$u44jQlqSv4_}jFZ-*aT;{O9 z!A@r|>e&j%_i={K%L;Oxa0u7~)~@e*!9iGe3K~JJRu_264O#?&~&37lez$S+ZIp^c7u^}%HCx#J%K!TFYG8OFE|i1+0lPDNhX zLv|YU_F==mi@27aGmmROM#qsG@ZdRVopU9L)6^|Ut-nH=(2dV)xdKb4L-7gLfXBk4 z_H8m((VjK-l`~uDdr*-fR>60WddybPW#E4=6={mt72K`|euGBfnFa51pSbo8KtLYo z;ogY@@}-M%W+DSFkEK5%FO?*f*`a!XzoH~BmD~ie0&jEU9XfHqzC?MtV1BBb04Xn0 z_N%)?p7>^XD|h_jS%?!adssQ2I2l*9h|TC8+(G@oPZ>fO;P11(7efs;YKFKwI^a`) znUfwvX-{wXq}jzK4~DI~pkA?w_Yvzmk^CfHRJV3^;$QhBR&8xS8}N#1fePIeU)VUo z5E=6xKghqtH1y;5gN@G`;m43;QX>E{CM7|B1 z!`Abgw^2tPcyu=@G6gRIdlAg=m6=zjK~>?Mhz4&Df8>qnki6(Igv{ZADVvIUJ3`9z zn~Bdfm^Jvj#;v={wc?cmy(}2hj(OX0{?+MHZ7*a9LM!IMG$S!*?EuHG_pl>dO{p?loXD!O%m5Mm>3v!)O z90Qr})QWwasExs1c^do-8!~ipGH)tRa8gD~e8f`|N0yb?;fOs( zj=O5mn58OB4Eb7-)YgLa!(+X($7LNV)T$mZXA-Bt@5UfF_@+}Z0F6u?9mTzFL)$kw>sE-i{VN9j`V^O3Y~@-l#V;O+5yLmooLe&)O?9%Ivl>EuRY%L|hL230f!9L4>M zzoN&e(vLWDn+!~lD9VtFV^3I2u^S_XCLeq36Y;+WU~%mF%Wg`JXb00@fcdA--;7+9 zH*WjFy7m{me&N1QPM-OjLG%v*AwIq%F36TMAm%xSssS5O6g5OKy03%&9x#70;h4tA zL2-$`fw{==JH!Zfv#4AupG_2s)S$bVHrNtwihop{%qOU0M0)HIdjWB+;%Uqy4ua-~ z-(TQ+2YzMp=wk8--#0#tcH)khYa6l8VJx&zEAK)5XTFmcDmy0d%+zjiW#{M}D{F=A zN?}7}Hv8MaRxg5|EH(hoS9>wBLp_4M&UCDYc06p-Ju*qgF%Vw4&2L4__ys_Jvl0s+-sTg!U( z%%d;YTbW-y@Zs#F)#a|v>~~wI$p6VKc(yA|dDO1aoj}4tsSru^STi3EpEngoTpHN=LX?r@phK#+f3=b0{V+UZAEdzI(J9@?$m;EGji z_EO37uc~xcgQ1A7sv>y?@O^8*Ziq|<;FF`)J>Ng&7Nj6cEt;EMy#_2}e8Vl4$F}44 zFPjbZs7HDLZ>bk&z~FdG*bTQjUsdak=Lu8>)^)-qh+TY~VPpKV5_cYg3=*oTfluHs zYVxb%52ypMxL`+a-Usayt3#IR*d@<^+ikg__9pPPXY5NmJXOiA;B_jA=ew~4P%i#P zdXHL2?tgo?0_Em`4K~OVewn^XQfp=Q@;5OclRc==gUuG8uv-mo@B_^?8lZM~;v3x7 zO%AstPEt^nPCR&sH8Zi} zIOeV!z7E|}p6f>4mYsBKd;JGTtzu(WajXBQ)6j9zgK0JG()ovJiI0AbF}vZ7kP{*W z?AQnQp5jxPoOlHAa6cYt(k8t%NSJ*REsxHAFP&RWPQ_+k+X@mo#?m%1Uyu{XgW!Zq= zzT)cOfg!u}NfCR_Wm3FB{NEVeN(BQt#`W2;DkWMLu$Z)`cHemqJp@tNXK* zkfZpQ^@11uufd9a0!L`|Mu8Xa*d4p1v!>3B0YjMPdJYUDr_gN~?!= zm&3GHpzZibu&{2Id&rA8sE)UX7$gif(;+^H67SI5#TPUcSeKKh1Bba2U3yTF|$)TGsr$= zw=5dui0uu$`9se5b0|jjNT^P^l;jA>v%ux;Fj-L(wFALwK6&bV%^t9Ouwb3M^7n%j zBLK8L_S6eLk;Z?Z_6rvnX1@WmrXqnXkXzvK)xl6=AP+0vTpH%{S#jM~?#%4lKPp@i z&e;Uti#ip1vjO|lxEUsw`6iSsvH`Lu(cw+@n>$4U}xQpnG!1IltI8YV*fAh)Au2aJlf)rmkywpez1?- z#z`ns)RB8tZN5`}5hmu=4?c%?d>FxTiVeTlXUE;C<9g+QL@Z&3UiirEmU$O@1 z0f{+-$v@bR-io32xZpN#Az&&I`hq zqe=E1##7X-huGEFBQ?cXVAS3$JTW12ABnaDDrZg+>;Cv4)KCrI{n#HrqWG%GA(t_n=}pJJ#H%=eV8?;? zfdU6vBz}IpIZ7-;HRaAHu{(Bo9%{)2nPQK*i+A*~q5!K?qF?hVD_-Z=U!zBH->5Z9n1jJohj+H}XPY86$rWy>6F{r_vOeUV$gvHVwJJ{KRsFAG|1Z|vCpUcw zDhI2{ond#zJXbeNV?0aGQY&XIrXUnLFfwe=L#JK1RFl@_(5&&NLVhD=Xj;NLOcmfZ5bI>A}a8MncI$I?x} zVN2lj%J$s?cc1x1RZWSh*)xP4yB)aiRaG;X6TsFJ3b*lzq6fArlf0}d!ao#H2!02R z1A&H=nTtlLZswyhPV~vTGrx{U4&s>ar>r-gC_W-hT)ZePb2gY$f=i>kh+uF^J$H{- zlE@gwhR#9U<0hCn*fdz1#>C```I5N&mhW(a$hcX5H_llVY)2ybqV!n{%g)QInk+!4 zt&Ruuc&`PgM$9`nA(+DTMw+pcpn?0k6KqTNJb^YEPq`tSjXAjJEoADnA9M~txKXg)iJs!QGTTSvls=$``rv;erJ^T;xRsrZJ|?* zdQG{PeXL@f96p!#Fs^7%)y{w9b@Bt~$?_P+*^=FYEesqg$XSOG`R&|%k|7VeQvPQT zbkhI|Ag176WjewaT9BeC(Sw}XZp00=InQZ>SABua#Wm7{C9qZdk9dt<4o7*0o!JE- zA2!LK^HX_;PsR*SCq?bGvo>}Vme9Db!H@jV9DbfW)82q~sYwrvj4B2c2L1vyD#OrY z)5dU@_{UzHUa*;b&kNJ4iDx#Fx}l_kmyyE7XC9>fe6Uny`5R1(_#-^9-h*ZE9w<1P zS!IASR4OMubTsb47k)5xFbB1B(~}DBlFC#*&Z_Ja_?)6+0^~(6AwG+K#9rh+*$4jXXs6DD7X>I&xB(lwGF#lGQO9YO{%$){YHF@=lIt;h{W}L4WrO=$&za;~kQFBEcw; zFx8k>$(`f`?P|2DCJ2PD_IeWT+n3rTC_3&(j0actP=4Elr?|0?sLiULBE`58cLNvr z#$L2b83!r@zYkE_3J9t51$%KnKXJEYf}K?z*YqzrI+!m^i;)@z;~*?4D<0lYRERQ7 z8D}rXC6O`q)~NLmE#nGT*nF0xymge?-@&~vz@K3e*}o&W#vk0P5HO{fA5i|h7Cc{* zLS#Ne=*rK^zhgvov_(JB<>R)VgD(drVh+eTu?BtGb&)SCFkn%2fdx+AiiSwg4vx>H%A2M?7Aiw{e5i)woNEo} z$%QimPU#*AerWNK1$-uPNFlgX4H~;dBK5NfWNNsVm*9!g$*rv7hum}YX$qTRfshX{ ztcTohyZ|193qcp$yORa4T+hSrna)GD4Qh9^vv&v{_L-w_7hv)M1;t!V7h)`DGwPQ_ zN?F*4d~STc%}Mr{5tW8yZV1hI{v&{bR1Hvr5#hn&iZy1BT*#X|!6fe9eEMou$d9UO> zl<0%^2*Zju?@)L6ZVVdcyOSP1y_xX$3Ah%4vE&r+f(r#Rd_E~QH6if? zLO@pL;|HImmxCaNIxC9tV^)RYd)$ln5wTf<8hUjZsPsK>%H8-tJWB01dxF2?a{K^) z?`E~rG~|oazY#UyxNG}y((jIb(3qX6UuxJ9?7!eqQ4=@{U)`c5$`#H=Pd+27BQ=@e zf^Wg0md^yTZrGrS`GMEn?ne&2E1ZftGAs%ikGTgOAD`D3_=?ohe?R~66L~O8a*go~ zpK;SDo>;#t|H+HSACO!bM#@KRQYWMQncB;qjj0ErhNHwQRnw`f#BYJ2{CsZ=7l5>t zbT?AOKWqHtk7)+VJKTVw@V7&y!Sq*6_#5A*jNsT8NS_a$t#s#kFUl+UDPueE6CHzY zBo`OKIWC`hV)8gFG2`Z^BC9y1Ip=S<#04VTkM>`au{p#*pf>*Nj zzZzZV5tIpeoqCk#cTkZ0R-+80BH~p0)dxz4Z|X5;E0}=m)mg^S#}B-Wl;!GDd3CKL*NB zzreVBK8r33e0+75QkYI~j(s94^AYPg_mMo3!dDe8BN+34SS~Td6Voj*V!&<0Yjnim zC-wJW&s3^D=b3%Q{7Xk0$|87FY=`^&`N21N_0qTvpEDnGjO~~J^1rzsFdfWk4Iu;1VU@qKV*=;1*4Oe5 zZ`xjZbaRR0H8rqal6G}?py!c{y6)>u)`J(p-}LanmuFh>^s$H~ z|68Px@_Sxsn$-=2$`t2?`qzwi{+;a8E#wmyKj!U)AHtTUE779ZT^tPcfxwl$t&041 zfsM^>#iQ!$4+~T2tR84W2L+<}>=#C&ned#wCfr*iO@)Hpp3JXeCq_ri8NS^eeXk#G zb8EswVDRlLq}wINta($=ST#rAF@MN2LI>AX;2VCYml?=D6M#ZzP;VAB$xLPH zY0uNS3*P9*;3QL;uI(2Tz)*`sayH!QfOa!6nEA@T!V{{ONcr_d+y>eN3oX$+miL$7 zg@(I-QQ^m!so3-#^&o$zF8(#U?UBoEN=D(k*`_uv2l}6(ICL%bFvAmOoUEQl;}h)k z!Te2J{Om~nx}9UNjam4q>0ulEUWweo*lFV$@qbaeCNF`m(f}J1aI`xbMY#4`i%+ah z(Q&BpWri~vD0WdAcs+O&>)C|w>!#bXB+d^ ztabF=hXM%?^E=$}EPdj~cwLR^xwD~KihO+4g=e-%QKv1nm0nB8sW^NcHCyh9t$Ex% zdHuiB)zB8s2XD~Lp*=ZOi`%6YxSyJ-#;HT|p{ZD(epn#{j_q{^+iHBEmhjNUcd8eg zj1BKtaH#4V&-`KZWZ%a7zap%n#{;J?-0EBSYlI#xb?4`n@F@U{z*xY;wHejf3%rzr zTTSCF8Wg5SYakmuJi9&>W1R@SH&2;i*53j$@-wxab8 z@KAV$H%8@rsHuB_pNp+Uw#sDErW=7FXY1(qiD6E)lGu3nZa;_>7Zon3;s4fd1Elt- zJ8S#d3{sDrz=b?A3&GAB%tH^ZcBSWsU%#Q>g4iE@1Q%DJa?$La#4qxS;l_-aV0N7o zqwkZ~keuvQFbQfqUFk6NaHzSm9P!5)0N=kd+c zU_v-!&_-aC(QfoUKsj+*0=ZrCCU(-&kdD!+)T~N>DcwllI(^2C(lq{o*%fGd(Nw#~ z0APjrI_e?)?Lp=@=JM70qoxrXjspHe6FhLxzLKDU`!1Z7efEk;qF?Nw=B^BIEXo%tDBKsQrkM>pI0K#qmfmjL0 z-dsrsX`}+b0i1&WaIVPlhS3r16Z%8~ex!!5RzSiwcC`}!+=UN$F+o$-@dy2<(2kPR znUbo(VC>Z6_*=n6o4&?SC;sTa?l=?~(*GmO(F4h_*LB-Pf+i4VU6#(=!$fyM-k+tR184pWB3i1nPa{eyYTndah<+DUjar}UC6=AyNH}K>`|F4wFVQvJn;b4G z{10GKVhk%8L@-IQx2!xd7M==WaZnSWQAP^koo-@%fRNb!g?7Nt%PS@rBcXs%>;nfM z4A1^`JsiI@=Gyt{{wdMGsxnDPSrie9h&^HOmGKE~nq^qu_63ZVrFToP8hn&LIt3-{ zg*0p|JTdUzh5IeF_0)VwR#`MLt@1W?XApTR!{hSM-zWPIA0zQ`*7G`TW0uL-nH2w+ ziN!Y;%$NvWwaL4JYvF%;1(<3L@g6lJxolUUiP3`nhcC99NrDLbl+xg89MUF){8J7- zeS{Csyt>5RdOV|zLiE>%u%JId+yFPIn6-ju5}kl24JI_vxhM0;XX#gJS*#F3L*g6l z!?`Kl!LgRliU6>O$lJMuIff47Gz4V1d!ya4JEe7}c->QEV*o}2n5v0dM4Zz*acJa! zP6kvEM$;)&eky_K-6VTLy4r#1QG5BB4nB!Ge&dNHPMU&!C^K-Rb}^V*+1K~oqzh`L zj`^K@nW#YHPBZoc+@lv9+Ps?AAYV6CoM(_A=qH!npc3fU0V~RTmz*iF>W4hSR3P(R zQ2oSGa!pN3@*ouOY^&Q`l7yF za+JoQNGa`SAd2oAiV z&y@Hypey_8i|h#M$Em3ms0N*l+&lJjLpw6+otZ}nb`3P1nA)6+a4@)~A0V3OfcCUm zj+v91K2dTFDO!II)8ia{b4E1TtEpV42D`w=$C*NFhG)3<%JypdELPq~>VWDI_22#j zzvBos&R~zd%k*iFK{_lNlueeY0Y0pN_za<21jZoKdNF9Igml~e=zuY%JyD~gKvT{enfB6jA8&*|&UXj*5*_!;XGy&@7F z)?I(Bk(JWxpm_Dqii@o@+gp9oZC2d0xH!%f{w#TSy(@$oKnYzqfX?K)p zRW0S&+#_os!Uk~7_%h*frs(-8txf@4Wnac%hbp3OTcZ3?%J_%+b40ziq!6u?7U)m< ztf()_&e8wV?UM_NPdh$0zWx$9hVO*O8tI=akDaXe^S{?n+Gt+^0limwo{w(^UV5(t zG5SLYK1=F}=+p)_qgxo4f`+kw$Uvc4%M;H4bY)kADK_v-+5_w^DQs`>Mrd)~U11RF zbULM1k<_<_uqDBBki^Elbe8FHkM8+eRC5^{qxxYK5P|VQF$R=W^3Zw>_dqX)^w;L> zJKkR4l?LQ2k2N{!u0E8w#FFQ1KT>H$=>!>`fa^ERLmZ_aTmnv@H)=4-4h>fusrzvO z2cP>PmwoO@4^l%5lMWt2T2F9m$><9&j<^=+fQDg6Fs&XaO!MrX8Qf*8jfJ;wBv;Z>| zi6GF8nJ&HL0NovUgPxp-9uLr_!JO$JLyaP6Phvp^57dmMiS$h+mZMiJ*PI%rQ60_5 zr8XFO0{11q<=|s5R+o#4u+sIu8$L?$hCjpin+h%$ zoZ`(zCYe{nPxe{=#NK!C6p48*pLzyXsf#n0##YoWj9bAZkJeeoJShE!=t1q21ioYp z-ULvN1YdnMPPxHVgt%%K;VW|n-Jbeipc!gT4QtD@%&5Dxe;=Z|rsQCZ^x*`EMRgR_ zVuugOGBOs(J8Gu+nAgWoq?UhS&^9(<0zDtXJ(i?>v zPG6Ztd>wcsdNw_RU8vr8+5zvP=w$ly%g+cc)o1+p_B7GIV^#KMgp;n3j7t5{l966& z(c`iqe->cX%t`)s^-LCjQyVc{nn(rurcDuV(#_vxI8lS=F@9~*- zRPy;*{ht5t^NYuKGp};y#yyUIVN1){yy*S6eZ_~;^CG(t4A6hc4Yy>8 zU^?Bf{i2nc-`(l>tf?ZSc2#n?hls9iVe?MdPTcvR%r(c%%3SYl{2nI>?XcbYjf(Mq zq$k3{jaJe~bs#>gOD7xI{4TNK6|c@W>^^$RxxsO8|CQEJ^mrM#FkI4z&|^)AY@OTt zkQxAR1yGZqzq!M|_SuQ#zC3u&NIdzxPva)}H362hXi(1GlNT5Ax>Mx1o&|UbXh}5&+DW81u3#o&UeE069ny6wL zviaFf;q$p5C24`kX^hvw|2bM|XiR?uW)V(PLKXmq?wEZd{B+_S=|hHM{-Wdxd9`6l zY|68ddJ1N+@?S_f%)+XCf*}#-vudv|VYP|}^+HW3oX}iAq375MFs0A3abq^F8+4RCV9FP}tIfsl} z|D!Que@Bqrf)vEM5lI8<5c{gtorgdMh!ina!mrce@l4gNOb|qv#xTFqivZ>%ctrWj zEWwh`@enSw2Gnpto57&^YF7FmgL7eyhVyBQOVTfv!oY7d%a!)W_t}&fk{TVUpCTH9 zjMG^YFg)wnu<=ex9pkJ6O>UP5@FlRt=-U&awnI@Q3kfbS4bdwj)osMT@?GvdG59MY zZw{uA*Vj;MWbb-lDcYQ#^1&SS>3gGBELHU$d^FqL1x}8IkJClAJH!(~fW0xxg#JLd z-s863j&61WPk-VJ(5uoDifR?vd-N6Y;2U}XcJy?VJ_1cODokxPl}kgMXeJ61D*Q4Uo; zriGfUw!|^|I>->T2Z6~B9rPUGEdebtxgguiWetqQVi*>9dgvGBIVVyb-Pr4t#C1g<*knMwx_7@9v3o3z8?8Q&`5CbaeTT+v=5f}G zs3M^e548y6$qYkArWUT&RWM`nK(XN9qpJBR4Nlli;bH{b&`?6f%{ChCDQWVp`EXGVQAJ<_RQW*7W|$wRItI|@AC zl>7P!_VCcZNef}*N%Pp-5mD%xoXquj0g{71O}2<_%8A=^SJtAr>NEP0xwSiX?%l7q zci>@sQe)s}(Z_KhXMeQiv-Kx)CeBW0lQ8R=S*9nLHas-z=VVi0F5*tMPfe_+(~>ff zZjV}+C<^m8)?a?U!t)rzXHI&)nFa7b87q+jhFLcWJfsqPQ;>|JH_7JCeqcAnZ2~@e z6$TXasTIaLkzTrdZaGeh)aCZ@CNg3U2)b3|aV@GCUdBm15|}I4UB>NWPPQ)Yc}-vT z0qNihx3KB?am_yD?+SY_Kbz~8pRd$dz|7On{<=XdjCBIjO8TX6W@^J}tK5=jgNS$q zrRQSO!gOKJRp&$AJXUUDrK5%uo5ouHN2+R~oSgLp>2KWW7)vVhFVnAd7G zs0xpHZJSEwg<1wKNFR-m*HROhC9z4m1W!Ph`+$Qh+LYzUwjdD>$cTIw$f~!(qpSJ7 zHquW5z2@DouJoD#D=2yWPm?n6n&wnMXi~{Fiu83%AF>X8M=MSf`5aPbNH(IeJDCT< z1Ufu#f(so8k&z8oNA;Cds&=0N2hW-E8Zw7^7n_2_7xYY$E+7ra+b{RBkbiHXr8bIL zT?xN~4h^yN>j4WYqh{Sj50;j0*{Zg1=W2jheh@35GoZ%ern--KY@DGsm344aLEE$ zYJWfD75VFUsEIjD;?R?fKAF^DL2t1xdJ}g>`Y4Ww|1z>Y^iOEugb)t<*RyFB00m3_ z@+wz-W7`k=i)N82LF)&3K`9`SK}qTYeW3ErGKt1SUNNV|%z}Me?6OP=r}^Y+U@i<0 zhRz(F`46#xH$oV)fU|;qlEO4QL*4{urkgeT8Q$xF-um#?WPEgR=gYDhWWLNn`Yd6d zGwHVuUm!2|OenQmD0KVvBmj_&mY|q^JMujE1CW5kCMFEulrQ>K)ObVRF#*ODy&>#= zB0dX&1V}MG02)d@Xgi7S-!_no_2j|t=SeIEfV!{S1sxs;&6^Tv>|kE6h{yKG775u=}nGH~~r zRo=5N(PSSD>SABQfn#bC;&32%c!(Z3TKX)+yT)^ZE;%sqJw}q~yndn7B{bBA-e>GxE$Hg`e1M z_;r%<3m1H6I+eP^&f&wD!9H4;4*H; zFjoVq4*t9`Na@G3D_=hk!eIpvjV=u`48UBuS+*S zj~;lJc!|uu^4i4zY1!u+j^!Jq8M?Y#iuHsG706bjR+caw)oAr~I=W%xwsPMs zauqn6NS%mTwU@nd51x}cMN$*S%MfH?e30XPK6gfU6=KPn?q=XYT`%jD=o0C%CA}I! zgrd_d1N@rfzJRQZvz(j;v3xcntibzwtG-(77%hE{%s;6O(b9JgqJa~DIdmQRTHr)R z9ID0T(71QLR$vL~cA@FYB2clmV8!rsiL*FkAk*Y=-5p~Tss#0W-HHUxt`NV0hBEv| zcW-wm{&p48Hx!hlKlAqD9j`&U-~u1);R$UKp4tQDR04BRNu-HVdvq{4XC)r~$kmw3tN=u(j8uIGXG zys%Fgrk}uP8W8Vtu9^Fc`Em;_yjc9g9D@6f=3Lu%a6LZD`^;}mQXrRS;jp(R77nOe zZFvUXOBM@~17E_%@DksHl_#8SuIg3!zHftnBv9$icC79VeKcf!{Y$uz9x1noi0auR zg|OoL#@PuuT{P)&GQ6{<)V0aDrA{2-qqkYua^M$9^Go`!;{QAJX~XJb8>EjgdF~l? zTZ=&@od4=LOgoGpKff+}JOs05iIeTmyL#cVp#BtF5@ZAVAk)5bq1L17_^#gBmeHzm|KU0VYF=U@J>7<;fz`x!aFeO>_N|M<<&^o*)JBn%|D+%S^i{W~G!Nve zxLI0tK#3ONqfa|ZL4B$51e=9WE%BP2+?XaY!pN_iQjfOt(fhVYi9x}g*pWdtXJ)zY zGUHz2&)5Jg!|u@wZx_XDX2zXpRvKm{`7u0mKRyDaA*-$Z6-f;rHMJ4Ig&>ZTTR=8u z!>2eq9W3#j_Gy6<0z504&s_T@XD96oJp*Q1T|t7yf3Dzl z?j$DX13cm@2}TZ|@f!S;^t%RU+Tn|kxh96fD)cNlDSS_)R`7-Jxt8b`mKz*_Jal-G zzL6B=<;?%{>>L`VEXZ9^nB`NYf-IW`L9@i~JMj}Z9|g~!e*Mfgm@=(0yxWDAwbu*H z;aFDWLlx}H>zO_h>=s%)BsiKJ9!no$Y9iQguoZ|~v~}uCWCA{Fi*jf@Y@X17=c9d{ zKYbKT(z~UC1*nhsGW!>xTeW!^FMq)D;2%zT_7B>LP;JKcV^z6D+w;KeMMa>A-*Gy| zA0>LWNbLBki?*4!9qVzZi85bM^k67)YA3-S8ij5b1}Yt0^RYdLvFMOb_yTHs!L<73 z=WEDMEAvVElGjB};&BDmJ(4qcgjW!sTJ$RchL4N%S)U<0si(d_60o?J89bh3(rOh> z-#8UbY9=l0Kc3uf+^_}`7ac(rx&E*S=mpK^z6A!rpP+k*K5G1n5u}yAw1eU`f}U#f zpdW@Kb4%H~gcA`nlNLM1wNVddDtR;S%(-?Db%{GB_kOzPUQPl`$+C-d zqf4T;(}BnFNr8m9UEq;)k7{XivB)*OTK3Z1ODo9T1t|l2W|Ted)}`% z_LFsJ*Z9aRlGpgz5!Z(?H;~45Q?Y07x$RIx+drqAhp6Q_xl<-OxMuEd!d6q;l<|{e zLN~Cr=lWCT4Br-QET1p833KP`52ELf$Zs=0t(QAE2Dao#y;=uv zFz-cNAN|28+vLh4=FLEB;x*|aEGv<`j9;0F8^v9O9&mW&)d;tKZeYjmTr@1Odx-4PY4 zgT`Z4Mh(?HG`wk(Tte>C?1qGDs()P73J*7-U}I0C(}aFTXaX>1+ow73#?9O8?@-1> zoTZ6})bnsCPwKOr;7U)dW^(h;UjfhaK|P1M(v6AB-ANR{cTYQAOyvu`UR8A8MVxD>Nxb|+KU*+C<^EyhoA5=ybJ?)MzDiHzpa(W4 zx@>0Pd=)ReZelCl6FvCy0^ z2Em#tPr=Xe9ETUMhIM5R8@q|Wq~jm@si^VH@w)p9@T{yy{)EM(|Ca?{D@jyFCwoIf z)02bybnTD&2aX3^K+~`jaavA&8E4_h!g%;?;mX{m^yAc&t(U9>Gy3&BFu{~DK@uix zrw-5ir&JHK(6@q`FfsF-_s8j^gp0A{o(J6=gB*7@;u)YopYK=9rKwW^xIDL}$wPa{ zu>(W=h>;w`=iTT!aj#GG;#}Yt!(-EO1`TN&B%V3m05eXA`|_cjyGOqfYY0kBPn1qr z_Y2Gp(1^^*WenyDE-3xh0(w*vPfJ}4H9Q~yp;@wzp0z7mLPBKqmeXT`o9OBbTo*aR ztk5AZWj=IDGm5l)whwgaTNTLJ1saAMHANE+-6n6+!?z@F0^is)l=0<~Vcv(F0UTZi>|o-y1TQYYAU3YiyFFI-F>g7p1)py`0roM9A6tiTuOU&j2N$d4}sQwRZU z+-gmYAyb)6`8j5$aCcjX9}!D0NH7}m8Pz|= z!g(|p2(fK|IbGJE5J6Ar7H&~H<1`npmKkv#u(&w(jcb!LlM@8+O~D+NbY=t3JzeY$ zuZiAl#P^`obXsRjOX3E+VWNLs(H_H|1c06pJY97y<1PI3r5VXsRiDG(Do(O3zSiIj zpiYod?WY4DkR6^n#O)+RkPuP~;4rr76O^P%qyRgvEohdu;-8Az%wTI-?^}H8p>%K_ohxKCL7Azss)_ zU*1A|@Uccd!}vnIKlv-$Z-yK!VvHHPug7?l!ls0RYi_M&?&Kr?t4DVCBOYbygw4Wd zOh~H7)Gywvyy5kY^lBmZO4f)Gb$CkS%In~5aEV8`Gxx}q4Dr?`T+bP{gCVcA;^c!y z^Eozh9gnuvrT^?Zejq7z|0^6IctCo`R~@qlJ4eP!J}U=51|&`($ocpK%DhNUlUm>{ zIg;kgvra1Uy+?F$AH-7RlJWbt8KCq?aPQfjpWYa<#BM|&MD3rcS2qwS?`6pQEkGmC zsM}vnxAGM|+%Mis& zoB;>J23#Kc4FdEE5eL#woud1|ImRHn3(?!0eK#Yg<-_fYrZo#r-A{Ue!$(gO0@x21 z?A_%WnM>cfsLuUK(8rPO5}ZlEoXdHMIs{j9vE_r?s9Q(Z8PgbhV{|!lVGqZy8%NFd zu_OkjOH4$BPQ@mG z%F@F#u}k*6#J>>n=yO^%AUmKc(uwZMtl}4U&n{;wIdU!-+uc%s!kBRgqL9ilq^iFq zt_PjJa&jAKQ}Z+3B~;v3=mAAc6mv#ZTK_R-2j4`K?!zg9YeXGJotAI#7aY!HO60Y$ z+lN5NHiz-gdACb)dq3T(&sfosI&PRt@mnazxCFF}{vMvS4$ph!tLLw{-9FHrF3?HeRTGm=StH8 zBW#KfIq67nUL?|7=d2ljV^+n-V-k*bFLr2#?14)#HhA`Ha6f3XqR!_GoQ^lT?`q7w zoR5w8jPBQ|`w#3pOo1;js2XukN>==95C`OrI7=pm%lYh` zTtsGP%&frXGbDNxUSLsS#H6osIU+tc!6pBs7m|mJ9dWUcQU8SQ|I~C?#qLeTCzM0T%Ik^wgVU?viDBCDyx!-sPKvc+k6=A^TgJZ}Gm*NVY z;g#o;YT(m%jS!ErlI-11-X};MqBxmTBj|{L(#fL{@GZ`+Q0r9Eo`jSsO!N1-3A2N|rAb^gnKxoi!@-8xhc-Dx(3OKc-FW+ouC+mVAWT71?1HI+O~^xG#0``>Fkk&Gty z$ox4iptfE@8(KgRaVY6Mzk-{|^Fgr?fmPltYnH(awcK2>-13?ly-xgwlso$<2?3;7 zdh;MmXVAko89VmY|Hhtapfvoi{4=KB5e!w1%^-Z1tQ>82O81M_5U7Ub2QVBUjq>SL z&ZP3$Lh@LXFMg2;fBs`P%Eo1aVovFdKEiyusNj@$f$#yZ{9gDYVd%;*s z@AkQNKfdR$KgaUaJFeW1c6mn4s~0NmI)6Opt~|kXt@O|Nb@aD7k9g)@JHl(^TvizD zmkA9#KjG9M#k&_}Kkkd!ZMRYGdGZz>R})BWRqIWi`au<&`@);<*R%C@aSnH@S&l5R zDfCe^&ZKPeA5I!`I@JsEV8`7mn*4+5>)#V1(lz9Qk7ahC6|&WgF>Y@w-5SkDKafwi zNfCML*N<)Kx%Xn${#KoQ(u!BDRno8NEU<0%$#(ua>MT5W%bdSOAUKn^<8RR|=)1|Y zPV>pGNdxB=1;nS^e;L7=$M0O2Ery-tpQZgLe#H<=PruHq!~0;X+=b3Baxy_n{*f+2 zm5H3nnhuGU~n>b&ncqeoX!v+LGHRBpxR>W&&kP2sSh9||rB)pC7r3I)o z65R87&K{HNeT&0QPy!qW)uA$C?qxgb7oOor??4)hQbR}B2ud##i8&|(VAqVj%s`l2 zQ()Pc7Cu9@C2kOZ(2Iu8Li1L*_jCo}BogYCxj{VV2R zj8V`mU#I8*g%7EkH}c-1v@GY5zvAP2FzYBc%l}vu6D>B=h5=EF$3OgRdMN~6iwqP0 zz#?emO<>1s!hNSchunC{$u5)?@HDX8;gfI<^gH(fm#6*P1!%+UiQT0qL&#c;gpOl5 zfL*ZK7W$W2!cF8fFIrnax5A{4$TJp?+EVE;r5;oKBo+waJyTfp)gA_aE%G|r_2#5M zh(mMn@73*oYv{WsT|%F@H_u+dU8j|Rr0L_7Iin*AhD#zTu6<+&f&#F&R7-Gy;K9t@ z^jk%nExcORTS8`_A^%I(4}cClj6wloV6n%4h!R52-ZW0*-bzC$O5-fy1rxzZ9R}u9 zY9o1MemZW&Va*?@szsp$dbOq3cPX8E(j#_((Nc$cisbRtym$o?_~e1CAC# znsyWNmr@H%E3Wo)Be4=w*Ow?GW>suKxbJYj0SRI`bo4-Omix;&I5SR)81l3%H+{of zFWi@EEcAe-B7=MbT~bXK|F}-xPF*K?7c(pTr2x{=|7Ua+Jvg=qY$F)z42h(^Ry+)z$+plCAQY{Jy3v_GlY9hudXc{n zOjOQVF!lo+(G2%Ps-fKz`~BqgHak|xu5MryYz#UPBbQWjup->SFC&IpAEdaB^#19fvZBm=Bxk7ajsES(0Ou{)E+lt3qw3l?o^bpZpIG zkU_qD_N_}Qp2r1uu+%lGbF09)i-^Y&m_34qzfe()hTM7RNBD`6>}u&X%NfCB#i>=> z<2zHbb*{dn9S=a$YL4!K0Ym)=(;}KcvMq+`=gx*6rtRoU*11HWT(V+8?*rXXtnk-15|d^C=zT4If>|lr;>cB#6TJxkr|Rq zoW$WADDlnbGiJ~n=~+k}sW?)qMOW~hfB6|ypdDx^xX9>$;)|u6MxZ4&eOvms{0I3v z0_96RH4JgOZA#4?17^)mNtYeO;!r-9AIKRvhZkBDPik_CtQEr<@R&1SieJl(%m4|I zKA9e{Ehw9THDDw$QNuy{3On8r!v=O8v}BcXvJ&*^p#?C}Khk@nrw@-s?w`Yp8Z=TB zOAo~zeV>^!z%5<&-=TO&L}`iKA|!GNLfCOFW>nsY%to-?_GFq*ON1u2?apm#=)rrVT2Q z4h}m#1o&Hlx8HadhAZ*ZAO~aYYz2C z2jiR5Dpel^x%5m~>>z}$_KsAuahaliA<~S8L0*?!df<0LXt z^5)ro1=p#>jOaR zVzG1&V| z^hh?M=T0T9$f4&!9&(Ks3>9}Ej={+}4bb8GI(aR+P`Du*};QCqG* zWsmTzSp39yJGU#pb4JXxcv{iSk=^nd2%Vy&zc;lRdMZpzK&)y);8>gzxF_#|VL|;1 zdWJa26#eyvYAx84BY8ejRYc{#pXZa_ni~v5dO$=k=!+-bobp`Ke0=9`$c5QG3vXo9 zP&E*bF;0Kz1?gGePRNguyXqtFiw0f7@nSNfg1QZ5(7RPMl8MwLHfNhpyfE8=$R|5Y zb8<^03U)>q{=i2jZoq=E(fEgoEpeVckrpTv8qcxWq(l7D z@ak0R=jt7F=4c(*&H0?rT{Hn4IE~;8Qu9I;+(irn+3L;wy2C09Gg^j=6d)ExRM;@# zpy96(dY9`-Qz7A=5St4&9gc2#vISm%9KP@|`8*R;d8A=EGIoXiBe^MT zqQkOJuS+a<&L2JKiBCIl0R9n(h*<~9wqL<`Qu~ye=;UrdeB=Sc4#Ua^Hnu4y-$0dx zPkisq-VM<@D#Z@KcxCHVqISI&9WF?rnoUH%;lqo~ei6J#&}JB z!$l(+H}=VbL4#`-gU})m2NU(JFn&H+bC5sO-K7D%j(Al!sXe2;bQZL4l7*Ap>IF=* z#OIc_>^{h`08xn#f!!*ykV%TtJC*PiJ;R7ok3U_mW!><&(9QJyqX~@~x)JFzFuhw0 z6DIQYohT-*b|LBnE*ZH#7B#=Kf>v$(B7J4KpORRUN|TtfWd*;&4MQwNbl8J{-m9Em z@jbv#MfB|;0Y97^h~Q({QrpypBt}k*oXut~G0|z5E)+i}5&bART3aMI-<@*ikPpqs zDNpFn#&HOA9A?NHHDL+UYOB)Lf}dwZ+@AibeN(}61~Q-Xa2Zvyh1sjJ z@l3Gl>;UXPs?`j&9*$!4}G1m8fRrSV`q79||cSdypU2oCTtdJxzA8vbj53gW-8!IS;8~ z2Ni+s&9gc+7%5g(7v#s2FBt|o@>m=& zW!K2VSR5AbO9FfP%_{r&D`;>+7$&dGO%JG#zz+BvlA&sDCQ2dB(Qe?522R9A=jbGF&sP6chVN6BHSm=3_beqwXs`sTV%Wz+th4o zISUbaeP7g$(qohe^&|Wy`OAGVB^QV~&BaH*SQlPrtC&?o4hH>lyT}i41*O#Ub1r0E zt_wJ(4Uq|!N@HvrubOlHn#F|Q(-wb%R%*<_$Y$w8hm*+~>l1dAri~k1Qta6`D|+G> zh`T&{CB|?N&>MVm=O7`#FNdN!AuH$~CiN>O^jQS@pFQyK=$`5iFDxTdD|Mo5(F1rU zguF-1CVPo}i{9Y~;$iT*0q^0i1a>hmM4wyih1PsH=sI{wRWK&8AoWTv^?FM7jVc+q z>ajnwtb*KIFtNxL4z#h5E^-O|1?h22a4*<1Z-k~w&O&-)??HyF$u+X=RRPUP!}CSb z>u~IZKkTuTH8;qz3nljH291@x$p7@!m2$65o_)f z%`hFnfTZ>LY+#W)jLlCDGR%ldYOBcCB5EYS`80e-G_0b2#ueWl9W2C&O4q#*El5J8 zPbP!)in;_K_1})n3p~27J5?@E&>!KiypbJ1QD~s)k({l-8c{MpjgSN9E_re8rAgCs zuR?5=q9%^>O3VpqX1}>4$XK5>cm(E?^%KJ}CI5Lpks|aLvTqzon2N1g?wMG|UoH&H z5s~>L^N+nOy|fI>vMznWD_Sy(5kqQRJX9>{CC>!>n?7eBP*|bQ@kz0{=zojN%2GA3 zY0~dGw{7XKMTPElJGgB4Y&CqZju1X~SFMr%dr4(c)$t-Vd-%5$GYv z^#Dx|_bhYw{lZAghg^`Fe;y5APukgttJ9w@FGA!k>7d+}) z;I%ev2ck4N4+Lg#o48-Jq8jjM;|`h@_(9HuGo^|SpM5l0v~%-~z0CAX!#{c_@m1X@OG%J<4-X zW3Nl0=^;m*R(%TLaReqV#_62wfg)FL6Tonsh-tJMuCc>5IC35ta4LIPvxv|6MRM6v zUv8OO!;#+7W{Yto=cp{;HGv)vlZpS?b_-5S;zU(ReXfx?T#$ zbE&$H!5n?Fx1QI#V2waN);%FhQuMR_vLYr|)T!)A@Dd8I1pEvkEW#{?x1^_%%{DZ# zqbV5X?sJOTRIutZ0Qm717>G|CfhF-1^4HX@#uN?MHnk(HAkUl)EqWfIjU;~N0y0t) zYUwz~%KeS8@k_=+jVE8>%xB_@<8F-|_(?c#DQko5=Cj+lHRq`o8D&*Wnd%bfsbs!P za78I|7D&$@bwPq|(I`->a(a>%Y9ZiPK&G&HCGLIO^FX42=>6E^MVv>8ayjrj|HA?R zSQfbCi9QgmXLBRD$JmBuE~ZZ>EyS67>0SlPDB?wUCSB1b1+}PVG2%=}ePI}j9SppH zGDMcwT9GwBHs)5k2jF+7Tl}p96U^G?EFUo-dcT-8XL2bukHi)ti!^UQ!Z?RpD~JbE zUN_Ftj|QAvo??7y$(2((Ot~*{CE?7hLWek8OhI^VvUNZsk^pHqkKib+9KAq%u^Q_> z`4`6ioIB?1xsS=5EJR|-S#+fSFHS8*oOT|Zs1wQttiuPf@(KKUkWAzQ4X%mEE7nXd zJ!;-g{7x%Ihxe#G=DtWaWA+mP7g8gLrnxcgr-HU|%T4 z=XUtJ4*T?v$E*}XN+&2Y1&(Si6)5nbKI6UNBku`X@%jBfy*hGzo^52`z-PcQzg6Ko z@pP0vL4Jt=Vr|hw7~@)eJd}o8F40}{u>LL^W&m3gh2aT*pGhCcJM_YDgy@8?Na7h8 zoG#Q;CY>^**AKQ4@q`t4%tb-JR<5V;M*NLt49p1-cf1A3@&Z7-*X(Ae+xkF zvG*X3CmAp58p$7U=WoP`gB&nN+IkUN;kP9_ z0w*!fJ{TZ*7W;-u23v}(?As8R5EN;*FrR_NO&p#ntn8sH+;wUyW@n^25Hm&S6djBL z!n=GPmrtU+@@y%7J%eV6Qct=#2!0Hji>Zh<~%tVmUU@#cxX5k=@N)sV8F;Y z&^gRgS0y4#rNAAZDE44gK&@bV{(|rFnykUICwwISmM@$q_Ne+5 zJxt!D@v+5S_Q*=^Lwv!UQU}c>P^phTmVVY&aX_|~uqE$gPteyDd?Kr16n!&(?oq*a z)tDqGNVp|=bT%pWv`1|gJChJLHQwxa;x5cJHq%@ZNy3H5SY{*9PvDRFJbOeR9AnMj zc<|UJPX(kT-0tlJr;VM-oC!Y}l?ID;iA`X+Ex8Oq{9%(gk+`hHzvayL{;;V-#)F>* zpEs%qnh~A{OT`d7UZGUL?Q_o%$Tj-{vPSBI@tbAtrt%sxhu-95pRp)8s|1VFXh*P+ z{Y~;P#GGpJ(`B8&k$>=c$AIpSx6I``93RrZPQmSQvH0nP7i?6I?{}x_&6wGw+9;%3 z;Q%hdGg`=0iajDQq`z{e#DkzuazN=axJ~BEmN;r$c*uFEZ{giQ4}c;-9}qjQZ6IAF z=_5o2Yhq${KKt!P8ac1CDN%;WgK@AfRPwxAH&7iQ_Kc{E+$k*I_Lbe~o;PDTp4z>wBXcT#6Y(+h>ii?0q*v*6w%l-s;`@}Lu zG|48n>=Jdt-xc{uhODeN5-NzwJ~k9-3)Oe6TsD;8xw}N@Hpn>w4Sch@p-DjEGkmyy zFB7yd=LuOv_>O?y)esCm@mmG=UInh#JUyt{l=!uy%6(`(bhE8`;5OM5f5rqrVQ$W_ z33Gay8+p#T{g#jyT&WxUgwRd6=X70K39%`O|93Yp)17n&yy8Q&8}a?z1whI(IB3ef z^%c_@+9)2KIGKSD+JNVT`w9wS#^6dmN8%j~JqMW8{*;^xAAOc1(6?qu;(f~f!N~zq z0g;Jzl%hhTxd_HvDG{~%lUB7JM^ySK-V z(|LeTwx!QTOU=0CA=+TfJ*hXR>jP)Izw6wNGWNCT4Bs7KK_Sa@jbE4Znmyt>H*!_V ziNTsC1+wJ@2aniPs4vq9K@x6}A^5!})QAHT$>b6z5&R7%Pd;at$R5q`KhR~22QQ0k zI_rMNF4ECJE$nLV52!Dk1Oh|@cX}rk48heM#gtFi*Krl}@d|mXL*spXGK{G!C zrV9Pk4`C5yFXI|W&ynT^{{RagB#h@LETwO-nCW&`#0@#`5_9p6O#7Ww&65RE(?y!o z5OWRCyU4l(H=gqg2YwHqkOs}+pY#HB9^j6j`;KS&-W zfum&Ih#SBkH|fu)LQ%+rtex~D0>hHpmE0$tHXXz{^-!|szW6d>Ca$B=e{{-&eT87# zK4&h4--0z$4Qnj+0xQcy!A*hCkenrbs#bCx_I2aoLtgnSsA_6=6uq6D3Z2t<9qbZZ zAa6Y*k3Y5Yo)pJpFPOF~ND73M=?kaOggPITtoOYlp5jwL4&S}V+352Za3?YGkcgP> zLZK)8lon%47$b5`*ZW55`|-`GjV*c2?O=zJO+u=h*p>9T;+Fdz2zyX8pz2dR)6o>U zKrF(i@*OrMiC~M2`8)HS1AVEH3h+-jobPP_MT@6>x+){%YPEv);J^9U17Hzd70Ft# zCzg87I)Vemet6Rph01nNO56vV!m)OW{Odj>uc_ zSx0QX$)kmD7~2p)vFgEl7m}w?_-KHgw*pa%DnQ!=rDXg_vPB=Bp_ZgNPoZ5(x9EKY z!bOhPdmJ@VGCY^Kk~a-^+;}y^Vzu{g3`n@6m-mRJ#70!q$|0s8Ur)LzLKN>k6etT>l%Z&YDN?kQ210qdd<(hpOX`s^V~f`A&^dnQ}dj%n6 z-cLD4|3G0Zn8MdX4avVTq-B@;demLP0rQ5ekM^JP{*~F2F%nVZ{XSyC-=l3=ij!FY zO5p^GEf||5ZH>B7-Dom0uG(S=t%O{2)s7O4gdzE>O?!p5qJk-*jGhpt!4nv}D7NqT zoW#KvHGVm@Jgpq%`3+;!t`fLC$|2?lh*Qs;esE*gDCjrb72P|yuwP2=I}rYG)jzok zBO9_~qVy&klf0A8C%&;2E{xylZ zBailh9|c|vqAwH=rA)@jqqZ8G9PUg{T+WC9zjE{|153q^yYnD3o)NL}-MZBuNnt!8 zu$6vDd?YTu@R=Y9AK7?$IEcD@M=B7{xPTu(-elxh!`jCb;;A<_Op0X5C3SH$u*aS@ zd3EZMqd%DYj@XK9T&+GE;)4hz=LOp$>Wq*>5`i_m;!IyGZz0x(#EL!9g)D?~&nx%L zJ(!@8>f;>zTc;4Us!eQxs3hlZ(}SShK^hqIbtRod{;fcp~eG$Yplxrx0f;;+QtNiBpA$;A`Tr+QI(qY-&SZc1$lLs`i@ zv!x_=WL(?|j7yD&%ONB1DfT56F^s$Lr?SW48IUztkWaXTe!98SXD;D;csD7=D*;sM zmPf7l=ePjW;==8k@$iO%*?|$^yGS7ZDJd7?t4Awx{H4f3QhbCj7(exXtTFg%xCp^} z=m?EP&uVaB2X+S0?NK>rTVdQ=#Z0xVu?5>tx#lLRLkAFeC&>TNYez=PlAp*&^L%+Y zj%S~O{R%?iL&apoHiHv{7zvqS<{F%kAjW^;4YvdbiA_LKP&{A3)&k@Z*>&^ixx_{X zfq}1u-q;XjL%ve;CuKEW-K;tLin=g9Ck#R4HsR}3=s;=pB5~^z`7dm@ImQxaikLYz zAQ6Aoe1&S`u;7<5zMpCM=ty1^pLn9n3km=W3Z4}2PC!@q&9NmNZp){4h!jj&hziRS zf!t7H{4jjt7h%s+FawVp`pyQ;zkoA#h|5jyNLBWDI-8weU`Q~Qkla&658W;qfD#Kx zHU$r3l7QsaKM_4ZO+~oGMhGoVkttye>?f<*vRMRzet5_5Kg2ayIbx86qe#FiM`FIn zWf-84+ut=LHpT0}@olL=3>$G3Q!qX`+ zFa|35n~*}#{9r|=5YGS`x3riyFCxYZ)Pl>)bDd^_A#x&|igYEpiqX|zU(lZeA6hUl zeH<-v3HAsX{Hs0~=IyV8DZCHJ@6kCNJ}d)(-samtmXGWHpT=}wx72W%Y=52u`E(lB za8+e7c6qrW_h1q6Ir^#ve@KW3XIUD*fCi;sS^neR=<8Hochu1zSvFM9CiJV}FYiT2 zFhwopQDVR3m-%JZe5Mhu@2iOWs!ObCK|z_uM63?DxSTb~;^f zM${ssTUNULkz;~QAbE|eJWIlW-gSHL8qQ2Gmo8Y5GB(LovzK5}zr)q^t^ZI9^|L?5 zvla=1qm6296N+UAD+jBXG@sPb8{LW?&4?EO91Uzr!&=Upkx(7K2P=q8+j`!NOIPLt z`3oML-*44hIl3wqn6X(lM4GsD`br+_j!&IpkCs3Z&d6f=&HA^g+0_kCj+m6%|MvBb}(LR)&hW`DQjzOIBaBP1Gx>@p+l4A{MisJL!<=EUP zC`7?+x2a&uJ;A%8UzN-QEiPX4)uf`8V@>Q{oSo^? zt8;~Ky8%-#su$6sLvlcr$D_Y!3rVVho*op?@w_c((qf)UDj?FK{pP2aB0&a7;#XVObMK)a3we(%kV6fGusWp0?K}%LzuY4^i>@lr>Mr%pcxzE z7{VSxpQf)?atv#>Ath{FQ(Yx@v^HhcOX3?DtJ_n7)d1TQlwW!F6r}L%@+SriS8WVP zh`wZaEsN^c1{N&7dcp38_gU?cqT6vlp^h9SVBf>r2e|{cL$Tm$PIw!il+Qu^^X47> zpRPuJxO`Rp-JnB73gO?jVFw;fky+TusDLQ!shQUV%%I$}{+pJ_Q{ta)2>e%t)R|wiug=XyC}M`eB!?jMF&{gJ01@BOPIr{PSm7Big)^ zC!l5yTbbl<+)nFB7y1skxdzPrPU00}7ZMa}V&ok@r07zM{FRO%@}&hEDt6DBG9K|Y zk(}HHDIE#5t6sqIOkeXufAnD6aUelpqs%<|nx!F;Apzg94u2sKDch z3Fs@m6*E{r;VHqd(OupD%1dN56-LBBOne!lac7>KIuJOvohUeP-5dF^FV-a;Us&@+ z`9i8L8Lha!!~z;7oDOuX$5-_7)yCBxdn3S*-bC2V^tC*Qz0e?fl2@s-QDEpp6I(#t zghvD4=?;eg<`F2M+vxt!JZi+`MZGwEQguPL5vpryTrfhPOn5hr1oPJxRpoV)6~-Y#sz27*o?F!-eD)31KD)@n{6U{rs|JL=aNgGZw~nV zw#3e)kB0Kt@hr0UHj0H@4a^!O(bNYp=OcEvCXp-F+%!Yexa$XfLmhc`-I1daIp6w^JUIpSTM^-t%&I?PvmFd>>>DO z=3PzwjI0s65^x6b@|#3pRD^B-Yr%0Kcfl%B+hQv`A7lkV$-DG};&R|uR1dE9x&`J# zdSlNyFju@0d~*O9Q8`<;S4{p>e`{vsA7Bth8&ndfq}Q7jtq|>kuO+ry%Dx0lXylJE zNr5K=H}=)mEG`IF_JMaE-xJo0ZJ|@vM>s{;c^)}1_b2%ot4E`uH4W&V;QI35fZ&Hxj0Sa}-F0WOogvKmg30xW8? zuV#m#6%xw|hpcq>_Wd<%zS@bIAb&YJudy~ zNaE;(%5eoA=$#!5AQV1Ls`?omhb+@O3u>BwoC?6{n-6?Wv4N?WAFWA($y8!Rc}*%j zU$Dr5tmM7Zry&?hh69|ahYti0L#IM>rJu7XDT7Myn%OM9KS@ z@NIr`{;t&NF}b?^Ko=Vl%g;iGy?`!(H~FSMIWLJM#whf{{G@*7kM;Qq=n2jKWbQuy zRJ@X218S56lBW;1Fty6q;q$!VTF}>?x^ad9`2&{2@*R0>biPzj(612B9iNebirtMw z78Um^pF^P_769_mJ;Q+l9`WE=-A`&ab)#wM!!_C21j*&Id+FQrN0BneauB9)ws^)8 zp3d?-YgeuJ2%FgxKW)JsyhO*siAuNvHi-QvU4&U)fkYoZv7$^-ndaio@z`r9XwuuP zfinUD2rD~1+D6F7u-x3yfj8;+0{h}}f~`7Qs)Yzua#B|iu5xfAf5;wks&6$C%YPNDeQkFy$jNv7(KSI?X z6T>rTjm$n#_MCmN$>>Av;{ij^6bi~*fJ<|c>4P+q;sYQZzdK8|M3Y&cJ0(*iDB+r31 zpwJ!0KBHL3caTmi-9^C)PA(HJd3_=^=A|4J!b-O2T6pispL1V=K4zRgBuM4NYx+@< z-}?bYM4p)5uiR6YH}kVJt>DZuacc`4cW_AS3|tIJ6O(i1YkKk)L~96OWrFMl2ZWOc z_W~%MsGt6Tvwg+DvmCvJ15G%Hsb2tsqNYJcy(S8FNsqJS;5i%&tTXaG`b-W8gp84q zA3M8~jpz$?Cnrhfl`PVAZH#WlD?L#CCm)YnzSFw|9gz;C*a%-Um@ZBqQ+#Ou%X?+M z*%8W`ro_uLo+C8=!Y+sm55IvLQuF@vK3QT?Jp9>I>R-T~Ol+>FGSVsfE$a4(Cm)da zr(Rq$E016PFFkAzzT~Rds%Yb;|47ZNa6XZb^fB@`tiJ_qt6CZ0UgGA2Y>Ft5EQop8 zR4O0PI)@Q!*`Mw?_xm5IV~lF-py&OQ3uiWcMK|>5k^Yw3XlIDsQoPHB?P1#u!s#2! zh^+LirT9Y01lcVZj>(p8V8~&Q=#RvP@C3|Ese5`*;7>JY9QN;qSQ%W{Y%qc-O3=X~ zBVOs#Mf8q!lRWPpJwBJ*I?RCZLC`LnrZ8e%V*?1Oc{6-KkJY$dBmsr`EM#y)m- z4p+QyrH;?!)7WE1T5^C4MF3&bo2nk44M_9Tvskd)vs7`;UJdAp{lxP7K5IrR5xzkO z$8yf1gA_Tzc|f%z&rF@S!WSmyWfte5`YGjiE6y)nD`3`n>Nu}_tTg0RNrfbHK(h;; ziSK`Zw+-LtbxJLxdB^Mj zewTvh(v9J=SLsI9!;AahZ!O&=<@JA`DdajZYsq&70aLnnau6!r@%yU>cdoGEJYTOd zE0ka&`TX;DujW}WlhrW8<|Teq$e zJYuYgB21rUn!-iL`(12)PrZ>1GLa-t+b`qS7J9;(dVoH8E2xu8uup^u84X5E7jZW&!^cEIs{7!Tj1uo-e$=NsWU$l&iKfLM zXzJhHYc|rx>x%n?!DB~U>(p|#_@43hyl%3V*KU(fUQ5a@u_fxXJoX#bV14Q%ulQ`k z>);T;I`Z24gxKn6e|nB<`e0o3>Eqw$PQB;n_Rsr#{JYQfd7r8EyifmmpM4{r`2l5N`!)s&oo_(r6IR%D2=w|n9PoXS){+wCt9h2BqqrNu<39`dZ99hju)@vK8V6NUt==f=%~3U)7Bne5X=_K6`t0-%xWp1~&x5;Wa0dp3Q#cO3C=M*+ z%pO!RzhL{n%InMd8$#qFuY(OG;$)lWISE%2oR)2NpJU1;+r?QVT$Kjbf7e?O+Y;Cr z%F?4tJC$jyALn6B#7YylOUKk=_l^0+a`?<5<1Gi`t++eh!3dw z59}PbEV>){xIbjJ(w~20!WXQ%_(CbZbvpU?8u+bxXoESk-Sj!WDe~n{aKAhg_UmVs zSVRC3cwo~;^gm+aUP`m-a$e~r!k~>-_=WEwLQ!G>@B*Lki~N@MB3vT?OJr=U3w0OF zIUSTJARvFl1QSshSz)I%zq)YArt9|+UKg8){RxdJntoEvK@o`_W#3-Mlgi4om}h$N zKeh1?xu?Lfwg(Q&$3J5v*l6?|d2K=i0s= z;Z&OCbQK`tM?!(`BM-qorsJ?u?+MWcM|^QaHho*6Miz(Y6ByZvuH%x0;T<&-y)OCp zf}9HuI~tJQ8rBQg$9$jXvbzE1P+7?s&hrD5!d5dH%$Zwra4@Kge1;}?^(uS9q^@|nGzml-TWU}o5)h2(a|V$xnpRW)l^57uSYF1;qux>0 zm)ZmNNNdt-gL%&hGVWWa#@yiTkv&NuB>;SvBY0`7VKz&OHegcde&P#)rRR!f0~adx z8hKEdjot3mNZfYLIPHDZ<79c(na&*RJM-sVd>VLoZb79IYhj?XH184XCYXT>2f0dL z_Yz{HF7{<$5q|~~N)PA*$dHwB^fvyC&o5gFB7_a!uWX95Xnc4wj``;FM^0cx|APnq z$1%FYo;jjNOU`TbzreCXBLeon(L?ft2B8GNxlgzv5h)XDNM2v6AG?RL=m zbefai)?QT$BN9l~;K0uZrJA#XUbJXf8cXKlJ9V$F_vGKRdx5##^Ci2B)6}UsSL~%0 z-^*ChC4^j3To&Zd3{@BGidUcb%om4J-ixge1PApoR@3-*d?pI4{WA`t{BA7037=g& zKSwtu`%P*>sCekh4%dQykRuZ69- z5z1Cip9yEN6IeadgQ+#8KM6;r7c^#l6#{!=YQ%Wm0@)6CIHBwQFMUo@2b(~@+2npX z@nvyGGjSEx$#m#G?q-vo59iE6!aJF9U>n$fuR$7Ni3IM69zGzk`mA~(P*O2+VC zEWvBa!I(cWsE&Rem)pY@EV|zxZxC=JV>6BOBZNu#0%arh)^p-b_dG5&3O7;Iwfa5#ek&cmk| z-o#n-7npX51J1w58+r{^E>eR$(k(a>oh8PaU+EE37G@FMzj$sFrSs6_U`W7W4=Sgi zUaX`ECHbc;SdHLy&R@waL9uQUISis@e>709`u+#*A)Pzs^~jk{ZaggcuwGP5;0RzS zHC#exwRgw|_+9beksWyQ@HjXV;;^uPpWNRZqAa{}9X<$BAs>SkB(1A}+Bik#k=z^N z25ts(9q#se&>!u>>7xOsN^aX~z)SKqtfehr6k;M=1`91D`~v%mSn`}&dNyEt&}X7g zrLUaBT}>B3a4kENz)IKaMYFzu^?(giV?$gaaQB@NQy#^8k6vuVk#kaDB!b=c3*`aI z5w4|Y{loPr_WW$mflvX#bE>C^sD%`$7Ur^x?HfXkBYlL~n5;Wh(>nP?_Jm)ItKII~ zGde7sv1hz(Mf=6{m|Mhot2pQq&V;t-Tr}WbB){ts=EsU-t7!kr$4-Kp+A>bSP>3lu z_7*!2JB|4C`h*N-KhN3RF85l&qF^84AEQr*2{uTqyI`px?D#&oCKO~WC#^81E(_Z4&+ReV89 zCMG0fOwa+`5<`~#aTvi(41V&&Zbv&2=lVJwC2!#U4Ts2(XFmq{S!fG3E;^-gKaj>P zu48jF3r+-i7{8>E14eR1R|jxvSIjHGZ#=BTr5&R@>{i%ap~xfZ(#QoAM~q=(4DyRI zQl;1q%Q-J4f;n=ktA`1mZBf$ z7aNGY3BCbB+@vbNdb_!T=p#p|(g|jX`L3<$T>5E~YXJfPA)yrnUH6A{MROq2Z|lE< zv8XCiqaYudA3t6*XtbI4IiUABh>}!8y0d_v0s9S`q}k9LXSio~#w4x&bR2nm_6>Pp z;0Wj}N3<3+eQXnLoyo=^{@eyY1TKR5XaZz5*$O4hMJIN1VHHP|2oK6}RRys#cqG78 zQB!f;nDF`#522>R!56`f(Ntb}aAOt3QKctdaAe+?X|d5j6@Xo9fNoG;dt%?Y?(oj; zq`N${f8@G)EHjG1K+)+3#t(sN{nvL8&T$9^AzhQcq&_|=&B!Mv1&V{2;4||pI!AJ< zNrJZRo?vyqkzX%3Lp@Sge44aL9X*z(kOmZSy*693x3` z9yaFeeP9zrE-VCT6L-nIp)ts1$!C>XfTOA5wM4OdbV?&+!I`d>5Z{ZLNKUm{Aexf- zhWWln{cKMV4!_7N+CR3;5iSD8N8OB>&BYTXwmS0kV9A``J49gUiCz}X?Gvn=@r<>> z28Mxt8d}DiZx?t+DLEj&`ayvBTT|$>(4gf-FasLMxyU8rt*5!18|ny77RzllJ;nt} z0*)i+Jm=Y|{6tMsD}tZN^~aC~7VWkfo`#Nh%hKcpWba+#wv3s0JN_5F0k#J@=ju*= zH`MMqjgbc42@Sz#(wMVsjgZF}2xgYREsycQ`&mX04f6f70pY$_Z4;GRBUNu!giOiX zQsoJkSdJMDmsnhplR|7$N^fU^wKVFkOQr>dho=}l98RWQ<0(wq;13B`+i52GZE#M~ zM}$hAZkPDf;F;5$WPL~K7L-jvs_{eDHhQQL`Dfdx8R~_M@IhkK>qleuwdc_D1JVJ1 zEkyU+9n1-Nec8cn{zz>SNBqHB!1n^dDRK%8`Q&tIDF4SI*f;)I>g?c!8UN#j8IRj1 z_X{*~BJY~XNFc9PGlfS;e$GbHJAwKw^5pU}@*_mF)x>_qvs*>tQ>|w>OZU#5^^91) z(Ir=P@~^Of6EGw0#-V4w%_X;U*uk40lEW--bgG=P$)5dq=F*SHp~THCkMNi<>5>MR zykoN(KiDfkF8mDnjob5-;KTwg{|?{J2k&A3rEl;n)soGoAu_Ux{N^{TXmoNCtjlsh zw&R6t;O44rN$%6jIy^nsp1uZP3eM=qZq}}*$ef1$!r11P^yz%}2L9`hd0P%|BIScHZvUtHFn#~dkk?ier)zGz5#Dwp*fl0%F)S_6Oa^k4z;Qngwr{%{%T|tvV}s>!=lrX zIdA|mQsYVXJnamyAjm!B_K6;NV1)?dQm}R93_JEIXX#FJ_zGKg$Q$*GJPjIvuo3?4 zvmTH9F)8Xb=MGaoay#g!`bE$>N@EW)POnRHo^CoUH;kBxg`N@dYm1HZ#IJZ_OQ$S8 z#4UWNgRB#-i(41bWyG5!{^O#L-s`r;TQz;VUB%X!_o+2>=>f$#e+M76Pv~A}RX~ps zX8>sLm~Qw;c-lrW4@6?3jJ)Osvu(fh-FHKYqH%5{u@PQ&FaSS)JL!H;J1rm(ITS{_ zhR})M?0b*6HuS`=6BHkYgcVMUaJ{X_q*ojy|ae3gYbDch!b?Gxg(|x(&=3da_WZ-5*hn zCAqrQ4^N|vA&7p0FfBPwE{-p;KIk3;7q(e~-0wKL$OTjVaTarJ(zLk5ZGD80qma0j zp1XKfSB6Eyj#MR#md)!g9JkZXNB$Y@SVbWJ@dIA0;6!j z+~?@2oS1ts2N(%Zg2#gw1Nob)9$zF6X$zf+a%+!@cY0ITsA=lk==*aZ6VEIw=EXjc z=M`U5X>jNu&lqnE_d!+?YgM+K(=cP&f+@YGU3T2P4lyCg&Utfq-UP1C+MF?ZjmzR$ z)12*|VIRTn(wN>5B@6#3;~?z19%WRDw;vrn6*oZ-MuN)wstFjGD$cTB^hb}+ zf3b$dVIV}a%hy$WsF6h%U$`z19!B)TC?FALjj$0VTE*2jOsZO86Gg@Gled>xNz%0S zK=~Z`*?l4LO+;V!5xdYrTHcBXox)8~`i~={sVUwU2y$ahi4tyqAN!4dIlPj0A3 zutyv0_<2$-&5TU5)Qy?toAk7DXgY70Ri6raPv6k@tTpZ~`OcramSw5Y#3D05{Hi@D zBJ5&X`jvQ^7ERWOW7xjV^5ANNg{E0FMxgYdl-v*6^ zJx*V^^<;{P1G%L2VY-K|b>L@kh0^PdIu+{JhyziMg$8lb8@DZRY>lpXIjM;JLnJb)Q2w{%bUMu!0eSw{ZLqR7^41fCc zm1~lIw7ewJk2o8B<&v^b3dE0%9%e5Q*ZGm0)Ox`qIN1+d(s@emc~QKxW8e-js!ehV7h=zt_G9`7Xpk{_^Wi&z zxzt3Sdh_o46)+`Z5<9zGJuc)uuz5^sKfnL?^Tj*!CfM?7Nv?K5Efw-}KfU~nCAEHj zsmU)#gPf{~Cs&v4@1|@e!F5yJxS1G(ES##SD3+XuXpki*A`HaF?2!6^zVuny!=A9j zCz|eVACCjL#%H-bFX1c)GZTxwn{nU@D$kLO5zb< z?SNDPBBq{k4z`T=Hb1GuDKFdyM=_C-nL!=#t~-Fn;*-`Nf`^XCJP<2h^Vuq(rqiPr4E%uZH@wCkYZggO+2zSelR?mf@Ulibq;O^gmfXsGOm=aYQ03)3R85A3)=t$f=f zm#o)=Z@xNJ*_7C~^c=SZi7LI8qm-UzsbGg)s+@N)!Luu`G-fbXRGM&4WTx31FtoR$ zGCC8x_%mDd4n)4IP)*GjrWmL&pMNKzMBMhnFwSvB*trG$lr1MR{0A;k*k^cfMI#kQD_nF^rrrS1R* z4Sy*{E(oZcI8$(9qbjNAFR|+JYpMlf4jTK6%HlhcnTIaVW43d` zP-1hDyVdy%V}$dW-#w}>I9fHm!P3#E;%a&#ZfNF{87L6Yo-%;0|~Xw%oVe@=7QjcnYc9|1!YZ;U*w`D zoJg>Ffb^jF!I_{=RN(p4w>B19RkagQBJ9iv^afk^~g{M?P7yY~9Y?~y+yAC$b(hgc8!dx1Z{ z2Itp0JrI7>i{LwjF0o1Kf*9x#+y_2dBkZEN z{tz2=O(U_cO*21-eElb}K|E{mi&|tpy}v7qc!|-oFaYo~9yYmxQJqWQu$AnO6Wnd= zlm@;J*-7+FE^2siNnBTb(~QW)-8B7qkHmikkB0}CH#xF_kxgQJC109M(3_0DezY^? zuutHh0h8*-SE}6B%?|nNnFYPAPQoFvB=0fyS6pijB16Gn^JVnNDVUSCC1psAnVyLR zs}gS;b#VHm*3UyW!T-|z#_F)7me7nPbOQ#&%1~EiKB@L;_SsOGC{hHkPBO}D1ycd@ zKe9{344ma3r3d)C%3$dB-*AEDXAi4%s`ZfKSfXsiShOU0)c8aVY-H^TEA1GoMNgr~ z9NP)8NvlaUhs${s8JG~mcx|XZ&10z}h8rZYdEQA>2Ji^7tJDj%b`uFJO{3G)d4A?P z0}H_4Ev=x9Vj6XQ_xXN*^$e9H=#|8rIgYv8Jms)hgU~tzdmQV)>Y$yah8D zMPSY}m}&fJ$DY zu%tB}m(-~qz_I6`YE8Ja*!$8?j#!FH&S$!bf=TdxC?gTIY$&ct@bU%E!u~*u<*cGj2SH-&FACi}XEX0OpePt{O}eIkH`gb> z0BEVeRmLC3c{DCNdOl!y3lRBhhCbt87=wYu-x6>UF+SpBGTwbai~~20@h&d(>yjJG z7;ta#BG^DiO5p;RN3f|04`i%?aT{)jC&Io%?y~Xqij>fbGY8Hh^=07p6#^F+CPd@I zaRHzeusykSromy|NFL)QA;$TA0GD|l>}#!<`w@J;e7&5M(g)oe&)^_+@6{bJbIugy zPMmP^;NznL+(zp+tZS}^1dppZ_!DAJazZzy10LFLndiuTLo$I)zR~ZH0DOFS5)LW+ zDZszkRAeT?KN9XnIJbmrQwq|uOdurqXJ!>Y{cU-_hcstgG*#i*8hDs8-5$5DLv6&-IE&dAi04LtRB(E~IpCGwl?}S9`h-4Y@&pp8R9yQYYRsY2e zsjV=8>j&RC&TnD_uh?s!EQp!JyxzqI}$#$g@6 zh7YJ|Yn+hAREphg_wK zt-mNZ^+>G9HmC zz0J#;0m}-bGcWMgg*D~o-i^IZdY(;+ubaXk(?!=sebh8JuN!G~e)^H0Q za>fYF3O)(=lnHj8dRO$$ARd%9aDbza_RT9kcYnp|z&rAopHU)k$&Anvpof;k7=JrY z|Nko72mtj0vGbMz6bI;sh3c3 z!3GNlEPx~8W(nt;MYrMdwW=J48y(i6xc`@7?}?gt??{9=r1#5;_$0c*%fuewD|i$ek;yGu7(l$ zXcZe-TA=wSqq;APz*i|K`Gg%77{I(Pu_Vi<{A~P z0+E_qmFLtVDwqQ2?(*Mz3Qoi2FLfYrnAiheV~ln+hhz^?3io8gwHVtR?3=uoO;ovN zlpT6-{%Hy{1?N=IMN$zeB+JRjArg$}eBm~u#?O1iCdOUWUVL&z2eK^NC?}VX$|RDm z)DiAV>&W{D^?(aQVy)3~$@^jcs|F)5nGH9)`rQo0lnJJiqgO_E@|zrWcl(sR?fRCI zOk}22abOkrP5Kj%H?Qg1gpcpWg>-2%GP>fx$??bBuc@$_<8u$}=#x9h*ALV`a%3>w zLN{AVuj~w}D<^KK5odNRF~v|<^jJnW0DUMWCrJf-b56ldJ76=nu?2r3N`;bee9sLR zUB9!Y&hWxE0t9xaeDTj%kR@#03{n5)na(T2WrZOL5Wa%ncZqdap!IOB4fF3Paad!) zx6MQ8=pp_USqJne{)Y73Sg9Kumt0^1Hvu4i)MrjG&G04ol zvG>4H3la;P1C9Zs1}}Bvm`6V%19oy-+I*MAEd#@#ckfIV2f?S#hF=IX+=3t&3C~(y z<8xB&Fk{1{PyZ8Hz_St{1GTk*i=t6*V7u8G!v*YdfxlX|eZjl&I`=B?HF1Lb!J!Il zH7N%UzRL{GW8@9&)c1pHQoRD_iR6d?TanAQFhx~M6M&zEtdB)>#b~kpDo@}Ir_dq{ zWJ|ADavvqO9C42I8?(b(*atgu*VNu0W^hM{c%kw~E#CvoaX1X{H<*%NiR88L$zR~N z9ZN3;^5>L^qi2zLmKB0@sY2`PuV$xLbXURr1d~V}h0Ou?-wqsKI+aky=r^9BCB+9n z&`&{RwqVCRIDJ4>flF7{zdpV0xgV)mWWaOWjX+Ii!i^0~i%v*zC;z(#j)1o|Kha~f z$U?S3`2U4l&a|O&HS<=4!KAD_#=Nx6_oV4nFIU!BvVFEx-?exfJb4$ z=}E*VXgD~eHylA_a8#c6sAMn7nGd5%%u*UHok%)a;RqRuI_vs>Lrg{7VTX;%Xps(J zh(NQz-Ve@j+u7xj|}}&qPsuK+5NhUh(iNvqlfE~ zv$?p_!raU)dT8@{fDO*!$7bF1sXZVj#Eio)pG&2i>ePibLM>mi@a2mvj87NQ2yOsc$WS|r=oFMTeD3T-U2@bY5^S@M zr#J9u&_$Jq3b6`ua6$4TSwBy>0em+~QSUA5P0er!+d39mv(-I3YC6>2Y#s&d2QJMF zyhhCRtT-Ayu^DUdkm+C|bAc3f>=YM>&Sps(v3E9Ou(h^4C)=O+BPj{Kt2NkB^*Ltz z1SPPz38)rMBXjn&F`x5(`_%khxZO=T>!R%VUgtyVaGv@6F^N~@bbr9h;n9m7LP^Wk zjNq~lo_q5-M)0gSeZ(e*(s#Q{9)XjLH1&|lo^CIb2ga`XELRy*1)&FDuFKU4RG~tdkXnnfv?}m%G=%0-<>`& zI{6z{8aU+D1uwlZS8TnvkUAH_NIz6+;`h2{Kb+jEQEKI2o4c8oGgKb2n#5k9uJjRv z3o2x06-v7~}dY;#IHub7fMYF<`g(x?>(@i5|LK1VLm z+jt3QLV`>K*=sDFml5L>eY%`egM))N%+19)pCiMGG%57~H$sMlXVU#wA?H;q2cEd> zAOVaNMMS>ogVC`a#o+T_sd~k(bc`dC5_3TnXdqB+=x*>Zm{UZj?p2&}c1GxKx`*3+ z(?b$t;+frs*s6g|X!!kXi9_+zFIkv|9tF@zaNTe4M5KPx6G=_1>R)Phi;y^eC!hr=Kui_$XM#;Rm;Z@jK2v;#@ou&%%DEuFr+b3O*X{(dvDFe3BT7 zqbc#ek1o5z@=JWO?udC7c<0| zgg(&xHXb&GE0UojlzJJULM8;JVHLd6*LN|h%NM*}3Vbm+=b6A4zpXLPW>035*Du9pqVRADJI)I{1U(jnMs_*_B-{<+~gw(5$#;;&et2 z@S$YLq$khL%&+Fv_@cuUSDf6j(GI@&0R%*enq#@J2Hi7RLl6XUdeqFn92CFoACX7s z8azIdO4gltk%d9=}sJ_HFi_{4=zi#i&|2FVz0rI}M?A?4f)r zbR%P7-QIIC)a*NUyWjD(`fL;dQe2%zsFKjL>-Hf1C?j zUKhTSXgEo&1$=2&>VLc$*R#f5Z_I0*F#6Q|P10`@PIe_79DkUHeK=CLzuZa-B>agy z(<}QD&^8s~e0C_92)9g5R!XHUQ$_@9ylaU`m(s@7wWC=(0tZGASwfO)y)s@^hwFHd zcL@WK*@#^l>p4?_R3;N~#uo}l@psKS$!BaT{B|TZ&XDWX+;0u%=GkRs==(#20LCe> zbJIoc(nsPxqCW-rT_>j=q9|eBA+}^z_IrVHWB)BADJT7!yjQA7uZHwT?tv;sj*P&l zOiFEL(*NuZYIJy_`4jlO+JIYGRY`p^hOizJ0ltIxEfY9~8oRadWT_+NFo*q%8I58z ztlRDF%#gli!t{l-lxp{tLB2gD3m2H}c*yuzdERq*&oOBh0R{#34&)YLrCj~m)t!cG zmE?CQUicIq{DTK~sBaFn* zTW$VGm*Lr154i~{g9sdD>o{@nrpL#bL;46=5fRf3%`-%COTvXOUZ`ULE9^;!K22hL zK7_l=&SwrWS2*&<@*VYw()7`!1BoU>g=QcgqNXI|HN|zj(4&mkPjvj=3SVE%9tXCB z2|$1i4pCZHVXn%~sfufA_$%>IeQXez((Ive6PY=;ky8>o*9%s$y@w;w6VMWJ)Zm^Z zcW=@x4f$+-Ah+>e;Q0hosPXxpeS4mR#ao+0Iw1abu{SCDmYg5vh6py9H%kYfne>f#t!=4#@;KC{=dV{k8 zj9%ov7RINKF0^cXss!1tDqK@iYOPrb-(w0$6A&yN?vSCx9_RLn6?$P$B^Ty7UtsU8 zrFRf=+eGBN0scwxv0W#*QR9w|Mm9@Mae|$psJBDLy29U$U;h|SLv&AcCY*TYT41Ny z_0_tKTgjmk)2veq=uI%XMj$eCFcA5kj%ekrZcKC^Mrg>c2}DQL=mDfxGnH5a+_g5D zU#We1L3zo3NgePnz`?QCt~94Rk$d9##lEeVP(&G)GmPOdWev@8f>N?@E@EU1J>myE zC$)_xGK(A>yITw=P5cHP(LQ@LNN%DV&`VCpxXor}0c)kVKumufz7;@FCbTZr5qUxk zgK9%!sV6fJNQz(VyNEUmSVS>}jD;E3A5ZLkN9{Xy2&cVl%Toexq?xsJ#spggic41y8YLmZjrg6} z1KAns8bbG%eK(*{+ybUxJD3su{v_9$Q*&zyx0*qrtV-UY9u%{j$Db{rwUTw@ceh2? zP|S6u1eYf7nbdBYN+<83oERg%J1NZj7|8}H#!;@A{{DGL;P&5mjK7CKB3Peq zUh}bLxXVt!X;>AI7E5a&>gtYArNUXH7!co8#;fq`Sm3ruFF~iJg=AZ8h;&Mj5V=JZ)y#QePirjP+Imos(Rf9R5|p1Qs8 zn(>wocM!>gPqurk7|&YZqTzKT7^2Y2}ox!dZ)=kJ>T7H`pc?PQPD;(VCw z;qrO-{i*%=*f{prGMVf)J^K!~AX(Aiwzd3zkbNGPN4chbUClSYc|x0)Svt)h(Qdh@ z4*6TwT{70fxt%xm=l7!bb1063!`lAr7GILhC;jBS{_&huam*UfrWdP!RrB%%0BbX& z``EYX5|{|>W;yp?m#?44;%NU+k@d^xCHcF`BhPsB2EFld^0)q)Fjs>vT_dSZq_c@J zd?%bVX3l3-31wLKq?YYasom#Cr(OPT>cM@vm>i0)lQ@*KiG2HW zmFz3}H7Ld#c6X)fOcwIF*Ifj{A_FQ&AN3#i9s$7}&I+I-*8A1zJPHxd0Z0J;QDBPX zrGA6`z{Zr{DIb5Hc_)x~ofNb4n%YZFMFz`$k#c3Wr0_1%A{5Vv;vFK;d=Qt(qad7p zNyihyO{O&OMScKXa<~?@68BnnAAd0{;4Wi*IJ0nN|ro?Cf2;YSpS$JfN_I z$*Covaf(adG8?}9tXe*4WgqU`I_guJlcbBb@)Y7Mcz2iRHyM5~FDs4_X@M5A#!b+C z;vo>U?0eX99$eT`>h91_&e5F~POo;h`jXk)HJa*}(^RT1)P&vg(CKKpIsr<*V)A^? zCK$KYA1^@_SJR%Gl3o{IIwq`z_aWvvqcJI^UcnqV>;5l?4!HQUgO(kK_cH5Frrt-3 z2oO=S7@cZs#WxOJy^3QIhI~TwfRB&k-JtN=E^h+WsglcTPF>{r?ADr=fXnSZjmA^3 zBC2d1MH?8~{p5x&^MmoRUTU<;<9F`{hHac9KXKALYn+uF@I`S$)5=9lI9O7L!=>n- z*UA2p@Y!pOhZk#*NsHOWqLf{?J;>d2BG)2QM6^?WB7XN~xTN!q&ZZJxF}_}wqRy@T zPa1AoaUEG5ZizXU^kaooPj`AN7BN(`(p+W99llOz)V63f#PjkS{nKvK)W!qNle6nO zCYt1sP`4T7nPAG9<8QL)#R^DR2p`OZ<)SGKx^INW&a1Yir#f`kS+H9%Cn&D%?BWz_E*i;U)R@ANxDWxO%Z^@3neP4}ISFyhhJ2If)q3p{&X zMjtaNv2@-yd#Z1SR_6LF9U9;LLC76e*!CT^nDpqp1v#bMdhzCZvAaSp&0M}j^q{>D zj9Bdi-w1{02~;4?{=WS|Ho=^SM`Ahz@DO9T@2At4NbwH0V&a>N(6Q5R#4YD1jzh8M zH|Cuf9}BkxMdg9LRPm6LZqx%Tkj!Imfgu#w^a~_1K;zg2xkO4>)nvc|N_xj~ErLlp zl)wDGWZSL=15#~dS7XUoTzr>pGhM<8E~37x^>~#rkNZ%BZFA`<;!*KlF^pBu zB)h*Xtqh{$2Xz+xs!ybx@%POttA#DPDzB3xy=`yUGHg5vkOSsQ){MW`jONS+`%6r*2m*$3Q4jlFIn^h>QybN%4+=;6AO z;4tk8sx5bPd8`ESqo4b%%`rQK4z5r@el;<-Q~&m9vgW6JD-QPHnodzpmj^|BulV=&HZne9O-11qrJ>_0ejbSt}8txf2i5YK`SGQ`?e6TXx>$?j%i= z3q5$c!Q*Q8ploUJZ01xJ}qRKIcNOxsEd;AZM<~XI+N0G z$6E9lcxIpl4}E*Reew3v$s*QIXxxZx*k#^bV?Pe}JhYg-n)0LzdNR};4fEKl%xW^A zp|#rWUe|?f6~enBZF*hiU||0;l;t{I!=~Qi6?-6^e%aE{BTKYT8t8H$%UZttC_O`C zUJETo0EybsW%DfDaaz^xLFYf0e6d4l9^`thz)_0i(eoGdR_~1k^na7Qf*iu`jLj%$ z8b4$eD=BPS8a1blQ%vDp-RK7&HozVi{a)N_b?}nuZ{)m~(C(v|x-8&RaiPW-A&h}7 z>mu!)tnzZZJNn(8l6=;eY9vInY9v_EBwIXvsNG}NHGFW{wOq)+U=&3imZM)s%0!*u z1~ed1Ti$)Xu9pAYfF68}5ekWHs_Ad~%;n|Hq}qM@R4w4C?K^II>*=HiB3JX@UgI;@ z@v4&*9NrAjS4bQpRk8*)^rKhMQszn2mfRiI`|mRtzR9N5$+Y~J-1b9J^CYFM<|G$h&E@aoBSkTekNXAAZ>`V2-P)++ktRZZ9f;zRUw zD07;c@g|0?a?ymqq+@+>gvZoNDdu{l%~Zed`mM8hJK&F%I@+)q-u2*K`t5q4u`yB_VC zCFwifv!qRSw_jU7wRz;PC;G@XXjB-;&c_Vez6-1Seir$XPP&h|7`^6pnKhBmtuNR=<2d)=@oP9AB%tjbkQP zo^4QbtpLLA@2FG*l~Bl@m^xC=gdzTfdJI?ce<`qkwb%vWyt8dkA9Udhio_%eUKj_0OR zz3gT94Dg_n?P^(I=^**8neS}Iput~(xbq4ATlc04mmwd&ghp=~dbZsW-LD2cuj`u( zoi}ZrTNiqM?aS~p^D_<#mBa8|^RrF-{I_8qmhab*9>wOa`J0|5GtY;eg2QNymUs@fYQ0M1X6SN;C2E0Ar z=9Fava)EoN8Bt!T3o^<)d`DyN=>+>DL{@%6J*z1W=Z&Wqsq{Xfu zW2i@zFPVVrj>6>$NQpaUJNy#?U4sr!YE7-W&x_V4>V|6yf_b*K3=-+I{^nP*IZ zYt*_2?GUVkZ~J_Aux~cZuKZb~g3GX2zj(hDOincjtubzkyo_Y}G}uQF zS6$g3=#B<8J9W&f-7enX)*!q_ql>W1cwm(-4e6kjuy^LL)Q^4N{*iCg;{*Oq`owBB zmYUsi^pn*&uio#~2OQQ8U)pQnG zd-Ja`zb{jB20P6$({jK((pD%V*Ghw-jB#nqAKK|Ij#1aIHRK=YGp4nNff9f*^_kO#MC=7gN&^>Zko8Jz|{L;B5JYjXI- zP4?=&wE*|E>H6$5_P#$n2>K@0lF&KbsGnHM>5I+O()5}sQ-2=selrgYI_k3V63p>M!~vJt3aJs^ z7&;4X7{54O)vTK}CF)x8m&asHgbN;fIIqdO9YhUhf53W&a6mR zzaoxoZJwmVtg6L$)wg!^tqegl3CGp}a4ZLB>G8Fd6g>_3Gvu|SMcDkGJPy4~bB9N5 zC%#-3L*4K+g<03dx6Zp(iztS3v-hX56g*(xdp6q1Lw!={1JrObR!7Hf9DCA7Du<<# zLA3__4|cnr9+u%lF72&FIv^6or(wL~J8wM(I{fbTMHxQn=!t5Peiki0tbI8@nk8zV zx78!vcI9gBUkOboi2KQ`u{I``f?w9T0hlP?>Db{ewE!)e4p~OY@+tUZr+V6(AlTun zZfhy(ujhyzeB59S+dnFMTeALvqiLoIXqY`~{$q&apytTc_F|)(>scZOnxJq?N7%>53WvTfSbv4vsCw* z2^8{j%oZT*Rdaw4s*U6Tw+ZX!iLz^uh0t- z%anshi!S;PVbVp7J!27lJnJ&l+oOHemRyOwRZ|Iq=Q1fchx3
oIuGq0lfl!q<(lEn=@gd~4uOZ0Wgam{jB{6w zR`U+$?cm%Y#J8DR|Jzb~FAy5OpYd<~B|XzN5A%xR(HVV6Qzbx{t9?9wxPvh3Lw7M}L zo!y?tWDd6zM&XZbcoX?+j0%9HxTL_Lt7yFxgL{KX68X|&+~JGRXypx@xNT|xb}#n% ztQNl`^4agjkoXMsfH1ex-hSSnG~vC+p{|D<-6D7p_|299&uOu08gptaL!qVnX@W5_ z+h?xRL}4G$Ih6gf*A4XCOUt{)Kej}>cC4MS3G8q?dXTn*z1#YSl6ZRVzb>O5Bi&?b zS-YoNf0 zov)H%R9Fh!)KCAWu~XTY%IVMB(_Wc3f7pA(g~f2(O6f@E|LDdX%2^OMtg9*dWM+MK z@ojkhF;A&1uvHOrp$TOJbDDyh_wA+zH);*2df3XMpAW#2MLzZLq`yM`#BQ8NV?gy6 zMBS>QJn63n7MG>dzyq%d>m?6lLVrP`X3tgj)TaMj9G1`4&#aY3{lcQ#JY3*j%8hFs zBGVMHv7HR_FsaUZX^ltYm~c0g(yQWM6C8Cz`sG9b$5V~?Q$6@ooU{KjUT~ylH{a1< zc9HZ!6Ly?g{w?N+@1(VJ-tm-QG}|1ueAdOKB)d!YBBnd&EDEc-|d5>e^H6&+1WMt*PR8 z>-XaJw^J6^u%??XGWfH1Lw8}$F2ezpiXE;aCqvCY&LegpOJSy7miuxpz8|++o;m$# z2Qg3dqV_CRU2h5T=Bl;4{$zYi$LK4RH0`Neq=@irGJ!OFizzi#GoSf}$5kZj8H!DYYwVjc2Q%G!AIoYtRgTTisCL_fuK@37yehwkGd z8^9I!ANGRekzcIg<8@5Na0KnKO$%0M8u;pqJ~BhPH#%?x_UT>+e4s`432pQ=?Li$6 ziixJh{1g&-nMr9{<`PYi0C}X>sE7MUcMp?Sb>nRR+2X(T@6tMpGGptnIW8iAK8>QS z``vRF^F!xQzrkTgrAHzqUdvA+!Q@`5Qd}YB3FwLf2E9@}dLH9Fl1;C-XZt66D2q_O zr}lI?5759k%4BP0Ro>Gdmzag}are&>&pJcUgexNc8=KA++_8~rQ7XC-3-4_?`b9Lsi$1PQCT zo`#_l8vXNh?~x7i$x5k?n)q(u4gAA*aWszPW5k&StO9oSLEK0k6V_X!vo2S5@aZXO z`)LjmhGdTjMV|b}Dd+7W?1ho%2TZQJBOI&<&(isN<485MRks6HxlmVUlKQ@DR=C={ zqik{h3@CC88@+gY&ce-#d#aSgMvu1N+dDweC5>peRajz_*)d7e9VDa?=GsrIGi{H@GS}VcNtr5jK%x0K z-kxg{+;ZP1hCyRvj@?oGLe@!U#eq>AH`Du3?ES&=5J~C|JT>60MK<3{* z?z0&1b5ESX$=&S{{?Yu3nsmqa%%zQ92f<>^NBH7&2;%(UlH{EK$aW42a`uXFE=jDW z^-6~*JZ>C4oLLSrV^L1kfMYaQU%Z$tUxW8IunsonmxHT&el(9>TmdKl}q;&V3fz_#?V-Cp>~E$}m~W=Lkz`L|?p;B3sXXva29Hg?{+ewVaW%4CO@U9+e3@1>)A#zcM0+3IR_!(Yc&Fn-{Ef|R>x1`S z^3vIw-|&s$xYm<(4fMVnd(OF07zyvP?M3g({WQjx=KUmkJhi{83qKd*lys_h=(G5J z(u;iCXp;pl-XH1b=ikoDJU(0YzUhUer{iAx4e?t{P&h1dhgJGf?|sv%7;Z269rvuk zntOK5Mcr7;dOzJ`GuxD?kS)9=Y~`RJeZcpb;Ud>ubqsw1O~;C`Fm;tobFepkO7 z$NGH{pWnsrG;f;l^cai0IkZoWr-LVbg*I7n9y(a&XEXid#Ml`5FNAtf8z}y=dDz)s zes|ctdfQBC%ExA-+&XmscxLAI9lw6v99+?`#(BR>y70lsb#1fdhq+tP{LP2gS%En$ zH@@1vG_iy_w5PW+?c&rTqLQ&mh?DX1M*NgFsXR6-o-LqrV*Q+tn2d$GxL0*&k$k;)_)W)?G4 zxiz;EQ3J2T? z*=?~B^N9H^o0*a{&jq(x8I-6S$nK79Yx3p^5#0B!dpZ0387`HkhlhO3mq+2fMQeH` z{F4@|hmERm+zgdWRxEoS@fF)o6+!(^`-NuaI`)bAn|r{%4u zltCHAxQE4h3i3toI5?I+97sjIIPky2bB8g@iQ#kPx6NVX%ummQ$(UDBi+ywK#k*_= zMv$#MPbID$L>!XGcMvbhZ>BZIC2O5}__daBaJcEoUduzB^iY2qtzcA%O;1|K=jB~{3mdF0lcRZU%WTFAx!VV`SvZ%{D30ZrnX>rC6D?q}JkDyHGk9&W!S zV0h}?TKDW4c59wn+xaT&SNo^;+QcZPdkX=kT&knY7h8>^c)Nsw=(&1#4@>@R4IaWq z(Wee&rB{6jc8_WVx+(vyBvlTuXg;O2#}?b|5u^kEn&Z6-+-cej60IO`*BJc%r*|B> zMe`EYk`h`@|4w*8{=R zp;r)i6m)RcYS{l8tB)c?*)KLU$lnp4#Aq6Km`C_XJ=a(z*J@MtfQ^#up{j(&Hd|bn z`+Ax2oRV?i&p*bF|E-^9yVku>%&?qiZ`+?7vsG*JL*`)VC2b3VdH2u=yEXyF7*M)@ z=lN1+vb6i^O$H4~+c7@CNbSnq1a2%rHe6<#hr5^?OiRxugwOs_l4$bm6LF4ylz{q% zl4;lJ3>xx1>5Nn4Xt7SQIK7*F;oNjYb)!7 z`wmGUG_^*#*4TrJ&2Dh`Yp+Oek((hYdhD<#(8|G&;`Z9P2D89aZ(agC%#!wD!?94) z^QqSUhwte(2igw)7VdsKBBmd=yx9DuRw-CbTCi957WuHWJ*xkJt8gpllvIaR8?&EJ z|EbSa%#^p=?=vQ4s{a4=`u4Ojh*h#lqvB8IM8>RJ^f=QY`aL(tUvXS9QZI7q8>tr} ztI`eT#mGX5xKHCfiIW+dNj4yD&Hk#Cp;;ipJOrms*D_Z!6R;xg(|WW+Me3#SBVp8Ujw zX4)d^%kGE4i}0SYOKBsc z4|Xs8(x!l|`ml>K#m3~1Zt!vcvENL9*-1`d!l*ykm!BjQ=sm%&*J2TMzxFJ#9gOez z;6Qm+Y!`2poyr4mAXsIE!<=1rX;)?U={TkTVPFlH5z(f_sOcRrmh{&61d>`bD3-C& zV9k(huP2imG2Cq|G6LGp_i4|h1CfkFzo!U)h5c(5{|1$qm+`^pRLhhtr~Eqry1COh zj{7HlGF>v%y3sJ~TiavsOWc2L+gM)8Xu`q)Z6)V99TKEFu!~i_m}k{Av(Q>Q>maB5 z1s9!5yr>=jkPNEqpAUAnW{(db$N3TwD`n%r$2*z0JaBM*(>!Cn>Kw-%L|4z?7}RMm z^HN@>*a@=qiBe$uYM*}KlC&HYxF;QOexEPX(NV|1MOBTJez^BEh;iW`HzbHUto9dWKISU!Q5W?oujUIm_tPALgMki|7Nv+UfY z!kI}AJ9m>0Lkx*Y&$4SgHE+u`0`KKXb2_0DVsMF!?#F}d_isNhEe<1kDRMFoi6zJxneJO z0|UQHWpF2+uGNuu!n=6&9ZsFmJT6BCn?Q5o!|PEaO|%=1tf<+0D3|Tk_z~7=FVKxP zl|UPuZLew7i#gxaYw^gh0|oC>9J)%MubBe97&Ed1nF!r~mb-D8{J^E&cCx3~(DTwJ2)mXo zc)9Ml)n+~4)v}lCo1G#!Pxto3%nlE~s%^f|BaXY3*M^PXmVZyd>R}j ztx^7Ws(10`jLjM88~z{PzN*Pxjf=RhaAhgy6LX7CztKnXk8^HdbisSWLvdbZ&H-xt z|6H^T^PQQsonLSyXbjv4UROL;p$R+CSQfXQVtAke8z;?odc$n~)kUi20B_|v8J^P{ zXB`_H|=X zp530;g6k5wi0n!44exTjUH+nFG!4dKBY7setQ(@aad55PPKQ+Q7|wZB$h%+Rn791x zmw$tQau@y{|MoXA)9#=O&brsXZ|ytlNKN_dF~L2y%CemvEEGNEMBo{J8s{Rg)%DbR z>d+%ya(rAutyAwN#>o;f1KYDZyZa$OAJ)U$r6NUQ>9xaO7t&^9=S+o3|6jM#= zXXT&H@t0_ytNgP0J2=Cj8>X$ZGo0kw^S&65hfmq{pB&-{xrh~-)m1-PwVf4 zy}lyfcwihb*ykVp{_N|ezF{<=Q#{@8iUf~r7+X8okGLZ`wHO!m3+e^9f3XuBm;p{&k|?7i}DQ?Z_({Lt3%^-BvYumH&=wfHrN^;uMNc6bAVym4V>PQ0{NP;E~Cc z80SrUe=pB`v$0=I`=&Tk{!q%nIPt^5wOi~hAUJzcX8{rowqbH>kLY2#Tf=wa8~aXI zJtq@ha1pmME8LKn3)pe*>Yp=tpgPH5M4|Yz*ydJ*p$L&SL?iu7p8l)dsG7A@FNFLW zN26DGC$j@@v1tmJ>6N)2 zf477C&VTHtj>&E=jnp@^&V0+~aROv)NC_Sqpm#1%(9_h=hyX3$V_eh7yCI6IXs(Yg% zDH!xDec=B$(xO2BMAni2e!Z^7LP_tAs?=!M_ddJHi#Z&;{;a5*XUmSe?`w)>es%e| z)BDf_xHGO{P1NlyGyFrV6K!EFf$GKNCq4LRb}pm&ilh=zrMnuiU8>N{U1#@<B-35_db>xS9%6Eh~x9zzDqa5B({HAfp0N4rVuM z+^dSP4rvZIuBk}6qHmSu+xFq8&ceH4SA^J>S76CBl7ia(~VN=nr=0UvcQJk-pkqE zO!PQ(T;5}z5=eQn-i~q%FD3iX9?LftSqPp*2CtESbs!E90?+Y;n`+~maeb_p4s$2I zJKm6UV;6?e1iQ!>z^Dh{#aiRQTov?2y|=eE&trZU$If^82M>WYp}#Q{3+YL+SjMLw zlO3R*^6j3G?hkbg$5HuiztMFUmh$iH?6;8!To>OU8|K0NV{iZVvUB^oFCXTDj^4oE z^f70Q@Z?g3ii;vyXa9`ZIQZJ|B1tc=IqQ{pkTnbG+&OAlC;PE5A|ZPc~R> zinwV|o!0V_D*Yi}<3C!ry?lG&fQ!&=uMXyL-@*(3wXxWBqd!w~jmv-!cgt2c9&dVT z%zHI_mBt5-Ogeglogndq=e^I)?@cWG6FV5W^nQ&#N@H>O^L~wJwLOein_-ei|MGn; z^kDhM;e1}K z%a?TO@}18%=N_r*^~%=cYn;cPw)`7BX+m5u-r1qwcOph0_wy~x#-z{(AV2i;8Dkmm zKR>@!D_4(omN(YcsWH?A${QeJsOWLnXhuJbJg=h=1!s~K9t}Dp8K3A|H_7r#oy&Fg zocHIA*VBEDT6JiKBl+#UT1=N@7`+k@dCHv?cA&`O#f_{xCUj<;Rq4l7|MRc(w8dR& z%zyUqbfO1=@%z4Qcso`mCLQk^&{|F&6O&VTvao_%On_coy{_l|Z#L~Jj|gb;C7NkJMo%%+X zww@!;stLTkhj)Clj(#S=F0$9-G#X&wo9+2_T{RB3pFA;jNHw;V!@Cz<#XJ++^n%=a zI6yR%N-fkwp5~0B+(ZLU0Ni9QBIc&Bg|?P%*7-e%+v(@r))4@iitwj<@!pjQ*^UBe={FyR*y;laTS}ryRauK zh-pYL+URLyVqh#A?PfX~_lKvtudB6+E60IcJ(HfA`^MSNl8*68zBp^g^=e6qHy77K z4XJd8uscVOMD1;o+kW8L;?-eQvG4YkW4U<}=>*je*kP0-5282pUx{+JJ|?xC1AerNANG;=A8Vtz zzwE5{qD4*6XXa|Mya;`N$_&NMDBviW^>uK3yp~Nk1!K1N(4oVN4zlll&k0DtjzWLJ zy#!OdF)Cl_ix@(f7k&qS0e%!3)Ov5blnGM3T)6!sMoMxht%1z04dZQJx1(%bOyRAQMc~y2->Hsroz1b7>6T7B2yK z8lNnwO>X!BHAbl@()aDj0NPug>O9*MER2|s{JGJ2b?)C|uSR1xyXFZb zm!nBZ$>UW&*+b|pMvmyHtNImgyZ;tc`wmsWqWLYGCF{D)t^>O=mLk(v9TxQ*lvm8U zFN2Bw)0=y;KF9U`h3_opG+T>tmvFHxjttTMqD&SO<3V@vZ>RT^!f^)mUhS*8fk7=^ z+tmmgN?uoImFr1w&0vhL{=tunsgJu~;Qiag+Ov;DsIELhg7r-GN5mO$p> zEk5gwmgm@{Qse`5E$!R!_>Vdnu)N#cc@1jsro9;~smk~+{O0hFU*>>w-9o?8DL&3| z_>!AI7#<#nEX*r?I$kws&@<@adV>dHAG+Yb-%wOGXYg*pQo@!Jdu)H69;@v#@A0el z@VVW_M${4%KepH7?LIxL9&SX3IMb`9CuaFzlb6}5e=v_2ih7P&;k=){IpN7=c8VJd&67v)7CRbhG}4#>6kf5xbE{YrV=3U6#w!@k2%k z(@wT$6#TNAR zSS(%@bX`ij_%2p##J%Vl0TF=ljYZkNoF&Sn;{4sjgb8%&Z5bbA?ldpOMKHOx= zxSV>ezwafrV=g_<2jP3hp61@GzcB2>Vg@0?)hz~}zB54xIYKW5RjumHx5Ej(wilg_ zebjsF)Ji0Fj>o6SvZvS6SJ-fRCRIYOV+5D!d9dMi^n@xJp_5tseRM&Z36p8oiEHeg zH8a;~p>jwyIWDr=eLU~I?p=T6;HwlgW6z2nA9vF$M(oaGdL<=sJCSJ()by&7IJ|Kv2>wU>-?ws!0qVV%3%}*#g(+u8>gzZ%afY- z_`M#as6-lix`WLgRGHU@jSIE1Kb?z%_FPP6^0(iEe~qY@yTw&m-dra7+>GCbr!ka2 z-di}UzkK@NlW|n-(nIA7cAB*;p0oe3)X|h?mhOTNRgB5_8&ergfByr#vFn!Fjs>qD zLF5j*60<{10ZzWUnueT0CO_P19iQcL%tb%weii7@xN8b@HKLzc?xWvkcZ?%>N;bMZ z@wdFLR?*nC^PI1jmQIID;G*e3jb)P#a#B8rW7yb8=M^(hz<7&I=^a^1)yF0?2S>SP zm*gSri+2Tl*O<5l&uuv9WmCS$lR5NkO>@5Fod_S>GiwN*sX3d>d$A5j8(iPCb{iAN z`fg%{#`@r7tyof~%k{GT4lf%1c$xNatQ=$fT6+uT!#25FhMQa_cZ#fn+&o@3!#_%F zuXNqgmrwDrR-)fUeK4MV$?qt+c=NXx`e$edFZb0OK6Egw#0*BLv23j$Z*BNb{X6hY zj}{t`L_V(>_dA~CZ5@YVKLi5Gc_er5F&5+VuCWw{=xc4YdS&2}W;gQ_U#Iq4iW_oC zzULr#wz+jeE0L!0<0|@@Mm}l&nqC6|2eF4I>SQP_U%bv4Yt3*Qdpy;&b)kGJ-5s`( zRnEHL>&5$g{fnyRoqZKh96iHMzNOOlfy#jN;{(r!pHiOd2^T*4`!&7g_LstQvCq?S zDVdZ!hQ@#wy91pNzfyCLqjV8fg!^Igct|(6Byxg_Ame$QbR8IFC!22!aPUm*rT6#` zsvEO~Ts}hj3cebi8SB=4Z#0WI*_h4i*DCxk&iN&7PnurZ zWGlRxcMD-C|2Z4Cwz#>fmvws5g;{n`E5pY2Q|7VuI9%$Z zlHc)H!-rXyS5{Xp#1|VpKDAFVAzlb&Ydwmz;cFk~fbeE3v}g1Y;^#$fG!sn=^#KAh zKec|g+{OS5XRF4b4V7EHwzuh(It@C}rU>m{r=YtrD?b<|@3PgYYn$TXv4n_r*5GZ8 zUG|3w5f?0<^5n;9k;)&j3BUbqz*oJmaMGXhzxJE_7yHTXcS#-Wl=3BmV=MWSH}u6E zR30Wctox7CR%(9v70xZ*6#We5ornI{nT}UX-0=Qk_&2?f)7KH_gZc!qk)Yws;Emyj zP5EJ~eO}LhN4&U+z^26)gMz;X=9}mG*`!}Io1_g}^Ve#~CmEivg{v3c#-vtFASn0O zyPJJu>hrX@pCJv-o3+*)p;l*)(eLVc)e2#?LFBw){d+}YJi8iTIQ`X|)1tmfFJvUG z=T4WlO>hz!N4!mi9PahLt;?Bya36XYs725Qw=u_I=Fd(^dJ=Yc%9IU^aKOBP7G=q& zgSmQ-%b5|JI0wh8yMr>-JL?TX$8vqa4%;vH>aN!>JWKV>jmNhVJ1AV5H*CcsMr+s} zZeABU8|||Fdvp9@3yyluyLzWEM~~!aH{&V6E_r7E)J%e#Ryrjo= z<42ZK9k_ks%Iq8KKYO|%%P@EmTfr3GU(TW@SNHf_u5V1H{ZgLmta7yPTC*c6ke#O_ zIcF_fP#XNkE)+c!BQ&V}ivT+2E8ZpJ|Q6+JsM{i=pcjpi6%l*r?Y|K~DLLPmrvd73< zPR%?&r`@-zwY9^Ix0euPbKo_-NE#(Orq|e1CgP%kRS$mBdJn#L+WOpO<4xPMFIQii zvecQ$^D)9Ak9@RPudr-AsY{2B)kr4ir+emkX-n z+x227AZ#l}tdHmd(CAS}R7Vzs17_RLv#+zW#s%u>{Jo-ybc?E^BK# z7~Fh>m|q=|P|JU-+rTf{;#}BfxXK6BDw}`N=MhpAeeK(q<$LG2Ym}ibot=;U{?R_h zxI9R7?O!{Q%EQNLt21B)h6{Ep%WlpP-ef0&k7A9!t2oW`Q04J-hgIa#+U?~#C1I7~ z9}x^yZ`tO9Z3L)WN9ow+dD0x>Rl!r~%UQYaDD`|`HDYJZ4e$v4uRCMdF@{(FqKE~} zM=iL0F3msa!`T^&5u2N=mzA>0IK88v>;r)daUbu zklrhILjoqfXs z1haT9jn;4cm@5Bj=ho3+doJqSxVPx*w{F$F>Kks z4mYSkVTJ6N$=Osr?fMtvRh0Hg_NhX{VK3;w%GGGPu16FT1QV`^PFtn&Q@vs`K~RI| zs%y$*88^vWIT>@|dK-0LM?Dq@c2(8!9dE3_BE9zxTuKKMGh0F0>q1V3m*NTZ3#!j* z{77rd+pWzsOkYT(UHx^@$)z>ZerD6v(5^X6Cgw!`E}J~L(cr0Ma^`3-*)!H+<&cY z;4@L3ySg>A?ad}>KA4)Ddhw2V#=N26xsH!|F@m_07dh;u0-r&r!wWtCo1g8}+P$vk z#%^%b8KsL)AIi1eHT^e_(VlESeXpW2oRzPKzzo;Rs58%snLa`UwnjfFEx%8P=5j%I zzjb)fGmUK+{kYCJfA6pl&aJwxunku#Gauq#5M}tcSG>T1e`_5k%fHDEG72l(=W9(G z-96tYcy3i_l=R{2(--$N-0!TlN*zKxrylLb=HK0(9S{FxE4Mm}@9Q?HI+_N~pv8~w zCbLyt3Pff!yRxse7a_3Gu8-(Y1jcwbHCJQ-Vq8~oFQean?iC;R{ZQ*+{JXDzANmiD zupj~t}j zF5ZB?0$H@MvPhcMRW=m=`Yc=5!P5RcUR!_3m6S|3K5q;+p%Zo0Sa033e~&x-tmF#f zP{rr@a4c)+_w_v9m;HYcT-&R`^($a#Ok`h0c{2XsO6bw>`$*earl1c4?M;I8Vy!*S z;H~-Vxi*!jy1UKnU}f|$5n9W)*~&w!^5^=#wU_sSN-V3|K=0`a4p|zHZmoS>w8m+m zpyfsM#IU_YYrP_1|VzFKO=_+TutJ+Ks%lK|O zdYdtaTR;D74__a>j`vPz)fB_#6N*J(Ha9%Uhyd=Ie)Mr zb^(Q0(~|$2$77fNBw>P?8n2Q{Qq>8azif5m8{PDcpZC|C{m@@{h;#LW%(EIRzOgcs z(|d2}B*}5L!|4i}&eC7%m74k&|HNkrbL4-EO15MAL_lGlTrVSE#|s^&cGB3781Mbr zg)2u7yrer7equVtK6O^-S2bIExo$Y)Sq5i5Z{O!h{qwcr)tFO{y5i98V7tY`mn-^v zJKy$Bi$&w9Xz8)-s8u?g9+G~4JaUZ^Gc0Q6+XdE(Jujd4&y(vh{e;f3y0ovSW%0NXpI!4mN;V!1G9NsdsFC09 z2&KA;2HC_-`@+dU7r8!s!bBYDbg;{VRhPWq4Ml}w{2;#jou&?_@h2Yd`NNsbKAn_B zyF1(tK^I@!Wyictrgvs--X~C+-o5({sIZ!ox+^J11eu=zfgcg~ZJb@i4%`|-J z7>D7c?`g;7duAV>@T9bJu~Yn{wbQ-hAZ;u__%**^LUOsu!y?9O#&5Uf#U*lOC!#MY zF*Tp=?nPaQNBoIDA7#d$4-m6NW;~T6@d1a3o{4xtpV#QLxSt;G{yHLj*_iw1Z9$@w z=?SwB*lTigZhS<35g~R^sZ|@y7ivr=GI009$A$A%vDfA^uWfK0`Z3;NMc1mFwyjk= zO|RL!^7vad#9#Ij%^bYn2W!;(Y4j$W*LmBI9G>)hMUelJX3(IXe(wlytIVRwLGu(z zR!+CEv-MNPl3aYcOJDQv<35{g#8LHMQE)(8$w?bJ*p5R>!>FInjKTPnq!r%|XUnR8 zh}iw{(LK!c6KdXo?XQbntZmuxxp&IU7xnqoClk9#7xu_{I#TZtV@G-zZwsYEkD&`U zjenqDJ$o2ISG?v9&hSDj+Vsb8UVJpB4A0oJjCNXmh)pc%AInl^*F2+Nx|4ff4mXNz z+DtBt$UU!O{^0m{w}z))pH{E@b>D51^?mb0Yi@(qZ^}U2>JOauJ!a-fCi3FT#Ww7a zuDcU%LD~Fa^B>%7>}S)~ZJZAXo^GC;@6K!f=vTdZ_D_Ldh39~Dm<*pcVLO|bVs9(u zQO$a)6NSrpdRs0{nbx_rN>U_m{hr8#y>-Xi<@oSH2b+dUanXg&sgi| zrt!U{Z(zZL<@|j{A1X^bli#19o{P`Gng*|r3(@C2F2Hj7I(x|YNt4Z=Mf{s=J|4s_ zlIc(i9!Ee$R*TxNdH0aq^z~qF&?B=p+vV35yF-Sz)~nX`*{Z0iL}L1uxJ4NUw}!@H zNZPZB-kQtms+~Ty*GoFO#T|A_<4JwIyg~jn>tS3kXL4~jIg^^bs7*sY&BiplW>n*N zsQY$T-C3&sI~o08-ui*H@$cHx-|>06DeA8e|laJ~;Yq5dH zt{!y#5BKejcN>Rd)p~u8&q})KLhWcszMLM_a-11LxLS93Hy{ptr?h@OXqNwOUValb z6C-7-*Xk^54dmH6E%rWW%+Wv1(0A1_Wk0~BV%&jV$Magq;99G&e@n)F2G(_LFVTBa zW_?W%2OmRy6aP59Q4dd_51Q&4$%l~)Tmk-IcBRkl zQH;}SfA9US1TFS0+8k|8?SLBPHhe)Gdy8){_5h`NwDgyz=aH9<{PM*9i$-BTzBC%N zvChuv>8j(vmKomKf(B|t790D2SoFhE7ow;IDvyH?xAvSocN)ExN!{cl-$y=?OV?;m zc9ApGl}Hcf%!{3Cd!_f=VmUbH)4Ve`1dxJJk8NCxK?Fr%j>DM*eSN*6aU?=>Ep3JU zO}=vdr1j~Gp&0azclWnv$*&$BE#Ha3Vu06=4}68yi|(qoqe-SHAl5jPwi92eX5$6? zulCm%2j(6ch4-$9E6u-IFHXgM6Zz-#J^zV;>7M3qS#&iTT#K1`p!>PmoX0Y8Uut1O zOMsc`4CncS&`dB1bgLkB4K^!yjzVpv39Wea`X1TrPx2OY#v;5P7^W^{L zv+Us6upTvW5ZoFIz3=X0zx4E-<;H#i>%m#zgGF|K(xp+9KZSo|j=PcAE96G9Q@t9R zaLS_lu6Z?v)0JWhC0vi9#zA_#S$i(1dIZd?@R&k7ejIf_cklK9o}$^{Us@cWH%Hgh zyDMzYzS!KvKBaD^8E59mhIRvkxq5cB!NKB#H%BZoQp*rk2U$t`NhOKcrF+@P>Swja zU2qf1JBHD_`#g-9$ao?Tnukm_kr5OR9{6jP4W@$YV8dlXXeQX~fAB=>n*0Uizt#~R}jSUDskM~7p z(WK*eL>@@R48$sM&bJ2all~h!x1hPpKVeh-$6E<~FPvCPEHDC5Q)r%wIm)ubw|y1+ z7;AVa=3r(oD7#zw66M7FiFzaPWtT4XP8!2bn9ZZ^;p>W@-7*AnGBdHr6yxKJS{K@K*{d6K8kE=V5AOVCNSMU?IvID z__t~n^VcU^@4{rf^Sq0-UdQ1Ohn|I=@8O%Lz9-QbMWJGodFaJ_w5eMs%>_<5jYz%* z7r}eR0Pts5w30ToYAfY|$wTEo3w(ZP^1G_r;8B{JyCe6ZKTl~guadvvuZbny=l<_a z+$Ss2UD5dg{)_gk?%bwoPl7HQ=PNAWPcP>(#RlFoRm<6?qS+t5(J>p=Kf+hp(C;y> z|7z)L;q$IgF?os_DO_<|7jXuuQL&uhR{l`@!R|!?HADDk@E&{F8(Z|WY^ApXtGUJ~ ziAtM@#P{dLAfLF8u=H2OsytyXYTajXZ%xuPj=54}uSbC2m-YM8-}l!ijE1({B0g_b zo5%RPwcI|%=dJzjIsRqC1kL-avKp`Aa}2q)ztbdIy!iKQa7C}v{AiAk=D|Rv)3vIu z#Y1|;(;iMp6MH!B5$P*7Vj#B+VfQDmmGCJ*~!~ht)gl@C4bJWq<4?P>Kv{%0WlDRmPZOrKrM@= zYthCOLLv{O@0m6;Q5(Ez(MK2plX`q1)W;KqW>bT$AL$T?|JuhrqHy|VNgwT17xnVv z6icVes=3fv(43tL>ZZO8i>uMHNXCjP_${?Ci zhM7$)1IPTpsZKiEYnZN``P?29HuW<$A2{Ul*^A%e>1Zo#o-~a+rr{09FU8~J3`B}0 z`D|BWTfe)fw4h^n(!R<*-kdvaHrCc;wDo2EMa7Ki``up7GM~qJufmTrcfYzI1{D<< zBh`i{80F->t8O`&;1tZtebg-%OU{!)lWYF$R{#9`jD;4zhmKRIrsP)YxPUq+&)-$fjNkOv4akSNXj`^n19 zL~%?t`0>xXi@HI*rRfmWrxOp7t)^3h8@2CwfDAIn+8a-+R(x>6J@>hdBts2ZF4-|G-A~3d-sofCrv!RADCz1&)Gqi z1N|fWa;N^XAFWoJ;+leK{YzT9T34iWi~Zk6u}w29|B#radFw*Y^i%C2y$hJr$2d|p zqsei|DEpK@v})flhskK+o-eF}4S!TD`jE_n^E#Y(_4$^6eh0S#1-$H>6)>G{>iuEu z%^nf&sHM0iNyC$+9a>Zm&K+ttqM@mqTz@Cc)fnI6wf+rf6(|2`i?c#!5}q04YC9hu z(b^{Lhh=+Qu)PmI)o@{R-}m1(AMl*}=8&(5jY0rI1*F&_U7TflU-)tse7wQDxNCpQ z^JJD0J;%;vRPE%oo5t?w7$v*g53w>+GTvRECP#iY9LhzQ=K0zEK6#mc_P>XsN6`1> zO4Y{++ktHbhMZi$S>cljnpd(^x-BpAmiFn8?b98MMqW4FUD491vp)-a>RnNNiWYiH zL3zmDu#?-5&ghbDwwwnGbFe6a$JOT7hV4>s@#|L1qaSs648{)(I#|k6(Bs9Prk{d$+nqHejkT>Hz=5$ zX2#!*T=p$q&A+A3w&k`n5mW#%_*{j=eYxDJ8*43}4tg9hh^`s;|0sT36zeWq(@dUL z^OeuS7n^$Zk9nuuw~+fpZnNHYa1{k^aebfPtSe#({08>2vBY-_{DVf^+;O;rGuRVj zm2vnMr*jXuHqFt;vQT>r_rxDZC!|>|v3ytHl{e)rBZe$Z1niV6_@=N{p$Bfkg8aSU zOOf6Dy@!-goNt%3Q@@5g-(8C*;79yGy#=CHFYVA_ILTA%Hh|KIF~+uw_kv^3SvdEO zTY-PxxBO0lL1nMC3J&&~DwWHNoVjWJTQ8e>8uzR8KKbn=I)rz2`h9zo#}~D~6FrR_ zkKdL6foD-U**uCtG!b*H2a`5w2CSW4a*IDE$#w$UxEGN%w&|1iSXf~H>EyUS%@Sx$ zbd)xql$%!KChxl|UBi+N6kdKraM-thi;4Mir6L)K!z%2_wVZRH1|czSGi&W&R^~FF z4`Rh0L+q77%27upmauElWNYUf?+1R!6MEh6{LES&j&lS|cWyl`yDHfCe@N*XtMSOC zA2~U2)=VjPTJbX?<`&;$g$=%f}9nx_jchlQK0Pi|xL3Sg}`WU^smk&VB!>x(&Tt zGYn7uw>;bjPP6zobdD68Js!6Dx1i;{EX)ngM%F1h!0wdHLew?@5pY*~v&?>!FS}qU$_rT6z($IsAWx~WSTa#7$7SCx7v@? zlgH53>Eg{N1AaoHqgRuKRt)^6&%#MNiTHqhhp(V*v3K=8+?*g0b4M59bjQIWG*gqa zT`oa-=S^|iZER#Ww112Z96@f1q0r~TD?1O}kyA_+de_5^ITw1_$xgy1ulw#OyuG>> zsl!jcieplmaeS;|$Tw#98q?C;OzLpAwB+z;cM_A%TVBFBGWnCM^tZ8|_N2LQc--QK z;I&@muOcM((O=sm*cjSC%VuLbe5$hNTd4N0u#IbX+Iai(1m}2sx&IS0Hp#wr_^dx0 zd%Hu`{@29cWWHeYK5O>XEoV(jMAoVa7pab4GxQX}xqIAajc)t70&^oCSR|jNPC^1= z!nNtX|K@4Rx{A18d`Qu!_0&3IZRHvaS=`DOe@<$TL21A}N&44~>xcexa(4Uua>lBj z*-ISzdK}>cnTONi+B5waMJ2|w#vx~o!EfLB{nPWG36LLeX7HD&PoWi-q-SeCr))sz zA9QMs{b2aQ4{6GkU4m8xb^RF`8#CzAe^|RYe=L*0>(rt&oI5 zG+M$`*X33WJ&fb{F!8)G&ejw!BtGk& zKARqX@E`JBvAGHk3QV+p?8hv0Z0q~MaUeBOr+ocDd)eV=l|06^vh?@SIirKcQ}AH= zlVZ?wt+FWk_MS=~Hx|%+=U08TRFP9W>lMZAm>x`gKRIr$#HBAdF{dUiY*5v?T;U^j z2jkz7iIZZ98p?R~;xFkp=bO?N%^p@0RpYncWtRopSF1^15AP1W<`Z9YGFI>vv!(Gk z9n-*5N#l1R{~&#sH?HMeX}7$}4@N_4MXux5C>v^z{QRUY=7mMyvoT);W^j8~9Z$Xa4pu_Uv}Pvly>bg3x+s`1`X}={-(3 zEZ*ax7F2nP%IA;04L@cvN@mB$nH3it_v3W3;otuKiI;ck8_-v+Vmu#uTd1aY)YG({ zXSoVrwO6!iv_8JRP; zb>c5wai+1{=X1V+b=qH@39ZaH_OhT0*Ix#v58k?pXxahm5r6oFK()Csg} zjq_ygRoXZ3Es}3+mLE%2x3g8~Yug@1p8w`@$;;iRjV5dDTN?TvL_+RG4wQUg@HRMfeESo>qRY2+5M7(@ z(AU#;xr?4YLd5hohY4igu$l%OB`bb5e#=qayT)sF2%i>zdfxdg`}CLu9Lv{?@29XQqkCvp4>G^DIFm-@mHHELezT&?R4hk`Kfm%tLznH= ztHi^my0>=~>#K-;+4lKPoE__F)>+deh|i8okr=C$F5d88bujd%tNqXSRdjP9rqHgY zr`p(gGHO?IenRr2uBQ;(^uY~Wv|poEHTihag8}6&{Bl?^54$~X1*vDjzZ-1Ah2Gz) zX&CqFp8NB^_u5l5n7$bg&)BXFPB+-3<2TsE`{7>?PdPVPZC_^y+Q&MhlHND%{aj*W zu;|n|T;aq0mv?Ha!&<^wol_d6*w1Ko1k5AXn=L;(pbc?VRdmKX!W(DByZ&CvPydr+1d5rpUwb=0L!7K6(4BhU7*C?hM(7!ww8zsg z+Btge=??6k><5kI--KVZc6~KtN?CN6xu^0Fxo~MZQ2!Ce-%TgMGb5TcG*3Dh#s|J}*)0|k1AL73>#8eX2BG+;wIFPvvFvAyHZM*>&CP}e%3vWN zJ>^@z-~lf$DpPP%N~1SdN!yINsiZjd!}!->90L|JkN?4_F4*(HF9yHgV5?RrR-pz? zKVl4x{ruio(N|s$4q^s*$l%O@B=SfDKZo2J3=1Y03XnmhFkX=>-pEJELCNb>G!bvu z;cx~c{}}nGG&A=@kEmTeGH_AoFk$D>9&H~dzh!uS%|-?cm*P#8Rb$)SHpUMa-9Pew z58mw-zJWi~Snhw?kBA`N<^PgP*HP<`r|fb0Bpyo2dwN=T?SrAE^Da$>Sw%Ij zF>bWJUzfe%;u%BZJXsBBhhDepS@)Ko+r&o~sG{p1x;OZDj-RYxSF=Lvf?lZ594N4H zFXoNeplIB?!QQerm|t3l!N(cI_s{nGXV`;S<_9Y#FJT78->>)g%Y2~zLSGv$#pBH{ zf5(Uy#ir{0XuhCSNx*c6Ka%u^zb_?6Rp{`+azML$MgI4i{9h-!)mYO#biU5WJbpRw z%;Dg0X=ATCy@K3Yl1CHxQ*eUZ}DPj2_j{M-$!GfQl zj;~rIeXZrNp72}5z}#OEXJ-~Qxp2R*#|EG8PHe_ATsp>}M2N2S(3v07U9Pt3TcKuC zUa(cJTQ9f^awgeev-aA{9hEkVc3U2P8*L3_)_?1bU&*mBX4zr!wn%L;mwNwt(iYu5 zVJMbFI?LZ_J?eF-7%idu892@@8!0udf6+_hwpcO&6vI?4H2OyiT>7oEZ<3{(=kOGw zANjjJy&l1`M#nV->)+9HCf(q71mh_%LHgbBk_MbU*TR%1=oPtm;o-P-2wuWoZfTv|#NJZ$BP)>% z^x@!hC%fRUs!P1vW^>vT37;WahtwhhqM+iLqN1SDL=%v!-*3lc&bjvg z&GoD&7C0D;s#mSIuBk5^e*g>qHW8W;A5{cf4W*50b?aq)IAIABlqHae8~t@ zmN$DICYDJ!qW!rPF=h6MTd|%6kBZ+U6#NMuVWd7J$6fFjG)D(UOn1-!AcESoMcG6U~3aT=KL5VyTJ4h@G&aKp(>AWp>qmO zz$Bz(?mT?9WabN&TXpySBN%vUhl%co(+oea$nt$8V^Nhxh2IRC9cpfa%cg*6 zi+CJoZm1)r6Z^n@DXW8=nv&Vi{u7^dif$ffByqGoy#2(NTsWoBclZii7-zLS1Dv~1 zPpZ_%1xxg);#*+xrr?issvH}KL;^}Ahdl8kV!XtlhzU-FqXpki3_UkKj2;myU_2eE-DzG=mzfjGID;zBE~LSO^nT z7ivpk@4WBXQ8{OXQDp8JU?qn$vd+vtnZr*IcPsV}IxcylJ~n}=U62BobLu$D(zEWR^!-BFmN^I&$~sLOizqp4^4ajdL;#rob(JlL%heV z2+f)0q~J(o;&8h#)^$0YT-P9S7LyL@8%Babe=r($Ys}3t=wQ`n{~Kq>6WIyMhQSvf zDw!~>Dstlf{;Si8x~-lDxSyXvf-w4~@sc8od06}}I*JWN59~lypWlhluxVU#GVU26 z$=Qa*II?GS$^nrc%DHnkrgCiIOVE;Fly{}PsI-}*dFDEhM&@}Wm4*0BaM(PI+ZxeT zc$s*7g{`DC_mu#LruhV()RRySvd)p={8(vt6b<)M`mATFq;-}X5J!N>j6zO01Ncm%f+x>&ELKpBl}_(e{F2*zHHAFXW-u!b=KAoz5UXj+?NZ# zUUdA+_7rD+j`2dG{c^Ea`2Clo_v)4WZccTvlXYFqH?(Jh&+i=qvO(uXx4X4lw`BIF zr0)Pu+Oyj_n$>Fp#S3M&UPQMGEW_q%;a`$dMg8aAas58IZ>L(lj-t=XwSOhQQ6qD-KssiuTQgUQR*XV2Y#hITQYB>wLX#I(4QRjExwU)rC6w{G^8E6@yDj%{7uVzb8msQ? z;6+I`WDdqK_d}@JUgqB2(Pt+(J!`m%L=p(h7Jh!G5`prD=bZNCH^tb9d%n?qMQ?D| z^S3r8*x1X{dPGoUMP4>0>vV%j2siAtz~bR}gM!WYq=IWZn3maa?kmasSd8AY6|oGd zO_ppYRP2hp1s;)?puAj{_Ga3Ohj$~jy8dPf^`4$ft(|LND}eX(LU z1ExfSc=q$G4Raa_qf;)TtDCZ>$h|tI(_Wo$2mB-L5qx532WW|Ez8-L_WO9Uql+ny9 zo)lcU;#Ws&zq1+1mV<|jQ=2p;z26OX0MY8M0#{6%$W3P?65UW=*|!iDq!1^AUYx2$ zWWDXefP5se_g!?Fs96jex`?B z3y(FDS{=c9{_HcCSw=E#&M8c`>XL#;SmYcZB5>+EA@DYaCwOI4=4sMtfH=I?a6`$n z^Gq9JJ@7AwS7^@Mj}SdJr^L^+=#a1`ut$zW4kbeVQoYjfD8rVgtvWNWoGgL5SRT@o z`L&t%E-S2&uKeEkymG^iXTUI<__qY7$=dFJrFWR5^0*1&pxMJ0^ab0GtS2rv9mCJl z(uYe=!$JuRUMr3f$Y*=#H}XA-5OktYA~RxSp>{iXyAppB>UQ+r;?QSK`KxwMtdk4& zT8nCFx6Qo#+O&_wawq3T2wSs>ov-I!6tYf|7PYjfgQ395(;t|*TQh8*XkoRNU)`j$ zhM7=ozwixtz(S+GRU{M>c_;dKW97kAnz)pZql!FO1ZsE@gmGDSQZscIpkQhfYG8?6s}l^9 zc;`==&y>e!Uh|o^Z$7IYpJ6=W=Aw`C-58D78r5pkhJgo2oA%45jSz-p+H3BS^&>o% zy!LRh_o34>xgX9*KBE?;eAan4Vi z3|5PCXpMhz@77C5J#l}QgzE7tKbKc5>ChOkR^a`-vZr&eHYvG928EnC7*THI^r#nq z^x-&V*fD5|PM?fO&J>tuh;Q+;;b3;~OO?3;DWVm<1NM&}nY_+w*e^`Kj+?z|M%3^g{23ZXw|FlNFZs)yxbUGGWQ1)<0Qnv_h!v)rT7JO)G!hO9K+i&L^M1u{r z6Br)nXulVf*j7=gv0qwajwin^A(tmVN=l2Xn0%rCjzWcXUbJS+J?j|4qwGfeGxl!V zSEvIJrC^!8?g(pYP&>(= zll|y@@;y3juIu)2zs6PJ)SMYG*OHHll0=}0HSdp2SYYFBi*!DZ@UR&rVCD zg$9Ec94NR~J*)^P&j4X`SKe!C4*@tcHy{H_^d_)(I=sC8)uxouKz&L9#`6sM;tF4H zn`?O1bC-2F^H{fbZec0J$Tyr5@TegJXI>?*J?dRk&kniZ=2H6-S%hrVyCj%L$%-?$ z<)1klC#k1?qSqRtj^E%VNIN5Zk7iDkHS%kb)p+oWB2;gjORPC|B&Do)kTPFiLyu$ znM;Xp8IhK)(eb*SG5rtlEROQH&1?IJ{))GBxXl#2$ROTS^^G}Hz`z5OH^=-Xpo$@m zXKoNB)P796;GA{6W~oS)g?tV!57C@TE-j^IxCVI&3Stx$QXS;oJKQm%u4GE4HSafG z&TsFTKBz-}hvk=>yv}-V|)3#l9 zKLZ2KD^2Z@=g=fgzLaILJV8Mu=QxNmFn~5n0;B`LW~*@FAFn(D9t^yPX1>{=U?VkC zB`KB2?_F`fP!v2*I|R1Kb0w`<@Au&HiCNB+w&M-^zxIB{KcN@`1NGYeR_Phljs}maaKUv6vgmkqARd=5vngf&&4(H`Du6; z`yF;A@{@A|NP_ssv7I=3QXM_$m1~I^z?e%NdE-S$w{X*!Civm>?qTThUF@MBk?o|8 z*{_~&+Lke#z*x~9?*-f%lgCQEF4p*BhQBvIP;BX1dE^EHHJW&TciLKNffyi{)VRryL$*9sJdL+tOc+;{J7fOnx@7}>=`uJrkk9;gQ zDqhIsC;|ihPAh8qIfaE*Q|rEihaLOqvc_ZEVvnpr3M>|6hH=J%S11B)3|n~aT=hHh zXk`xZ7mit=$kFQ;p&wz+#=*_&d))P#@IuO*MJoX0&|iynC*Kv+m1K+T&(4sz%SKIc z&WYe+;8Chck*tc0T?J>tiez86Nv3n)@ENb*yFAdU)ke+ zJ4~ita*ep~=kruP-}*UB?R9oVxAJ&@X5OAub&}&WY z%f!e-^w!30sg8Jm_9CBWdBTGUDg$&0x*z2pVkA$hGj^=iqRm_VlON*7tDYX*sy+-5RI3YFFd305YN6e)IJ;oF7C{w_tZ@S$P?W%`uD(qXPG1kzC%aw zUcM;P?Hf+lr~D7NO%}&}ZL~uBSQO{Up8x1~Ax()BPqAhm*vuGR`|gZSg#uBlIqV$D2LQI#~ND z^?RcOHO`S2deNqCxD4u*b%yRP2)dyS3+tL{p7#@e=mBV7e-e&eLj%x~7;|*_G5&|Z zm`ynY?;t4<Ehe|+n70xi@dNzlKS#y6%Oi>pan!zASU-c8-kUG3#;BI}~_|1Lc z`SEKE7A8Drn6#w!n@i{9?Y_4mupU}U_i)?Y07#lg4mF8SV|9e~)>nAiYwqcQnj?=<94TjW4KUxJ->6g=SODg9w_s^`e0DBa->ID`y z`cqt4E7r%}v6({<@czPbM8?#g?oKCH*&{3877w~P)ZPKM66oqr`oqz?hW}-A!=69j zm+5YF?GqdQfRE~!bSB>^->=)P?~Q%@Ugt{X<=FvwmE&*NQzx8ihW67i3n)pZoj`k` zJ6oG|IoUJTtp#@P6>frnFoQZT2HI!JS-3)0ohvd$7{3x3b0@hzYm03*s+|^}_uzlx znxxI^iBF`2)H>M%Lc+1~*En3?5_IJ$q_YL}l3_l}w_9|+_B(nQcG<~~HemRQ?st|c z9{Yl4UxTgET=oo^0pnR{zF5&#?$PBQn}B;*`Juz!+ry{b14f`cblGbgXXXXJ1l#3_ z&(DyWGOy8zwW28Ah}Su)T}duj2D2A ze0K74$$lA;R&hYrqo9lmh+(Cneo&Eu+w4GoS+(0 zURDQFil2&oW)7Va2Eow%w#rdp>%{RXfk)OK}-^P zodz2A1AB`xc1afn<<^yT=<)u{-hpVef0y|U!jKYdF=I1jEhH_f9CFN~D_%UCay^_N zDCb0@wXUKWJ@my%(t6Fjxy-%s)itbq@{HZTAi7jgu8uXu>%CK-V7vM=-}&Mb+EhvN zIsk##5W8Z*S>Nk<>$0)l-qX`t3k*6VS$NiI+cNd(^zI%W#d-ApIwY4C*7!Qu9#W(NZ`}LhMp> z=v`*G_*S#BunuLh@;D=1{3%80ZM=2gM{o2=cNh_pRSmKLhoZBA(E_LDX=7}5ysv~+ zkQFq!L42H$kvOQak+8+C_)6dbC6!dI6$-s_zcIdycktX9ktKZ4{u@}6+9MNG6@S&Z z%@9zJN_6i-U~E1qT6XvCE8f+0t>4p5@D1m28~1ANbambHx+Qk)RP=#yfIfCVa!%r- z3OLu1amNK)mh6|L7i4E>6?)(Jpg2#etHWj;)yeg{B<&_YJd-aLtwXXHeSw3`u!{$9hen|9 z&}T-U_$ly_pd+(ay+waKM*j^gWonlG)+}eZSTCB9xx+ySO`PUbb6glthuBz_Gkt@P zKM`zRZLYynRQoE8yd+!1_$`76urDqg^k`+Vroy;!K#%uNoH4IP_J(eqM4rMVwc&`b zi;97X2{fxCKBPQqwp>^Tr}ByXqPMF0$9WF_W#<(V43UEvUHI&` z1etscI}RG{iXg+EbC5KtA~)j%o55=-nECiO{Luv4eqK1E+-vRk;fo(7bp#xn<6`Aq z`?)4rvx2x@feqO94N>B(8Y%(xlv#KBQ5F*omMXjXECRJ%-Y*l-nwdBL(kPz}s$QF- ze__W4jI-!v;5+d3CY9)A2EaedNV%wc=lc^rV$f>cj<1;5vEE(Rk`ca+n!n)6@c?25 zsI8dNel!#`v1fE#C+^$JbNRs%rxq8Cwh=&b2FNJ-!xl zDbIBv*PwHNE4+dav{^GfUZc-=G3Q{*!*A-=iGnUNXWwAEBhy$LRE84y znHX8d{osHI4Q~`m30*XcsM`&ql?EiEen-E?JR49XhqbQWH(L7X)nXTAwL8H9Upp5Y zT*64`6`=)Z4OGfLO#-5T#wS3GRrd>@2ZvQb#p*tG3^BZI9~_(Ll$yHon4;HLYF$wNYb>&LdiCsR z<{CQtYKLIP5HgfLu}TXLHGI$5-Z0Hb6)Tw=Dd}yG;N$!Pv&T6~S_B@05GE5CLG+da z>mAfR_|1%*Giyj@wn1E$P^*rwhAi(y>^TD&Jg8@2C&SQF!j&E=tkanNDJevgixFM% zjU#O|uvgf#{>T&E0p|K^hrKs8rTh+t0{pgZ+d3Um^FyX3dE%XeLB177d`B-ISER)iu!Y4(-0XG3C}$Mm#3rWyM9 z-JL=Tx$)Iiu9eRfrht}ELJl!M^{8fJ=75qkr?ts09=?Y5Vun1)OI;< z39`@W@*C}dFE@xr7rpM{m$K-u)NUdadFNWqfd2Ix;_eyllrd)_DhqS|RMqO(RR zzm|50Ti6LM@?(dtj6Wc5w<($aQlSMRjz!hMK7=M6Pc@;%7tv+!_|btJz@^9$BJ@mU zXI7Q1Ao40y0{7>2{m*N$Bec)bbYzc8j88wNDL(l+V9z(Nqe5wPy;+)yEkJ+2^b99t zSghE>(l4uaSyC) z&Y54L-(tUHqEipiRi~61zJ*uqo#TYX^o`5_ zJdjzR8s|{$iLnrQ>=~eN^!Xj^#i&Q<10UN6N`*Pg{t24^KS-4j$L7f8e&k)rGqk-8 zD6lIKUjak)0c5{FUR$|D_rT$lSFTSaHtDUZa})@F_(JwvBL0>82m=S|VlAlR7M>Gp zt)1VPgZ4EukAYr4aiBbrlSpdo-+uHa1^ozUx9p4qcnrdK9?I zoLC9!z+jDfOG=`LZ)J(yiavShoi>AWG$rI>mkxRj@fH1e^6?Jh7pYd{>IQh^@rfei z`&8@%LX{ISQcbd}kLXNruWO#dd@TLFTsM>JI&z)RYc{0CFXE%jTBs029u&skd*k=i zTt%JWj^Lz#>tyut7~zxdc_5qywZub7C*P+!=af*sQHu_kfe4;SA01TsA?={e=mlpA zxwbXRjaDc;Gh9aaV`k{~5+CN*8cLFWm1|P^+n2uZ`;6byI(ejIngV|!{)xWmhsHPz z-(p$Y0ECPl*OI^?Osd?BeRM=#<_Xo9@W*?YU-)Fut;HXsi}r_9bGaU}8E}CyI#iZ^ zpnoUa=Mx>V_8_bZ<<`jZQD_Pbje5>z6mYFkpP`a6t67B>oq1F#B&Y#&68%=M_95K{ z2$b_MQbtBUqgj}kOP&VbD(E*i86Ebg08|{caX789#xs-nW)l^dlkdaw1IbmufqE{(uPNS`1)__eAGytTbfSR5 zFL`eOtd(&HN`=%Bie&K}_;}F>k){0+K5{T4SJXf8dz&9#4Z=h7L>7; zyQTyMZK1>CuY4)UT~;RJ8cadKzT2mF#1dX%d@g>UXA9-NV-nIqb5e?uG1Cak+4*EZ zMZl>Y^?O9*Fh}SQ){}gv3u5z4S<+@()&*_KxLLI6Quxw%(F&Oph4sd<*{8@Yx)5h% z^Sft9e%HZs+w?{HsWF_hgEJn7=(%4Z{Bc*!=%g|1V)z4{-ez(}{+cW7q84O=oX>HK z(k6NiG8v@jcvFQ#hjn#4iUNkCfo8i-!9(ogxB%|iqIkjGZM{>5d<@yz-cZ)N?e=+BBe5w1a z`@tOG!w_P8+4s|CAHN_oV9g<|iw7Z85I!%+R?#5avj~z%y85`qd7-Y%rFP)j7ILaX zTsX6iHrRbN^Le1B3VLPDom_3&*Tv0GNj{$0nptIhC0&DSo)?VS(*J2g|8`t6-=d#C zp0u|WU+V{c>O=oxyph_3J-vblmD9zN{vi*bGBJau ze>gW#N*Si~-FSuCife(4$hEdyOM4VEjd?TxPxwyo*+>4TvTXL!Atug99|*W%+bMEy zI|A&>`6s@GrtkZ#2g?c&Wak^-iVP9R;I2G3=1w~11AMdRfK;$%tVVTsf3yRhMU}CV zad6J)709^XaiKO?6#nzQm-u0<%UF-pcT3Jl} z{h644(he&&{(`zrc#vo4AF75fK3LG*}bjI2`PXW%q{&y|t{gKU6rD*o9U&2KJ`eb$^a#9HvB zz<3Rou^Cs+D!45Gv_Hmu=U=<*jd9<)_Z7e21#V&IIy_nly3n6BKS)^f`?lB=T-OVR z;n>iLLjqlbl~e3Ru?5YvOt--_anx>q!9Io$C=H-s7&`Ax&xHFkh>BXSLuTxPC9%Z4 zcYirJ*p!KG@0IaM1Hl@NUhuy?yUvX;;)=@BypnCybh*d?nH*d)M1K^_n4z*tf-RPoJ}&Hmn<`k zi8CI7^dbIBfzLnJGFP0<(k{?zbTTR#b1tIG7*!wlgu6DVY{OO~zI8cym+J&Q`(S3z zeXCjbB#Vo73lynB9!Xp;SYd-z7(4)ToHXY#duipJ*)Q+ch?<&nNR97!P{p@sfl?Xk6HnGu+n2F=2X9ow z31bAs!E;~MidQy%$VL`W$_E{`^#90>gM(9#ka(TW>4PzA?j$6EHB?DvZg z+X7TjeKMLo{QzD2k1RqaFh;WOWlh6qCq6>fCg;kQbKg&qVSK`O`=TqKqt){pIr|-m z64AGQ&x1qJ{WE6Gz8JPqfS7-#8^$X}UjeN9&GfA>Z9J7(C~20o&A4_4-)hg28H zp2x@y&8o;C$_xB|Kf03YBSR3ET5q#}tihjFgqxpT{HVlk)C$8Q@;IbHJ+P;e&lQ#z+@*3Im=?`LNfCk|b?2?FaP9UPfYEK%w(hCJ_3>fQZGm{@G_wjb0P{?#UYfw}mo+-?P^#v2{L4X=E>u*)UOekZ{Sq zu5RX_es58>f^1IyWAE@>-^+YdAV>7Q;!xYMk$wM>g{=K=bjD___;K3k2Szs~SHXRS zev{8n&wpiGBeubm%RlfAH5yoR#M&l|fy+znxyYSo>)t3*uD-v!e$SBd->wsxEIK7NJyfh z9hcq#yOQ|aGk8x&zrbJ;aP+b<)weQrvAtrCQk({9M})0`j0g?8BNd79|G{H>?_w_= zHDth#?f|>Af*&GiL%0)!O(JBC4SR{P)!_+XzByyR@tY&3`V}}gFpOr?7n%Q*d;$D%z3Gok>69@t zrO~Y$U8Ku;WoJz^gScsP2IC4l6s@Lz$Su;AG_W4X!YA;Vj*bKrT&dXy$9{mWhWq0` zZS0CyHhb%5eR<+H>}eIbgCBd1+}}6yES7a;@ZSA~2g;9*T>Hip({|Og-CQd^ zc4VLI$>!R)40Asp$}?hLI{N!R*WP&lD<{!QSSQU{VbBJS4(yfuAK3(xwDC&pkT}5& zL;rW*fP4Q(Peo*s78pRq+>8k!5yqtRM=td^6EpXjgWy+|b2f^c{EZ0oGPWK$o9vN;f2f}cTi)HusSkX639&=4u^X%@bBfOz8(GFaYxGc=(>8M& z82AcAci=nl=zr#aESNvAN56B*LCv+YZl5ZyCHmHdu^t^9-2g0*ocBNagy?4BozG+c z9=Zl4tU$<=T|IyNtc<_2}RDlRW``v_ZsyT^_%{{0|1H=dJL`+6q2Agyye2 zZ0ve`U5Bj4UVF^%O`DcP~)SJ*b_CPJ3_B(L$St1W(`Gd|dij6HrCUXMHMQLd+EiCoW~)cP{^e{Fm930$=^ z^~V`_+YqLEXYA{uM}|af8P*ScP`#@>yTY0BXRkc=u_G(W@c<;M4QM>WLXj@bD-Ukyr zoRYgz>ku|^#*;KjTm$DYxhy-s?N6~CxMl)g4Sz+OzoEkeOfO@#0*e_E>neo4R>!|d zuDYCAdtxfycu_Xjp?k@7*!7YNAru&ge#e8#6kkbC@@4Rq(jH|2$-c(VRZdAWCI!hP zy>&bZ3)6UggTX?>ckseB!PJ)Al%P97|JGx|>QmwtU~>te#s$Tt&>P>_KEOPQ$Lp53 z!|}Vf6iYv3vG2m^^BHa|FheGKv;T#zcK{;@4o|&HpUy6kMsl>AYD305V=)A#WKJjE zcmpkgnNMuEIPt?adjlWUt;ri8{^Jfxa5@^kCp#NV&A zv_a3IEbPyH@~Vxp`;9DC?ET}+N7ceug#Cu`lrzRPJh?B`S;4r9&-N2Pbk^-!(PU40 z3vmEPql%)1mmVF!;`zW+pidH?#n(LKy(gG)d^EC`WGtZ1DxqZoyrI>7uyfKkC-I>q zXr41KVB*G?MK1L%6Cd7$3R`j}0%29JRM;EebxI z$W6GO*ql5A`v5GDY1&2sKDW6ZjJHpo#uMyl-o+P|n@9V6?)CLU3I{WcX7hyjRt#f< zIsPa`cV|bygr#i?-OY>UIa~0M_^=tv-h8WoJJa`e=zyy{(&p&X*q@b`syl&%j*@ub z^IOc! zYx)NEmU;stBkPCjN0jf@Uds$!eu`?%nKD#n#DSZ8>6|9mL!3c-pX#Y6pOOJahvS>N z?%2_MzM{=*4_ct9k>kIJh`cg49`i1JHIO0r#aNpK*%}$&k8w5wn#<3MXUOxy3qs84 zqJdY7!dA0(X;oyOj0>^mG4XB#a!k7!Z!`^0*w;9h@iVsMIqm3ouEDpH4T=55B8ImE z9ub?^Et_=&f4Y89`#W-&0&+`Rf$Whki>{5nx;?w2cA|t+&ZT^!P(svv%y|Vi$l0>r zaj3|dCwV(yiXQE;0Itpc?U~u9AFUO9swce!SHMXT!62T5%4J5SH+~O4vPaq^~XZfW$8&a~sP7hlH&4ys|b3agX(q9Rck2=HSt!1vY z?=BSU6mTMsZgf|wT`NYfCZ`l%0;Fin*Tk4+&nNV*tw(OZZdcJDC-pt&X*n7C^9+O;pf>7OuL&BSN8H=WEzJD?`i8*d>cQbNE3uv7ens+7oGv@3E7n*j^Uk>EXy`r-Sv`wod zzXlcCD$P-M5MRh=o+g=n*xtbax&EvE{MN$fF;J!MNqc#)>&YqEEA9EePxh1Y^UAghod;Rh!GeiX7VRk+%@=F`uS!i<~b38BzwE_uQX;S=|6@jbpu3r<6Vaml@Fio7<)gaGC+7R~_pMQq;ZV8WC{+vIlm@gx6+!RsYjfVoIDuJZO_BCvGk_o6&llb;+zWJaXXT@U%g~g z!YxO>+(Ot~+vMTZ4w%by>aTH4LA1;{Y~g!`y(T#V>f{;2FJ0|^5HCk!#HYY&lh+Ey zmt4mY<0RqzL5@oYbRh68bWnJX!Vmtg2)pG^T3(X7__0F%Z`o_$U$i$nj|bGg>lkG@ zDfUkJ3c!^0P>}11O=6SDe;6Fl5m-i1pfmRK%@K3?-J4d(g#|har7X$QCZ;Fs*Wj?2 zz?W3lIrGffjWgg0ZYFTo20_LNn7)qUqW=>9?#_u_lY4ksX?JAB;ANygp>IRvtHCll ziYA*+Vr0NuWxp>x^J}4A@Pn_g*I^j3bD=&!Yie-UnuWed$)WStT1S5%e;s+2Xm@P+ zZ-*QZ@)$<=vU}v3Mn&}@@}CK$bLLYC)?ulbWP3<85N!K)sU5tI5AF}V2xgF~tq!4a zVD0v?nU_;5qJC(;)z?i?Wmjf_KIkuINpv{SGrGmA?_SQ@Ka_PXN09XAGn^#_sc0)A-kKb<% z(BsvYLXW@#`Nz(H2j+kLlGvQRGX6q;Pu1#NA#7}(bq$t!NS=I(yP=f?+5(OUB|gEo za(cT@;5L}F6@MHlinEVXBGPKuAxH8Z?iE%*D}&+~X6#dy}akWpec!3P05qTpnVl)N@5 z$R&;k<*2G9RZAEqa)w*oy%4iTu{-0>NHyfMoV9?@RdfR`U8YcDiR=p|=$Uh$N+EV_ zeMt_7d*ZNE(&mx1O(o zr6vY=(D2D>2l7lU>WHERtDDT zEo}LY55`r^iB*3BttGEPTRCvNaCmb9U$jvFCiljJd624gFbwwL>4)=99lR5Zd18K% z8R%~NbLYD(&giq`s;T#PWXLVswiA$zoB8p+oc7+Sh>Wg4n2z64aG88o!B#islQbyRpU7>R^8DUh?DoY~e9mybO2#TN74X_#z64{|E&Z1S zJdjgf$anBNT@#;HdzVc*uwX<&7jN(xHl8w2y=QWGTRR+igL?SHTw||fCGRP9)VKMI z>)#ph_vz>JsqcDb^Lwx%eEy0~-u2EYILDVX~ypFG1)?KI1~)n>7WLoaYRmX7o-KkW`yieJRm0 z<}=ZQI)FYTr>9%d&O1VveE`uKcA|RTtM{Y~rlz4`5~%rQr-ksQohFyf7qu%o!_lsh zb7LIha2cz9yePvDCXTsTTRCuXE!%KhS(8Qi9f9|4i(Sr{V5g;n@szZoL~ftZofGOx zT3I%=Yrli-RwYXcxZ-c@P_|lP1^5y)i8*=L5Gy+7=~YYCl#kY9w^)?-^9YlRRN^3BcD$e&?F%@rLG7-=rp>soo4z z^56xwYLREh*~0bSR(y#E$$hGN;?LlH>KFJTJ{?%hIa}*J%Qu9#x!=KsV&Io(az}Z+ z_BvpTb5^BFauQr3IBAgw4ta~$-WL2AxmR6$dY$0plO&KlYx+K4*2+z@%^S5ob3kD5 zVwh=KQh`nLH{&Vv(`;{Y*IXO`h5mLOs z4*@O^>XERu%1mLEuyx2&0?vCU zkBxp=dHl8#KW3w=|Lr%_+c&%3&}UEX>rSaimy$=kAUGv$lPmc@+RkS#WQ-_`oJn5u zoZj^Kskc$wcEp#(LjM^T_?W*%*TC*TQ^eloIXN(PY)0uzBe&&RL&2e%y^lwS;mIH1 zV;?yN*blDVU4HuAaK$v&K0#Ft9=o}g0&(a7t3{r`+hx43vKv+@yb028!9r_C<1shr zg^V>A4feqWUD;oNgQWTnlztOImvS+FZ*%{FxtUeM)TpNc)}3m0w9e^82c6`i%|!x8K^f{2p`v zV286fODrAkBb-sX_E9>-wFl@nc2s6N|Bmp;5`rIDXNO-~6ifEV)|)9VaTGZ_%_HGV zzZU}}AzzRjiU0!*c`L3%Z=WgfS)7LO*;PI?q2NDy+IQdKx72;4%#TsR^@|S(j~_B1 zMp+)H&&9wzUI~U+n-S?DaHtJX?gx}#w`qW!iRIfCge)id&fE6kPO^GbmBYz}?PW;7 zIGgz-QwX0IXNY-$FM>W#g;!rE$Hg_$>k9oFZ0v~hmUTj2o7XiZptH^(oh%sA@czoYRi|t z^Zjy`9KE#{Z*yFfLI>qr^;|humz>C&G3UDZJ@t|EU{Qz*-8~mPt8isB>j(ceT$hec zl)=oas|Ft73%r@+eBa0XrZ-oIWAf+xJ-CVydFQX?JMr*#_7QTv5m3^-O8CmGfuF!( z()JekJvh9nNyzBnAbXKh3A_0n_FnLqI54RIF=Pyp^&tjEc$ypK=(G&Q)}_LMTHu?B z>k+d;Fjyr|7F_;t#hD847q*3PU8p@aUe>{AKp#l9^i|o4J=hU{kS9w2)|cP|zA4rc zNh87|QUS(na$2PpFY@Hbxwe7+T50n=*{=j|;U(d173Oi%T-^;Sm$$&G= z#W#0H?y0>MevBdI;eW1$@a%EzifencgC6TF0+F4^tImNArY7m~xv9^tjjx{WExiF4 zWvkSO$VleXUtr^hZPP{;d%V|Q=CR4&7vk0>e;3|@X@Trr=Xc`i)#UdxV$$yu@;>o{ zgYV)Qu@iYl$@B_aM7~7$Nno9l1xuozJ6?Pqi!IL~> zpZ9RTSmgeb4k0}0PV%Fm1%Y_8;o3KtbVA`N}sG^Vd`3XuEjysfa&XF9nJxy511`GA}R z>;tR!jU1iLCTOK@0AGycPYi04c0+jNgp?b8x(RlZ9(So-0X{SJpa|`$)4pblKS@EIv3*&4qE2@~0f@q)+G+{pVuJz~yc zd3Kwe9=Wf*{w#+rU`>6_Mw9bLY)NUG*XW22ZO6n;S$(<>x1OV6@h9ZOcKDt`s3o{3 zzMm-ae#|Db$*nPjzedIsKVz5A(FK#Mote_6L`b&eOwcZPt;w~`XZn$}Oa8jlFvfpf zaKFeHX+<@!W4O+EPb`PFvH1$A{9_qiaxSKDY|#$>Dcrot3T9nmy@C_vz|^FD`Q2Ii z1#5qJAo=}sD0wZXB!DB1ieaX)ojP6NCgk-E-cj!RhAT1#0~E^GjL8dPoDx2F#J+HT zo5UXQlQ+ zp=W)rW6d1O-)BH(Kb`^o8(l=6!@S$Cu=O!U7tNtD*KplxT?~AeGIb*oCoJm4J z;pywGWiAPA(Ppz}(7#Q_CgVHWKVX`MgsP@+zBGNOofiBwwTa)rocOd${y*1#=EM6H zUp>5wUePAcGSw#C4hj^uq>PQ(%rASR8Gk-EzyTwyu;U)>bEeDKNc*L<57#a2)1KMz z34X-GgXDYD{vfPqf1^#5?t(e>Y6-Ey`6&Af|Mz2j68Z#QZr`Xmr0)8Oo zo{qwGhW5-RqJPAN(XSt4#NOrkzIv3C=uwHQWUuo6<3(qT8t%){8GR0aD$h<^JMW3M zvp4(zj90@k7|T6`-)=AQDCxU6TnD*~Tg7!>FSN}U(zbl|@hWXM^Ez9;5|1DTwrR;M3%oA;rVcEKmvo0k#~gO{o%?nL?< zGK${;uTclrP6T> z#Kcr>to57NIv%U9rq5aPyRHZHx%nOE=BDUM|B0n!6@SOvQ~2=|`3#rE&$IQxW;UN; z>XV5OO`qbhiRU$a45YF<9GQW%JGY8I_kv%9_Uvymd3Gx}0XeK%@DY>hLY2nG1~;z4 zJ(HW*hh9f~3m*2LsNcxkmiR9q>>Z8gU1LgKHFkyA?FM=2_^9{|)dSFN2Alys`1)w$ zEWg2zEIzDMzCQu_B7rZ$h8cim%^siW=WcT!Qj0hXe}8kq&(HTEx+A$q*7c?x`bYH( z=t3)g#-GQ!y-|nRKT(;e2s8L?LKXk&1;>y#K`!0BIew&k!uz+^1A4&*KH4Q1B6Js3 zpI+a!jWQU)-@s=pA6_2gwLK$07kz0XZGgwReuF#c@%_f*d+`_Zdi>D^M;$68qrf?t zr&lloLp~=Kj87J#C0;M(dd^a5TW|WHIo%PPWLP>}HcxCq2TlyGs}IXwuHHinL+=JZGqFmp7kntc#lIfP^H0$iRPqvuZb31)Ah(xsd(6u- zxFzSqH)K5ExZ|PvnV0*9Jm2dbAFI;L%d@avoZ({R`Q*N|<=|hX?SSh)@K2yrz67cwX#k<5`-Q@+N(9KT3(w$_KbvOuxlkeUocW8{;QzOY$rX z;$+lFq7NlbrgHx0zK=ctt-wG1G%bj)WaRL1pAGiupZl02H~YCy9{%S(capyU-X|`8 z??bjsdo+C9lh?iGHC#YY6s`?*0q{-k>1n;PG{Imz`nkS-7!|x`!n;qD(73}b!uT0J z^ZTK=YLE7s=XSxwO@E#N?AV1^+wZ+VMKa)gC+*`O+n;UGfp{j{%vT5W$Pd;+PvkD~ z${sQ6U|hrV!Tgx&3qHWkjP*v|aX=e#9oQzY@bZkrKPmUY;QrnVx$*b8oRby?OFe6z zi?wt%exD0Hrsh6l{UotOc+i|GCc+CFjpt01x}mIQ{RnfF?jp@@9lGW{&P`)8DmVsK@nL)`DIRDFQ8f zr2N7b6E$PT{srZN3q|q?B@g*`9-8&-fV2Ny8{xZmmwhGeoA85(eo43ruH#PD|7U*4 zZGX(q{Xg?_I|qN~$6XrP-*(~b`&obN_nWmQJi0%x4Y>Hen|W*YkgT<_Mx7s?g&g$i&$v<6`uz60hdjpS z@0uR;tfv2u`;f)^dmq|P>)$oOShoJIY2P);4<%<(BL@^dx8Jo!PQWe2UNuK>y_Hgc z(3F}bp)=&{nFiQvGH$=;++z+N=bV4|J8nU9{>d|t3r`89?ZlLE!*2y4^Y7Usyoi%# z-0nYT0+#d7nNXt7Ba=Df$eC?kx0t`**A$@!=(!X8D(A+Zar0;9&mR2_|6K#&&l-sT zy9Sze!}i~Mggt3jcme;;Q`)@#vp=zAAN$jNBPREA-iR*;ejVALbt4;o&uXx-@|nR| zll{CF{)6J*GeABE|3~l~5()ecgqui8ojOm%@VkC1yKnT~>4SJEronzOt?)Kv)Fp8^ z6Fv$!Jgz~s0T82fSZG785kD-oLa>F&FEI1vD>i`Q^BTPSi$S$=I97UF!TwAubg=Wo ze2g@MQRl<+lWb8!QUj0hC1^)JA4Fva`R-tug!^c-L7f4MBDzige!7FzNWQ*xaYNrc_?h9c#>Gq=lRgh1 zc3a|Ag{S!|q5lJSnDIIJzHo1EIW{ZxfIoA_Z_f(VX@tWyt45*M=ZwQ&z_{`qdUu!C zYp=z1+1ZfmuoLb(2sqK?ntPi$xH)*7!B?pI8*qwscSLZqxvmlpnS@es(#MP(mdVRq zqb>a{(cp<2{9b1sxGvJ)o9nKEiO#^ffc5ja4otxBeHT;m^EQw6$Y12XS>0gLj1lo< zS4y5?4@`Leud^H4&(l)o&+Tau+~|?Se}N-Vf26ri#^&2Q<10RgrHq;T3Cmm;90cQ( z!PkGj4a;Yd{90fZZhe$Iiyk!zy@<~PKG*j>XUylA2pWD^6N3`$_nd6(nes}RUDdeU zB2CBUBas5^k&@+LV~x{@T6ohyIKTOrk~q*dIS#>euw1fUUjq7a+mK2?f1e|&t8J+| z%@}t*C$-e$rX%EM#Hw1XKjT(-1NOpJ#tR-L#^l=b8E7B~CeloLw+N z%Pj#PiVctV6s+y58}_lI9Me6e`ZLvaF}%&3_r(kj;W0I^I>|^AJD@(wS{AzkP734| zMmc^siFuQ|Ks+3~V~}%=j;{;doA@35y+%HjZ_U54aZ{PW1tWMu)JtNc(k9$>VEpU7 z$m$+#QNMgeVLzT zFbq_PB92>pNZH_GW*Twdb0s!6@pP?6p-PGCvIc(c@jThG7M&N*Y2v!X#lb*eb0^)- zYua4f#NDOulkzfLIMSAz(N-6EKwFJ%Z`KxlZ-e28I7RSZgM-xVvnIJeLia-dU1`6G zyT7yVcjQOE&fAavH&~7f{mY~8c9r${3=zOZ+FC0DjScBcUyIgxx}6t?iJZjJ1_NLIrFL_JF<8} zjFV&}*ZL{W6u`Jre}MfzU0x1T0ckmP0tbxhVyv6@G6BMxxjm9Lp2^any^S}+kNg3@ zG4UiVn85Qsyi!5C@-8Ig;|(Gohq`~?Q)&=|SvGv8&P4UqU~_%)iO%}u+{D7J>_cL| z;oZ-uOf7gZk9?0qta$CR)+mLm?o3-7dUQqNpeA(z2A}qwx_bs;066|cwv(f#?GtB= zsbgD_|3)ml9Qlbwyb^4T#6&l5Bz}HS*9qO`jX5o1-_E+ECh6j z-|MWY1VmdAf$;p{nyg&QQ7NpAjcEWDPZl|m;&+L%jw>SDgvp814P0-v3JGZ_`0!1+vK016XB1xOue@%BnTY5#fAP>_AOp(vt8HWz4nDp zi4G(_PM7z>SBsBV@{7TWYiS3aF?bETvEZxlBNA^5&)_79OP;~;JYv7_1b5Z*eBH!x zZ@5wV-Uz)MG9}#8tMf=Bu4(7Q*{LJJGZGWi_a3*}ja@8XUs&m_? zeiHd|d@pfd>k~p)UQ>E1vu}p_6Zd5R;5&X?a{yCJn-crw(bMJ`PrOdSb4aXOo^ghs zguco(Ke6AA^i%3Z$TLc;_*TXeE0#V#o{^Xx&q&Q7)&00pSZmO4<{w`Y&zU{st33a2 zUjBxTU*a}BiYldnS1jNkc0%fZ`BToqOF=AB;#3v)&!~@R65ow~ z2to$@M4R056f$2qW$Y->}3`r5$hZ zhHu)!9{!EF^O~4Drg85$fFls>l6v=O%dSs*V|pcK74r;BV)|)|YvJctY3B(tF6!3h zaT7B=G;y_ryffnNa28|3`@H^%yGvWCJj0r2K!1~GXr4(QK(;CO(wu%E9N_#$v))tk z+~9|dqzXCqch=Ozd&J;h`{W6EgLRlNPOOnXv3I!#nrQe%>algTnt2F$|4oD)|1i%0 zzhFtbt3bvlX=36P79nwck99^|<|jtWHQ?6EiAFs{gKLGOOXix-!Hmk--_~(vImj@c zgS^)8D|HwrU&jAPyz6H!j5*joatNDyv}G-WHGSNJTqPMN=Xab0`^aWt-a~=ye<@Q6 zJEQm0KJN7J7CDEjB6F<;J=RiI^Lu~QUWJS^V~0MIrzP_WE}Qnp>2!bXvH!sI$=nVNC})vVjOSEot~21*r2pMVox?Z!2wuUE zIc?(bzjN^uhi5JEye=H++vtlV+*js%JRA`(xDbbL(WmjVU^rxq!Y67O zCd|Ezv!^T2zyGeKMd@v241U_a)ApVm!T~l7*M47#0k>qHwE+HSX`8;wdw7R5Eu^3c z{4wwb;6dAIjfx`LmaQ`c7W1IgxcpX#|&XY1(e)_qT2OF7KOtteU|plh`}V=zH!H z%6&W)%$M^sHo_b8SU;1iJe#XzXMHnusubNqo(UX++@sy><-hl6-ou3?`|P(4X8Ung z{C{&BV#I-~r!9~9TFW??GENgRJMq_XKlBBM=Z>c6kiF)bQ(>8rORcA8+V_6;@dIli zbB&GeoSO5FeH=!2+L!&mmglx>$SuL2@II0`)EOuIGcs=a3z??#*5ka>?@uoHje2o3 zW3_usG)|5>XZkL+wEB**ZzJMYy*P_j*t+TXt0p6J*a2$+Y4eXZyt) zFeqJrUt)Lj{-wWAcf?i5{q1M($S*tHgl|ub6}mg}sC$*`d*76H^cMI+j_4(PZnSz6 zxa1Y`Ywv@JCtUfF55yficN@c(wVf#7vAfm>dpIqCp6|QUd5!ioYkxPs4-%N*#-}_O ze*sqtr-VN~d^3jl;aom^qn%h=R+Re`AAs-*5n3cRh|2%?ZvFUf8vr$@U>h|xfRNHJ z_+9sw)T$G|Ke2J^|8e(b+o@~Kw&nx*g4%{l)d-M4fT-?37*Q7nTcD@UFNP>Q=fC#e z>zr(v*-pBxHBm7%>as^fkCfT`{tZ6T4JAuvI}Zp@kZbA(Utvr~$gj;jdBvZgj~qRV z%XSDU-)C}glWp8Rt`R0nDlX*GY<-pg6!lGDm(5@4i(ER3%jsf^E8mL@+?XYi`kwS1 z_NH!@zLHlLp6MeS#$Ms^%^nV)XZlt429NYPGbtIlifb}mpzljviX5D#g%?FV?>0Up z4UPO5)Ad}(1%P2AM~2u29Ul%)=RAiSI0@OiF6CJu&&oVA;er367PBrSZnhl0(z6j;qZm^DSZ>Z7Jz}U3&uC>>aO~7QH1v1gi1U*b zjLaQ;8Efl06->3SbCz(?%W8ZHCPM#c`o`Eiq9d}yCmLbEV7i3KCFn)qOJ((=W2c3b z`b>UY8|_2lAB~{bumg^9qIb)CGY5}p7fiuEyTsA=(Zetn<2`s#G(?5a4l3x0pwH$= z*sBWrDbC$ESD_CuKs(?Ej#RKXbh8^M1raA_8nGZpSBSidy%_IxiT4M8<7gFW9k-u_ zI~NHD=NFu|Z|o@v2mb?Ye=e}YxwW}xOsc%LhoqnA=%k@_(LC2BuWgBQe1`YRDzp;g z(~8L}EMkyAkY{wF_y)Zj$%~F9wU19UbA?@hO}Fm#H6VOVKf&v11RIHyEOPjjy5U>2 zICzhL4cZ#c4z6>%6Z(OJWTVQz;oSO-U&1zg`Y=7U`5qa#EotYazQB6(9&ISLar7eB zDY*6oyUxk#>=9Pg3$jiNPT0Jpr_JEk4L`n^pvwLU-f;;Xyx;Y;PVE0~8~n;so?mSv-<7Hjvs&xY|2MZ} z$o|T@eq349O!)QBXPdL_%&&F3S{jYb+5p8Kzn&^Nr0Dpd=PmdMxQ4Z?7SYKzn!D{c zzxd+2RXJyf4XCBfIO9!eCBktgDYV%1+^@@WExZdbu<0631@y)k{twJ1Sufeoe3J$~ zFGKwr1hQ7OIiHk>v)h8Bgx_|V6|jcBAPa4W&4aDY(0_ZS#m4NO7QVlcg)?B`h6j&m zB%Xo&R|mZa{1wCJ-va~3c8vX00UKi7l6^Yo94mg0wIC#`3-{;l%42)neKZm!tBNjl52)RmOS>?ddQnNA9UDW}YP zeU+Z+Xsz1EvplZ()vj8sV^SVV`!Ro>Nse!5cD2#i*Cg88d*5ks*OzaWKB0ovNrnOaISu}k)O3u!I!?*E5oHqANpop z9a>%0F4C4k{oD$q&844?rBB|kLme(6)t1llB#EDu$LZSftKW6H(Y@crU{+esJh2CO zleOJG)>Zvy{d3)oRD6VS@DgyAy2`G{>Y8*opBHla0)8V^Z+IDvfihM94vao}arAFr zd;3{STbuD4{myGrcgFW9FL5LPX8Z(b8ob&?lQCz^ZtaUF&v{Mqw?@-r)yHp z+Ubn1jbH8L-P=jc5x>R~wn?wC)Ut;9>DkBsV}6h_2)`&eAUl_AEZuteTYZeBaTwjl zQt93w;EBEWhcRUQG9P69q2^|7_IH0UYs>K2zu^f-|9{iX+NqPXwMRRNK6os4?nc{C z?b|UybF}1Z#rmAWlld4+@SYrB^HBU(n~6VbRSwtOUm{e;T&X+k``TxCkzZqJjc$sJ zrFAG{=~LQY^>LJj)HqdT>bdZV-#q3}IHF$jLy$ceAC$VPJO9f{RNa|h+)NnotamWiQng;)vG+HyP`ZuSl%`6*Al)d-u6^po^9lZcGgCJrh{)k<~T7`jz&gsDc_tm%E zq_gb@I`#?0Q<(JHlWj-aryk4p20J9TA2fTbctVr`*QxE@4){qIn>;wB*w}%;nmx7=b`8>IPbV_&Ldy}Zz6=TR&l}cg-w8|eQ0KZuIrZ_5*Tb$hXnloMI+J+AI*Epc&%d3Vs$?Z@L@WaG!4ANl4ohEtRxDn-y}@$gmA^P$ zFRa0N(xWS&vT%J*^p4ns;X@u-GaQDTzr`<|IW-Q*t5bx>3%`l848zTft+V;gb&oyj z@$KsqJHQsg!Fu{*vyd4r=5)qe=P-?1c-2Oej`Jb918m3)dyOnAjyUr)bNl3N$uy09 zTHbOwrxPC9yB=~ih7tZ2^%H6AZYSX%?5yEh7#$4kD-9N9cbnj25w4GIIM%mwh6eHy ze?3Bf^$_%e=5m@ul`lOhNSwI^E>+bV4`4E8<5a z>RgS@=MP+4;=i23bTIx}Rtu#8{3q{kOY>3~X=_-@A9@g<;2wK;02x!}7yjw6C7iIY z`(aL`Z)%%f)CBCo@2$u_-l||9dI$Ru+WgQ?zBr&;Ycpv3$ggu06RwGfS^bRVgvj#* z!#d8rt_2&ZTc2BbY>oyWv}F;Up_vc0H5jq@+V;4&@Ex`AOPu}ihsS-?w%Jo2oQ;O< zg7MMBf6Q)!-+9d_F_soM$q+of&YZW_bIhYJ+z z4rz|L#{AGRmkk)0t?NN-aWg#zr}-YkDyKLB8sGs1Vv7z*&p+J4e7pC35bPI8UUG)V1C-w0&m9XzeqxQMi(UCH999OVW zNwWI5fSnlBMvd zh(0U%O#wk`Oi_Kf{4opKbs|!s8IWs*!;n)*x;)wekY57B!=LT;mhc z>R!g3us{N?IW)Ka7^E(AgK)+2MgL)zh?^!au|v$`Ecj1724kJLd<$1&^UcDp38mlu z5YfLA*Z2)BPPY@k;E#XX@tfwHG&*4XV#EJ_hGOC3YN6Qh!;SlshN9yj zgnO+UFLZvfw@&KRZ7w7V*FvV-h}rVol%?jaF?!-J&0^Sd^>9SeOqMRncl(%7CWVHA@cdV zUcy*p{Pi!+_4ue2c|A6&nGBx&gP!RE2>*2;`WA3A;xi8Uyo&nUp9+q}^aRnG(vOdu-$g^=Yo;Lpy7CsZvlU4c< z5-%%K=uE7!oR1IP0A{rD+{PzcSMT71!p8U6d5n!ec9#aW;A}z6V%IUc7s&5=PWeQ@G zXu_x~@x9}^Tx2dPzPZFFG9$2e&0$|7Hj=eXHxA)V4#iHfgB(G_hRD@Ko+N24y#j__ z$Rff4`y&JF{T1Kj?H(V<6KDHh@xe0&mof3#?5C{$Xsjo{(;DAYS}awZp_ELE{5j&I z+n2nHou+1={`#i)_=eS{j^E5zT`1d7C8I)bUfFV|~(A zGPODF43mam1!J*{&#SG?(@*Uizc#Gz@Gl@M4@7pvRuLRPWIaE^H587}OT(|X>`N}3 z3)mMR-~F&cYtV>4s7Lw6tg)u|x*Kb7lxKg*hm`ziQ-lx}qgO}>Qp!!5cUh8>U$Hl6 z;TJq?kGv`Ss~d3+{Z^1Of6h6rIIQUlem@#Dc-<;V;I+EFvys;iwjaaA^-;}Xt5Nb> z)tsp{5A>-MR15WHauz*x7`GMPoHGCVQDg308fBNU+&T8#ur+VJYkO=!m$)}+Sqk5UIjk~ zzZwfptJ3QKPHTiNTHkg(Chk{PWXY{mZX6l0m$jjn+S zMg%{Hs=J^p-pucHVpHE)-67XkSf=2#<~Q5>Sv8AW)t#L})BCC~g_A1m9>3ILH$mG3 zq$6b!-ePH^DJt^Mbgbf|jo?yI*fu&i^w*CI!+3~$Zi!p>v%Uh$t8`hL2J@e|euKU{+k!ofv)*fgE`C__E!Zc$INx&ycWv)|2wzG@d$cRrBJ ztKVgE6sfxqn_TJuu07(0(;$#~TYwW%CI_9f*I1h2BjP@m1b;==0T06)Z3Hj5leM|3 zyLs1rlnL2ce){YitQFE{WxR@B3m6~$jy@4=bZ5Ddxhd;+(?j>(?@E8q59ndM=7+wU zVfWdb@R}d6BOy*W@MO#iwkmzQvpoiKA41H0HXmt;UJ+#i-`>dlfR%Z(XVRAp>3bhN zHBPX{k$Uy~Zd+|y2=NEZCte2+grL(DG5_zK&LeLHrR96JR^#n zxt{rRIEp#vvOeG>0v7x=Wm+$|f5$zOvU|&H=33+f#&LCqQSrfz_OOAw2{&1K6>DNAO5p;dp0Z#i0JfBgkmmrbnzGCa> zn3n0@=uJqtCbEJqXY@jWeQ1$3GneZF^-X(%^U6IAm+_sgmAP&%jxg*jle)e~SBvtJ0btL@eLeuqU|I*RC6h9)A}k z4d0Zqc5(rBZ=TteG@wX_kTrg$q2indA+9u38dCPYugc*{*)ef-4c=SWVEEj3J^Cqa zVjFL^jT|Ab%6YqptzJRinUbgd!FiM$-kS5d7n*zvZ;!}%Rk@Ek{$Swwr@rBPtpt8< zot)t=e}l^;&ozE*LJkhOH=M~MdD?V%M!kFF5AVSP!dwR~gwhiCUvQ%16Ksjj#g`d4 z;8Ex>Cq}PexokT>MV4fXObrlctJxJZjyMh}ySYq{6}b!8vyy)S9GTPg*+G;600Im9 z3ancE>wt!Py&Llj&CDRpc$|>IYJ(I#_M6v8U~7TThodp`1XH3j^b#Bwtgzd@yoDyL z(|#~8fFbd1UcGk2cWw9FkG?iD!7|Z31D^~b#PO2vG-&|r&y@XKDqI%-VKZmrWbg-+ zk9RQEf)5aWAzT~A)hjb5@w3iB6Q&^&v;!4Zcu(lxgvspi4(5ArgYN_j{|+2`LKDAqX zX5O?-$D5#AeQ`b>{h|6R-w|4i*zb^U&eVzP!$*5n;lJ7_8fWFKz50*ID8#>1gkOfp z*~Z8%&Ai@bY%o6BTpOH(1eZn#mv5a1n(>)L^>2KZ%R4NzTJWDD*Q_43%eZNz&4INZ zZuIzVZUp0=KTRB+Eqr?&a#jY8GqvLCZe!$GN2iRv z6u+^%mN@ZIqWrsIk62@u!_*ovY5|De^CG)Khq=d|G~9LI!=EePI(e?UtWjMIF3%-5 zJsXwc(|%5ahJv*=-^{VmVGiV4 zg-z;DTd|4QZ#iplSom=TZ~S3h1DAypLSet3+rV!ZW5wV4z#yl>6Kd?r%#T|}1RU?& z2c#Jpdk71 z3K@?q_t@HD?@4{=>2f2uu3*2+Sqh&eEJTL!!){I4Rz#TPpkd|&pt#6OaD0zw~X zyGi*aL-2_Mo5$soPga&Dr@78|KGx?tyuRuB;4Q877(HZs&uGifO2 z!kEdwjUEL)Z@`st323b+ksILaN97+(c#RKxnFh^64?B{T!lG}C{n*BCPH@)TOyZjw zRfP%P<0}fUUWPH_dr{%OQXdi@KJXqG^SHWJ7_*A6{@qHH&4Ab}AYW8md!0vo!iA?W zVYq+1kZbuy{0<2>;5`rK8tr-FmT$C4^7FDy!0*E2n{~i93WpYZ1~>x2)t_aJp7=66 z#g%7Voc@2sg?_?4{WZ3TCh3bF6wed~U?Anh76ZO}zK^{3Xl$F!>u`yYKE5DXGy8%g zBPV6M+7}s?V8!4wvR}&V7SON9W=GjA2W#zrrxwPM-fM7p@YbbqOa7(in9kliBN1utW^bBM9 zg-fBI|HRPbB zFZR%a=RL{>-w~Cm@Jh}NR;OC3f*-Zg#S=do!`YrkzYJW(M1KvtfGYlR9%X|QV5V>} z!Ox{0;}^E`fe$GukHTYX?EjdN?zs6&_dIr)FZOSX3>U*Ud*~{BN3bgN_MG5}RpB*n z{J{DIv-?$#*3$Te^=-8bug~f+QfcV8%y!{*B^}kD;M1mDn=n=wy=ENn75u`jEc99*7$*KN%bJwS z(@M3h&0g4f*BeJem1n_AH9wcNOV(b3<<}7}a<)V4$L3F^6@GYqSIPvZi{O1e0PC|3 z!9xOtk$LM#Va@2ivt<#ixmj^l913eTA~wK}zNq{r-%jjdyc^n3zp&;2{YcVLZS@D% zOgo+*ZFSc{f8erh)((aF_C*%*re`|sSzEny(D93^8)yqiTitb4>GSA_GPZ`g_8eRA ziOrw5?4SKlAH~^tQvGi@&T!`pmwo9R!w@CogK^`Jhuki4j8q1P)Y<^O= z&vDfcruTATVf4qi()Y{wg*RK(f7U}+c=MLYMd8h0v*^Vl_aa?<#f@9Pv02}|W3!S@ zuUwG&P`K$1(h=s0Oh)8+)-N6YUv>$C;Xca5{$ebB0_~G*u4s#o@dt*!K4%y@_RD|G z6~^pt2cX&ty!OAIE4JrHnUp^5L+(Y~FZqKDE-dq49+?NYF6RiwimCAPQFzl&&-u@g z=Z}8Q8v2?m9(jbhI<38uY1uJWb;PsDql)NLu z(FFrHg8#(SapZYv8Wu*5U<+ms8W5{+NfMC;xo=23AJ>XDf!E*<9F%bs zd`q5+f*ZmwBk90DA(-Y*?=xtP>e`@o$65tbRQRHiON!)4u-Cri5qws#$W~I~2pWu? zN1ob(x26Nh6KN#HHfg9l!hUSUH;9W6h%Tb02wwjN6Qn)wrK?gr|gf9yg4eu3Yh;NdvDG0;}`ScSYh7Olav3IhQho_qm`?&hg_8#_cSZ)>7;(6b{l)@&1nX2`nw!KjAY?x zmF}BiiV zR2ziPx06G25#V2YKRHIl?P}(PacejDp4S_d+`clLm;7YrZ1bby*;(A+{aN9Syr+G) zG#zGQ3)Ym%s*NAXDZHGEoMjEusNikkeivV^L%TfNcaP|67k$tB*_u}KGdX;d?%wT` zK|zx%^1!z?j`hGfnWvL0T>7Gm-5jhFc#hCX0EPR&GiApN^H{+eaNT3%=|Km@0(N4@ z?c(U`YOzIbc$~eOUf14$nS>LUb}gNu8tFlpNi+0?A8g5f30<>NQ$ttmhL6-+Yae|Bv9WG4ze!YlC27=E@o#JM!q>*xbrPdzT7%9ZXW4D!+den|Qn7Zkm%m=I&DrZ41C@Sp4=tR>{qdGXd#%B5 z@8)l4B-qR;pQ0laduYyUF4HMqm#ZB^Y%lOp>W~(;8r$)(^45r3K-?a>*U7Dlj!zum zw-?xOa=;%aeq#3hpq|IxQsjpHv1J_-WEd%PJib9kX>X&i^ZYjVaraO}vs&<|zF$R8 zah1gGEpH6tynJg1n`F|ib6m3*V{GgS*h%1TBHYen{HO=bSy8*D7Okp_KZtc#1CG5fDHxb(V6fCueEpvc6RNHexuLcF#&_V;oBA4+X^$CFUp(! zpTNK)IR0eVlO1-*T994zxyH9x()t3zOP8?)riaS`hckOWEP(^~8CQF=3nwG}3mM`m z#_u1lmv=##iJJW>b%V_=`fMBYN%NSW8?a?N$X;hYf=f;GQ-^4d-tyETXWIZ!k;`jGQM^_IF$4`%JoUp&6^D_=0Z17+_Ao zew6%pRx~tB`g`zWobd$?<_edlZxlsi4>W^L07`RCX zdAGmuUG&kFpJT;iXnk|0B2#rb!FSv2RPf7s^_&1*)y?}*@`{`oA0FUzt$rK(E72u3 zZi633IwB`t}u~- zN4uATSAdJV=r+e5d)DRKPpl&k^_M^&KUzwinO!iNc1Q2}DYGy2lq_{*_ueeTX8r({ zNl(3oCHy~^JlfC{(I*HH?Y9P+#;$r?X}OHSaoy05>`B)f+{{M!@EtATl|uo;_R;aY z;D=+r4w{x5qppU%o8kuOT2uTVpk!$w6F6Atb-0YB3UeSF1e$(aYYqIn>3jU2q#PG$ z*~S5$0wH&x|KVRL#viy_#a@R#f1p3(>>brJk%x<~YW$-s4}S~%g3||VcneybWH;zm z(B`yfj9%wYvevig3vjm3^svl$KqP@pw;$_JYI%%tBf_o}T+aelXRo!PTeueo6=Ti- zT(eQuqdZ|5Gx+2a?o=ZHoDA6)4Umgl;FWc+9Fqg~Mfi9|4;7j1{#wKzji$Yo!7h=# zw8J{aC~X)1n44CHlW=pxpRcvcUl;hR$A@qc!ECrfzu`hhx~;KC|De^`kO{vE`wXae z(78w6mE+r#y2Ae=s?dwini*n3x`H1FIRyR`Y_Tumt&R-CF<11V5P= z;reeTaIB>hPwIrZ|GVACq>%G0egSo@j1Gx@7|8g##hcb!vN1-tu^c;`8r%^0>a$rW z!2`M@Er?dk<;-^{kt_F_K?_%xo7*?OZ!5{c!cGs}NG-Y2c9S$}!@pqRubVRh&JQ_{ zpNeJtbYOSYH$hj75`+d4=MM`Gj$gD+LGi%2*CpH``2CH468}WF@F#(1+9+S!%Z+KT z2R_83`B8)&RxA8KRs>%Rdq-wW#%A#vT2=wY*D*ZVl-259@kQGWu3ZD6@I?I1Lof?mOP6&5SOA|;T&1x zKMSSv788+`&S)Ot3u}I$i^X`&r@mKqQ8DUk;pZpn^oSs$2KHkOQ>tU)8j#i&aVXUF zX4SRcDZ9vN5c-AUlZ-X^DD%-rE5VjeOUiD%!$`tP>WP7j!ebm;?S|GI$o^v^vZmg~ z-mvlW`bM8IzU$)WkfnfYI7gF^IJgF{1`Q`no9ht0BJ?#HvMvk>hi);huC*BbPVLrb z&STvYohi|y`jl@=gdemCZ+^w-3RxBITMjZ@8OLBhE$-pSNTj~xdMaVTWG;7#pUiadss^6P4Y1*vWYxW@<^2A(K%}e!SqccW#X(4*n z_#G}~&jbe#@V7@-fU94yaNb?BcdoqaFmnhOp^C%IG5?6)dWl0@kXxW?cuI=Qy%X?e zI)|~mv6Yl`oP*CE5MN7l{t0`lp^HJ^82N;o?h@DN!t2Cd%&KB@#Ikw9))bq|4GoS7 zFB*aqlPwRmimym!7UqcV*>?KVj-CGprjmgE@ai%MEmJC60IqeKYB-Uv)pt9LnQY)ZkV!ny@1stBRR zoNxGMJw|SV{-i99$P9?@9obZo`yQP6>_c4@AMJmS4?RY$y#%jIvv1tTH)ZZ+*R%k$ zqabiWN;>`V4NQ{m4j<3joSVFBq!>c0bg%^^9rQfer;rbv{pvk@{us!N*OE8Vrf+5G z#E)-o>Km1}Smh0Vyew}idlb$(yzjJevx}{=vE#9sylgPrHMwTfB0{tQelp`#pM6B&HQFn<^`#XJHt| z20hTvxeZn_=mhNHspov;^?xG$mg)M4-p2@ck0D&zaB>^$gYxY!>;XotP>Yu(tvk5E zoG?Kkgdtt5!?7@A%yUv3qmRq}@e-M=6|{N8^UYPBwav4vAHX`EVOR%){)mMmwuLI3 z#kODM;^r< zHFwj0E`2NsTj&FFRA3%!^85=)pB}n~&zUUO$z#)tj%rZ+2=bwaPK18Vcl9^vt7DXS znJ;@tbyL0*eXfm*UN1b^D>^eRAbJfAcsdb4Xid1jE(`*%1Xi+caG=2^g1H30Pt*&x zJIRF2&7s2jJ@yeKi@cuw!N9L$;MSML2)->UOOp(dF1`$i1D}_Q!*?m)-Uts5KJVcW zLMac(v8|N*-8nHVBZ|i?{(dAj{>}>>>G{Q*uYp;cG$uswx2t5Q#!u#G4UMS=+ zk394=@{o4A_dGP1?TQ{4iG*#D-#u(>&#-|jcmVz{1tZ0G_zJxsa7Fx&N!;j-%X%yM zHyKJQzoo0HH_5|B@*vox?^Y}`S*dzxanJc2?2)o1l3!J}K$T4}UbFk&MGtEY?aDoF zS}`Df&V}BrbfT$A4_(vcSTH*uLpa#gH}*ZxxYyGho-e$Q_3yBPiTZZK!h<6)&yvG{~7TGo{oYCU*BWwcc^ zZk4LGMKuX>6MWK5kT+;rDV{O=RJ>DpYlc+61u~!P@7i_)i zz$1&ykud0jy1c8yAt&`LIuPu$R5-mpt!QKBICDbUE%YvsrDB(8#UWUZ=+twJe23zw zTdc@0eq1R7y4_9JvfCYNR$8 zPoLHeyv_xSt&JDhx?-l}&g8zHV)srtIA2o^t|`YN-bgtl3?v~@iQguEvaP0Nfj=yI z5G=mqa7oJJm^I}~E^W_|B`j%%_>4clGT?@l;PVqcjq|$)K7#lL4wSM5X4P4uccpE{ z*d?P!Wwtkt_ICS@186Z_*32%~=a$|*3yYMK=)JpzkF9`9Yb_fg?HVN3+5Av1=l zN5QOeW*;kvwD&=qO1Udi4h#*L6z=B??eYfy+mQWD+JXuLPClte{it((X8%Vy^zJgW zpX~rM3AT{zExZ!jI?4aSixW5{kd_+ANV#UOM_S2^cOySud47nU z@6#ETbj|qQ{u%>qeAD4$PrKm2n7kX`#k%wrboe9iYL*MG*#znI0d(v<*83P)g!i^$ zjks82#(~T)$XCxte{E3C9|!^H$Apg!@En98s-Fhkxqht{T33WLDxS~`x*l_?#915C z=Q@<1c_~h0UaF@KabZG^|9AH87jNDBVcVOo^dJRYY3C2I%yo{~oN=~uk{0~z>)vW1 za~Wx)LsCJ>NXFkx>de6QNA?lAi2~W$D(cG(dk@L4=QzE%C}EoGjg0Aq^eN1qM>00K z*KLzFK_|SNY)sZ(sqZp*wBh!qzl8Yu#@3`*yWLh1(v5wL3>~8=Ok84ovuIA$334G2 zxGD1v`w3kepT{!qXiKMT?x(JB8WixcgWWC|73SrdFa%+7mg-1->2x;!~G(A8LbzW1TTyBSX16FqYn6gaPy;SQOk968iZonsJdr!#QKI8djz*4pVhI7tlg@YXg%nB#*OVQ zXn|7=zPz~0_<+iWxqwNpu2loE3!{yza0^zG&IlZ!M5kT1!25FXXAq+UE7vZVrj_M^ znElu(8L{s!dQpljS~uIs%p-Iuu$yXu>F$x0w_#ag-hkpqAdJb zu!_9I^QGeH&VD^t>7L|yfBSmg#~~@c=z;xB-=wa}AIJ;f7$JP+!t?sRhMXmPj=nfHB4$PMUms=9J>=OTn7h%c zUO|fFyl0MAV996ToMs$sMzkiA$qf>~mKP;B@=DG;b7#89dam?>QxzUDPx&cFINUs$ zWAP)0o(=v~y1eh9O6`1&>hsJ-2`bNFS-^JrS zT2S}*1)n&r_3nT_kz2YLo1=PQoPlw`T^8+k;UHoc1Q%2;AAqr0P^@Xqsptg2fd}6K z>py%&Me#Wq+<0y>x18t;`QoS6;J)V%@o_HCC>b1`{MI&aiTi=a_7d+!?nr%J_{N3$ zjPo4f!DlKhC=}73Bs^!XzRoY$rIudc-7J}F+AE##tbWy!`3Wol&TqaGzH{;bxX3pR zi8HU_B!B3Y7r5?>7x;Mg7X`65(h9i}6Fl?Xf%)8%^ug#Q%_>I7lE<*Zv#>oZ6$}Y~ zKDZrtZrYO89KS!R++$$|VI3Mv`8Qr_Tx$X4G2FBTgQ1zFE&1q-a07xRby3QQFs`7& z>28sC*g>j1N}Y5@aGL@1X+LntzYP!Ox)s{Y8QZqGyrPpl#pzG7GY%aG8c)5sD}XI> zG-VP;@dUXC;1XwUb_sg0iKa;>z6hqu+Ko*a%R``-1F(6iW3g#rk8)&QbIpDSodRTUS{0iO@i~M%xeJ%(dP92rrf@W9VG}(Isu@8t@q^OT zv1c^(6B_{HWWIUVk}3FBm3)kK6PryI@Xm;{Eo5W?L{zScehq1jsETJ=NS{<%aPUak zI$yEn7~ephk$2OZ#C-)`;Ow{7rkp`k+~^M69_euIxDmJpChB(m>Ugb}a9APf1nuo5 z2=D2T=dZBi!^y2TR|WXcs88%3R6Tx?&Bn@rO5qZ-cwwstP3EsU+`Q^=nRoT600+dM z^rm{;JsYt~n{_slcI=qSmlD3PVzfaxgR}mlpi-s`Djn=0yUj;B!5?kqJ<_4waEkD# zyJOx)Fb+Sv)ScT$e)p_9k(Ffar256H?&1SH^Y^+-T3{K&SKZwnbq8&Guz9tSq_c&v zDobZ+JnK&U5M}+LZD#!eF1=@eXrd7G><`fPJ^O>!i7wjk(N4engJG>#$o3{K{=;cE z^o;Q-aIZW35H<^DD*dOoxdc_xslLiJi+-McBC&s3!T;1JJo_6H>W)^HerRYLc-}-o zEulSx&&WjW>&CC3gZRsS^#|9+hEnQ`{ty|-lD_;LQ@{Gd8vcxDf9NJ~M<)38mcD2} zYV|0Socr$mVf_>mdgHOKF`RD$i`(-%U*|cZvbTWVS2W&C2v1I<$ANr_{ zJ;oBgRG{bkhyD=tx5p#;IoYyw`0~u-SJ7aQ5me@)hE6o zn0WSyQCN8Pi2#23|6-2NeU#c5R9U$iN4RXB{Ty2iRWA0`f6fs*i{194pEs7+486+r zk8^})6}w~hs$SL~R#8_6R95`~tO?zCI2A?rGs@<@H7DY1-ZREsG2t1@yyqaw3iRoH z-dn|RY=QBsIZ)-{J|B5jaE5-L;vWzneu-!1*=tljkinv} zSvg_5_&(U+t3lqa%ZOg)@~+ZmPtG~?9jznv&Nq3t(|H$Rmqz~j!F#lIjn1{Y@zXW< zeQv?o3LRT_5&Z}sfNu>0zcLEy4ucsf%?^36I4~k*-VVv2x7tHzSr*+GJHtd72Y+ww~YxVV+cmcy@g9$(M@Y{O&(41VSg-w+^9o_ zx*w6tbJHD_EA(0U9_PKJ`7VC1Cutjep#m47$e|=$cgwPuz_gQUH-4>CjIV9-f(%0O zOY>~zkx%wCRr%H)L=B#y=OcL!Qmz|Yb7A`e0P|`axe6xx}H5y+Z(L2(X_)iSfm zy^Zaro~Sm?>UD@={XGpdtT$HtY-7nMbD!t5!FbcHDuZE<;VU*b@VZlmM9M&0Y40mP z?i|CoQLP3)+^_>+@Xy>>g}>;RaferyhIavebj(B|dSGD2 z8ZMV9s{GQX*gT-)QIj(+b6X&Gi4uosJp$Wp;1~fO zW#=k(^cSoj&=+&iQSH_Hy(0J@T1Fb{&yDN0I||JLkKuQHp(- zl+8eI$2gQl|LCr3B)Uav{7arry%_TeWNAug2u?GJAHSp$kLWm=DY`wVBBB?l;*n?8 zn)2OsiO_rz{jApeDz-{sn@T5Xq=7rDc%C{|_ju%tmBmAP)>?F%GMywVjwBv*19-;y zo%B^)+JXg}xWo@T=#!~WQD*%>t~b5zvHnNbL-W6!NZ;P@=a6s)HoKgczEJKy5Fzyh zHIepqHO%~R2iJ7os~jp~GpAwZjBzdzx#*Q|8}ZfJ z*AAEp>U!jwIhKts4BP1Nd)P0IpczG_1>fou3Ad^T4eU|!2`V-_Z*?^Ql2h^LZKtRP z(r^8f@W5F>=JonIDfs6EgDI_EUdJnRQNvOcU@5&5pA5(}Mb9rC$aotdk`;^2#1)9Y6^}3@U*50AMuUCk z482o-s5r>LsPlW`Y~15S2b(-n&%GEq2VtOg#XrhY!W{_bJi}?61y+Z0+mU{PZimLv zd^xW0EpkcM2$y5w;cs!j!!OW}+80XLwa@jYB5~7SBrkpw*(z(Gg!j0oy-0U;ki6!( zZ#HEggg!#FZUU%Z_8@L=uiN!s}xWM8Z)rB z=vTX57yAo@^6+LOlVZ*r!yB;%!@-`D_|P?%_`zz?U*kUXzrYYM{Kk(>fzA(z4y1t3 z@<*IJGe;AslJIGJz%v(qKu{A}^^9xdb{(R-tR0#vpJ2l1R7*Z@Nu0jU|zYvn)V z(oh#Mi#Eq>zc~urRe+wj7Qf#*ffL*@{%EuiZb%FH3u}JL>YTLTHbi#c_j1f3=r%g2 z6>&1c2hGLK>mByLK4%@fAH6-EAMnDWo^#Hz`*H@5@FAL(JWueCh5P|7rjADZk{>3J z+VCY~7)rY8Id*C2(UWc&I~k{9Xu`NFJ+@*<(%@k^vF8n0KTT{i z`wIsfy0x9yg>vvlxDC{E=+)Das2_Cu*(PMw74J&o&C-Bhf$JLahWI`3u_vpt z-3N=-%PqZ+Dzwn^u7wzLZQSLuL~Xpj5{`dboOQcm(^ zYYVm%5{_LieWQ$RA>~=B^1Ku7Q64SI+u~mX{f~+#ZKS1p}FB6!k>UYn+uwkG5>Ed$lZQ#Cg{>(1@@S^KNH+v)jVbq8N^*1v&Dy_sse ztbfBVCjD-p`uDVZc=S8u>F(s-7x2saM_({BGkTS0a+tsR0*VN)y1UN{_c1GdyNgbP zlxLqbHS9$1V-{OsRi5&neWBSfU*$2-_f%=h7?pkay*yk$`+}vdbNj4)=Bj;`z5Js* zoJ{(=tbH2ZShY`rjwd#l+7Dya)g#F_`*q^C9%ECkkw;=vr^)jM>)Ee<=gR!EXj{_n z8W=S_`yFR@@Ko?*?9(_5YR`TkoM_KJr{Z;GUMJpvoEPZZp^Tk{%&C8;0dJg21HTlK z#%oSEY&x$w0joxlFVnA$Kh}rrc{?hB7g#%h{>wSuZ^|2dCLqb*rW*s1_`g#vao=P@DLYap^m;3I>~4N z!Fd3uOF74RWotwm{21~aZm>?35$m%522cAD9?cQ^igO-R=7Gneb8)we%ke(#%lnd` zHgjZuQ6srIYoqSr8{A{Bh>A?!XT*BM1r~KB+`|9yP(1Wor^G3I$0_?N_QLT@_Ah1L z>+?imRn4=B4MRB=tg6NHsV{p^WQ$yPy>EPEk-l*}r#wq}X1oVjQseJL!g{;|bE>!F z*l2My6AX{%x(`*efqy;t_LPbxw(~p!#p5as1YNNj@ z_2jIjZqQ9j54f*kyYj5F#UmcEV-yTq>I|KA+DO$I@3qTdn%Eax_mi)6F!emVV}Zky zHfFD(ew!W~u3K;D#B=w+^;Vq;23=MA~baa8{!)i{NYNRUF?u?h2|FNZzJsJ3)s5rTA}o= zoeCQ@o`@6f>5KG%2)<$I1ANDRpv)PCv(g$%#Q>Xs`T+O$KERf_dzSdUTo82Iq^FhU zZ}@g}4fvE0YHztZmL6RWWKtRGB@Ie>p00&$K2Lhibp@CNo7z z5X%4f8$%5u_iSPAe`EC>tK0pL;?eyuKYKjT9Q%eB!iCr65o(RWo7)8S6l+AyuFHz@jvf-ujGNdrOf&(PvLH` zo0S@uF?`ySVGxf;jXge`ufmEh=5l-wqR1cACa6+Qm$dBr#1j~o1N15q}K%A<7yaK5H)5-Byt1b#E?}$Tw^&j>jf8HD0 zMWWhG)m@XQ;-;!o)jy@|_rL$M+fMw7$Q#=qtzismL3oDIDF7CL z)H8h>3wV9cf(3J$Zz}ILH;r<+KkW2Vph^&7%CKIMkVAvrLWIhgUciS zd9(Y!)X$kTAN~B*HoyA$R@%roNZZ<~tZ<8PZBb5!>K|pEZ}cW2vUbO_q*PTvU#^?E2T7V-+PHHEv{Ik^LFB#Z|?h%}r`Nsq|H^ zQsv0z-@o2xQ`$$pnmw1u>}3u$MDop&g!74e0{qDR>kg`j+^>0pWzf(T>brkA27b-G z%&BS&h;>Hxr}rj9=CNI2Ij?tR94w<Iu%F09it8!)qd3=8V{n%IkMR?Uk z>+;FRzNXg4d;KSiil6(pYa{#W$KPXZWF5c%8Oy24m z#bsf#o17)f)0I!mIMBb@abpbJ(Xn{_bmnX+0ijL#*}&;lyD) z$~arjox%60SWBW?C7FyG7v`XCYe!q{?82EDyV3vJcvSuAzgc&HM)7r|*5=3l z;@Nrbc~*JP`h{v6)p@hER|Sm-XF9*el?fvn5s@4V7cU1(aZoK9Sp7ft%@b_~yCRH}oKOc5(i%F2wh34A7#_jJs{kP+m znh|p&Nkfg>U*qVpcD?4q|39x?|3{DCf4lFHv*hFNKIf`^sCtHHTeX)6l7BIuJYxB` z^G`PJWc>Ws^AB>-$KPYmTYWsxhWUa%up5n`UGoQ+pyRpQ8~*2h;PMZ1&}BBrXKVH| zHrD6t{PKFDz4oUbPf@v-^^f!8qyNiTHPx`o)(6pCU{8>(;aY4~g%co)4~ZpX|Fx$6 zec!0%l0W7yN%8)vnqyvT((}CWYkg56RXCO2JzYcVBx*0BUT1557LVjBYiE_OtSwZ& zU+=Tj9bE1Fo^O~P)I4-QZ{6py`}zpnWobX^y^2@OUw^NcLH+H&374fR|NcWb_44mE z;2*-NIqP49`xT$6H+7Fz)SHx}^_-9IZTx80dQQL3*^jz-TT~rs-^bnq)7r;GJf zRlhahxV``V-H-40&kUh5PWk_?FV&9OnfulL>i)G^c)Kntlly-56)21U89jtk#Q(pw zH@6jL^uPc8|Mz?Q|5x@mJJna!di3ke|5%TB`REA$cxJfCnV~zpZaGVQ&+R!J*Mm9x z@WQ;V7v{E{t&{Hod-DjJ?*GTuaaoK0X&wK2AEK*%`Z@yZ<%q!S`%uUj%dO zJ9cFr}e>Zvb+AqIbMD2R~(P?Rdj8 zU8>q7B~l_E^3V-5E+jz`;1zVaR!EQpcoQU15~{BP)F zzdJ88?#^mo^0S(LS_}T+{12Yay#O=(#lyH-r`Rw;}`RU93b%+M{f8ecd zRQnG&c1k_|!;YQ%)2HMw=H|=i`sDgjGXtRAbRHf{H)j2 zxnuc)|Kb|Y!Dg5H&<}uj%GO|7&#(SN{{Vt5Kq7%VeG6=T{`vewezS-ed+W#Z^@(TH zSM$zC^?ZCE8yB?sd%lCpvteAQWF6nr56k=6$MEp-GQro`t9Ai=zO{+57Croa)14q+)vdU9GCzoKVkwPDmle6RQvV0YgY5dZ3M2qpX$kS zKKNXp|M)v|37k6BZnNvUe!5NFGK=qL`TD**!h>Zk{(3A%{Bhlvt-22|_ z$&zP%TPOOI|CBHJ$NS62d;RsFE&iTw zjo1HbyZF!f(BEI)|JCLF?O(tC*Z+I;`roWy|J#rC>wmZSXPo-&?c@2k{Q0-^<%7HY z_sMU+z2|R9_S!k#ZwX z`_pefhjU&%r>E_2^?&#qSN+55|MY7d*?*e-lYjYdvVZ!o|3&tnE&loMUa%U^Z`WTjFS{foA%=T+yw#-t3TCV@1qaQNb{W;1X-~Svykl#z%1dpor28lmEpI_-Da{g*Ze+K0ValcTbfK{$-0*hvh$dv=dmO zA913;+%K=ySig>ejUuwwCRF|B@ygGu_ly2~JEr$5e$BN%{(cuT{1yAPMUVV3H$N|= zvyIG6^@lB4J@4mtp7ht(SHJ%=Z=*6reqZh9UwlT|70^SV+V{zioAb}V!^b$R=%b2R z0lNGL-LP6S8A~;1{;wNW{*_;J{OiPDefsm*1f;*L#{9L$)tdb^FaDy2ANu;WdhQRJ z;MUIS*LqZQRei6<;Z*i`b$vC4>bV+fC7WYcZ`A~G95(*z_i|P5PQHKXzrVH;?_1f1 zzx+wAb6naLeSF0YrE%Nu=>pKfYAt{0vA-5cnuqZBYhA0afBQc_$4(qz@$VfktL{&H zjk3$1@=5NlD_!%?bB}PU=blwQ zu6uxoI-Oxy&lZdNG8FwIa!>E*BGVQ6Y&?(XlPx_bzy5X=pZM{m8b?KR6&-)A(brh3 zE2@u1K7P@4^<6Hlj#ZcZ_>c6c`Rl(sp|wB$RRs9QgW#n7jlZwm_4Vvs#p_6I`m1M? zIdsjZ!^Z4`55ejlLfHmq$bMGsnf+ z@9MK45)8u$xaB-?nbN8xg|hhU&q+NjRH=HJS}eQ4=4fjYLU|`1&L-Abj^4JxJN5h4I(=btG>2gB z9aPh6SdGqV0>1vDei~AkGoBosfEYUn?{O=h)93VNnlW^_PxMJv?$idw&J@uyaf0G` zt?r#2KC#2H(IF9-+BqOE$?cqVoz5tpoSx?^Qf$VIjbH$E0Z@ibv)_^W9>jyfH zFQWXEZq%;P;WKLyS}){5g73Gh(bMS|VZYe7rbQPaNAA05yfBxh5#Kwv6~9?~v**uDI0Q)$?=EPnyvvAx9i$#OZ*NqNz*>I?Ax-HRf-qdQxyPe3Vp*}t-_NhLQ z?ReUH8rx_Fo(1)0;0%NzDWRnIX4N+Q%B7?!%q#u1fLlV<3$6VIXf4@Gwv<-L-edo> z!6s73=N8nVTl4|1K8v?;AUKV$Xj^<9S@IdVD%Xq^IEXzWW*nQuTu&Wi_4!hLreD1{ zwNvWUEig0S;-=xtX9P8$ANfq3Dc67oh}t-owrED>Gq86nXZS^d);JQS2u8ykTaTKZ zCH2mXHJud{99lG`;WuUrNs~ZXy^CG@ziqZR$wpFs!ppUnxG1o%V)@Rjps_D2PtQt& ziw5G6JdI?2+A^_dJgWBsIwgl_%pm6r;qPSCEB6r;Y}%oU00H13;Os=`XI+F4J1Y1m zH!9mR#m)QFutN;^08*A>@_2OvV2>S5wCF60__;x8fzR_)JD|_oM)P z)UQ*~>Pb&fe2S{6UX-wJ=aiylIySJ%2_mT{>-Y&$Jtsf+wW0$vkeYN)le0lDaazR!`cd4b<4_FQ* z^Z3l$m2K6Q`^AfbCml9;H+Z^YyrC5~Q@KAiC50Zg!)+?p@myEu2^LOJ%!@?G(YQkj6dj~2 z_u8Q$CFA~G7=$~(m&Hm9C-lzh%Kp@*PccwrOzo!Y08MfYHdL7IDrEVT35mHD8Y<>m4Lj%TQfO}vp>V@N)Nm3I!=^=j#7?f(b#ruLa z7Nlq9ti^Wr$B234YA|uM{I0!l-(=4>*%06iWax&rnOScecxNzvWPh|+-&)#-?$g+k z&cZqvvVZoDo(X@2ygmEG{5a8c0sHoxGM5JHS1y0SzIL=yuMo44Y%&!>YW*8@PlG<{cANx?&RFF2jH~JC$HL4CK z)3LD__ntBzS=ktQU+nB7&clO_oc`I{x4&;&1(ht~@bakuQ z_sTI-c%1Mhdd5xFbkYJLgn>IW?r_WSz#ac=(y#Hb8G6gP9ok;j(+B|rX1ftIFWq5pviUX zy2k=0k*tmJK;BVV_+Y0wP|*HH+O`9Tsr}ETbS(XJOfsjTB$C< zzY!Wfq|}C!(vdLFKJ$#mMvtp`_GtWWRKPfDM}U@6=M`c1D%Cz>_M$ru!8kE z^hF+>HKB8wcR=~6 z#m+sm&vdyHKAx@`7vw{__<-6Nd`#tww)@}{6%VkFN+yoj7k#JT0sDv81Inyhw1-4} zekZh`v0>?wyoKG(VHZCBg(Nw@__D%FYnm4io)dPaGnK2eO z8$M&S>p_;U^drUMf`&BpSMDc}PAZrG1_7eHGv*fW(+oO=6oixRhj_}EJS}`Drh*Cs1eaWZW*J?)O#IJ zxs>!Bzdai`j{va?KgSXZ^%Fv*w#OKN;iG?}O$^=!Z`cAN&jpOZ_DW6FAR1_3*Mvv0 zgCGDfpKxN=d!^hP>!sF1V31uzACZ8jtctWhXa%sp;tfR_nx6z?v^U`S10vUq9^K7~ zlGQtVnYk|sIf^|Itl(~~G$9KoF5T=+7|H~Bfe31={edOYBk0^r9EY^EEX5QxuXe_; zVt|3S^3Fvuzo(xs3iAZz+)J&dl_CdhkEKBou+7{|wy}*ZB6=b^=-edoyGt>@V+tJU z*;=P_6&xl@hAe%g0?&MX)4WF(t(&i{}wsLahAZ00#&2KmG;xi z2sszFr;sJ!gfxxP^bn!##C=YK=PuD8TMrd{rD(#Bre?cgC#>KJ9S;B`J3t!~xn|p` zZHfY6Iy#4{J^>rd4)rwZ7I@ETe5!5ltWm5$HycX}5>WwY%mtyFJ#BD#K`gmJLZyva zU|;I#B|Kmo@aCM4n82$a0njUm^eI?WeFNXA_O1pb)bgk76RUq0Bh9d~aA8LQ!c0i-A0q&vG1F01@@|U8)NsNh|{ruhkQAvMNF7h&LDd6{%H(l27mMOJ-7QM=zmC*wXyM$vtg zN4z_Smtvkdr^H0AbnT^cINe!g>W*{ap_4hlr`DV8n~+IE*N*W#Q!?$-P)qg~UF(Yw ztCt2!$j`3Ap$VToTU^hJ-TPW(nU$f{UH!n&W1e{tc`=3|GS~0036}m42bRQX&Th4@ zts9^Bejhmp-R8Xudnc!DhUj7RdQA1A51xbF_OUB*eAL0TS~=l`7M8k}J4bN;EwMBq ze=2RUTO35iAiaciJ?w?TKs3NTI_RC@msW9+eI7XF+g@KyO3zPd4Ei|YntF$C6@ktug_tLdnXFeEpU`r{z7zD#kuOYDS#YTe+crC=!3JBip1m0 zreFxz64>S=)Qf?i#q$-~P>f{ZJ77^k%_=o|nt600wD3auasxvNOxPY%WS}nS^YEP- z>ja=Jb<8sdUD1>?Zn2;NZGw|^lVO|CmuiX}Kgt;4QOO+Jk7k9e6`7%;FLm*}A@fNc z#h!YA2W9EZz5x$WK`H3#7`@D}HG`czAI^)})t)&n7JzQgQN%ZFBW=;R7o|gc?-^EJIELwu zec$egPTyvuf*s|-Z&7KPgcRDYiyhuW!D6n;E~Yw+E&^dm)l-Wef;tP(N#a*48;JxT zn8zfv4`^3|9nHKhl#wT0hM=xwV)KZjU0~!CerjiMT;(%wIrAniJb%ruFa~u6?kq5Bw+|2m5_QF^qIej^+NL{^#PDk>TnH_nV8oxuTDY845UfyV|1b9 z(}ANipVhpK=|WEbDr|;;-tRnnk3Lk=yOsT;Ql>j(Gljs=jN1b+=AZ^NH`rYh1)Yj{ z+t2*5uY)yA)Ac`b{A#d(RdCYCQj>|2UqHLZo{PN${CG^C3TXE@p+^q#p=3=jv`jpr zA_M_C2m9Bc_80-@<^Aw{$~*4!g&n6Z!i5Hxh#JgG2-3Cv)WuIKvPRBvzh;vLa3kjt zUo(3Jo*U#UX_>4KiG_Q{Bz;*x1=s*2H|gMmsi_^{G4zPCvp$Ah^V5oO;)TbJdBpE( zCz|oZ)+@CA#J(k1StZ-s`vOfyS`=DXQ?a*brh7JY-k#Q00>{;eu z>PWw2v&Br0ri8!i=zMBYgN1$RoHOwo$h8v?ZmA7(w|jpkqlvzDd4Kt!cBZzXp@Z*RkHSyzqZNri)OR=W{Wzca zB8L5r=!dA45Ex76VPD|>vg~cQdvu2QMcE#8d|L-`G`SXgmUk*(XghN**;QPOnz>hG zp(#Rq;lA` zm{ar?*0jWTf}PYjM54>1XBVnI)@(6#O^3@--D`p=LSs?oAMFFID5YZ+JY!eGC?r&1 z=J$ch=t!&lMZc!+u`SpII7UP4Zl!CT2NXyE#@q;oLmmt=zObJk_1(mTyA)eUu-zQG zIGHP0A$aH)pg-z{gSCb&zumQw(gt%3(xgd<4*$r9@(WQ9< ze$T$O&?W3=h>>lPb_MVCLc0KspEV?^56p5#G}u)1?HaZdIv&&i7OOCJ zOMBrGAzA1t_e7DjGMq>r|UH{5@^x!^(z8u?zPT|CsKR0E1Py^{1 zuOf7E=0uHM=FAcHp6QVr1v%XaOXtd)?~M;;2!1?0q?wUAP4v*sM$hOp{T>r&G8Vk2 zM=)0U$fYPQm$htY0Y;YbqdQMf8cHV%p5fD?gNQDG{pRjXlrNa~e&G}gBiSKQr6Grt ztBP?k$+|0y5UX7v%bS9Gofq>mUW18^q9+=&F#B}ED_B^U8ryVX@L8V)TWcb50f@GA z7IX1sMXb#@oWZ}kK$!sW!~|9JNV(1#bBOL?=mB-{ct+`} zD>$i=s$_0>mUpi8xv%tJ_iL!P3E3ikZ5pvkc}zsN8%lz0Vh^519KnPM9ii5;K_YWTMwBKR(n*?e z;Xi@6Cti)g4VhfF*73`yBI~gNkXg|{e66voE#?j0%saUJZsf6M1G&at?L=py_nrLf zHFTZGl(<8C0j>%00b93fe-RiY1Q#*b6FNVh2G*@aaOP&`Rql4m^M=Zm^|SO-t=V8+ zKifOlDKOUD36v4&MhysG9ex4K$ur`FySiS@^oT10K%lL`3ydXPDUGg z(N9a_3ToqR((`oOPI%fgq8l53qdKqj3c%y?&h@5z^*!6H+!yh6TK6I^B}U*zK+aX~ zeuNzS6~L$L`eWn=bQ_J@JN*Q|%ARU1a5W7K$e)=pjSW0Pg#JJumoLIc1GYimEWCH< zN4;O>cc`?m26!%NH-pj4QLT-C8+B17`~4@XQ)E0$QnSLIZQ_e!ui~30X5jLi(m~v# zR6I*u(5@M^gCgC+DLwJ^y)cKY(aT4FY7&WOuun@gnI0t5@;+;?__M?2O3e%a{3;v1 zqN79~&}|SEME6m5Wc&bcgO>>^u%83Bw!tcdNZ?I5uYgQ&d0XEKqYrEp%}_LpfR=}U zT3N#=$7GVZEtc2HzToJ=EWH88?O^FLi{OXiySMSWY>Bnm4@|Yavg4)X1o+PF(=L~2 znICz3G!^D4nCoPNTk%tgGZLh!J;<)Ac&%9$C~7Wx5Rp!7+6UP-yf4skxMfPUYg?NP z9N&ABZaMzKo6{8<*-eAhkTzQt?@F-BpYGbo<@?lC6C!8uCwoz={Cffrhsw{Mp$9hV z0!BI=1w&-)iazDY7!b4=zhd>ErRWvw%8W2i01IWJ#=<7BA@Gx}Sn!?1-TEtrh|G8i z+d`kXo#{HDp2!-J;xEaaWzqdEDk7W1$}!Az?)d9-}Un#GqwbW{9Ol|h>{CPLG<+ij`v2JAg^E>M)fOUmNv8pNvYB! z#>_oj=HW9RH5k_An2*>gSUR#!x(v2QhTX)U?5%~@t1Vz+zKI&oY+Y>R96c)EyQPYI z<8bAEh)ItRAeOZ-+o8H69+Hdia+n|dkQm(xUgE{X6fK=^=pM;BGXalk#k~<*^^zD2 zGd3-~vW?s*7`=HQ-LNFa;^70|cbHr3_njH2K@dIE8mO^g#%+=90sMXy(kpRLqLP6h z=2m{ZyVSjzVVl%Do6rVMViFcE0Ix-OMK);-Um@}q-xvH3;u+!C(6}5tdQ}kn4N;R^ z=-Vd4UPF_%P3i_1Z=sfKu175E@DX3(9{eCg0N=gnAbyd^a)pLt+Zr+1Lxt_qH|unY zA0f&R<~5q_UVE$^*Us<_db)&IlSVKqy`~A^AL*(=e&UJkniCVLFD=)8S-Fs$n>Xzw(JV z15(56EHr~Yem1!d8*Yt00^PvZ#)ZZAiNo9sSrO8wIo=RGF|5UE6fQITqiu^mFwHi~ zJlK2@8*Phojn1!VD4@`x88y9w>@cVBF8z*8UWfjKPK15wvgz#y3sM>20c>1biJheA z?~Hzf#n?j<5I5NKf72+QC3Mh(*?qFHl+A`uBl-6dYP(1FL$v^RK?7#lQTvk>^3BKrP1_<&X@l+auUMo#q&$MKKO zO+@o^#%YyyEItys#{KNX_IOfbGdXzh+g+HS$zUTBGk4}ln`?-#id-f-RrHiJSE9R9 z$vdm|VS?PIqrkX!#V4Zxb|rslh0eh475*DaTZ;g!#0~N<2%l@GCaHU5IbjG#&`fs^ z!R|++!ellx=FiWrxU$kQ+}7%{+j)5sVtm$__O{xOtSB8r^Q=qvYKSz05~dZLEMd)0={gF zC)hTQ6k1<4>BobNVNr3mKJkhV_JtS@LIW%|@+XnG@M_%_uxDdh?U}MW3wpaS_8j}o zZuw!I8S&^L&K9P%vUy-`J<1r-# zE+0j{IFcWIUXGK;MwJ7+f0f7#{N(x+I2y`)*Fw9=fF?*I4H@dq7OXt^N~LANC}ZH^eY2}7W5eKeZ(2P zcSzFUd4KVT1p3GGlRJO$6B)!G4#?O42l>eza6)B1SdH)e(qi4L{< z^V^Sq|7L!&FLl4w-#>z%Jf$Q-@RKS0`od4vt_F_q|M&UHB%=pK$5!LxE|>;1JZP76 zyL_^zVE2IQWj8=Q`>bm)jQ5-QX19C>_m@hv{&V-VCPmSm0}rNZZ{SLdo&Z6$jueG| z6xzN5IJKMh!2li(?9*;czv^h#9fs3#7fiZ`!Qqs}2P)GiQd{MKOw~Tc2Q65{yMUhU z0l3fcy}93j8NC5VdGI`F(Zl4lo-FCb5@oA-0rCX+$%~ov{(b??;6~rCl785#S8H25 zqgcDnVL8t}*Dm~VJ+Fjfk|Ey{3jcFh?_;3*FpOINV3>sNZTTtQ}{hj)uPiXN-pEGICP-OfYeI`x$X zE5_^W_)L!-0Z@rg_Py~Q${)DNcZL3c3pcq9a^k-aZnB}$um;%SKbn{Pp7=?HoBaC! zA>3pgf}4yr-sIEY{*1FPtO+(v!jCp9Br4ow~I?NLS&IiRI% z(ytzyRG$E^!1E|m;Z$4jU+v-wNP^t-P~+0v0s~#`DNL9CJW?14*ct^?;HLWEoAAef zyJ?tnZ^5@eQ^3VNskSqK9ONpe5-PR)M}hF0Q}Q-7{Wc{|));cnm=A!4d>$ah_1u8R zJX6c8-HQ2kft(m0`hZ;3cu(t%gVdie2XNWB!-LGQh~NiO)47ueQz#g7j-s{#lu7}h zh>pUMxzFtIDgS~CJ-kqS)28`D2=oM%VBE7?gj6sA-o<-LGnOt;L)@;a{Q`%EinONS zLJv81W>;g@R=Cjjnj$ss`M?W>(=cRcYV@frZ%z{J7bf&JP;d4xOlX_xFdH}liqIS2 z)v9YZecr&@RDGpxg>FsT#I{z5rr30a5Dp5&>^WA49My zA%({S4=*mRv6dSv4u$~M-Rm9mPyL|@V)*ee0iT-+OUMws0zKi?MnKf38(!EPx9NcW zz$;KX6(gUS5TmDs3v5MN1B^oI!99h7v&^;wJAP`xLv+w#8>%&k53?###qedvTs&Gp zk+HkMzy_`3E*xn&MoIKa-Z!B-+={N6z2O5+%)U8QzmcMRhAOz|haz_=N9WrGm=8nk z7&t)!<~s$|dF0y()&3SW?JG!&LPYfrl2X?TV_99=O;DZ*f@;Rj4aTzeLs*A*=GwM` z-~{wal~;xgk+DnIL)kvMH=k{_O&c7gIrp?P z=9lYc8U*R|DBy@2q$ZT*vNGg?lm)v%PkwscOFtd&RKB&#=vz@cWJcL z#wXwj7=ZpjWJ=rcZcr_ubc)fX^il)H3k8}}H)T!*BFa7vtmr7aLyCyO>F5CxZ|->? z6czJI{{fI!Jk}BPdulyhi0Zk?n-04bt-6#0nrs?_)wD1ZWo_WX*20gSNMLo=#{-4J zNy@v-Z$0YSk4Xl6td6ConAXz5jS*zvQ!?{&XT`mSRBPJd+STujt>)TA4?0jQ2ZjNb z2rK@!r{0n404lWdfwPg_sqFn3Xva0(2b0WZP+V;N5D-1>bVZ1JBA~LO$xg z>$WJiXRNfw5UM?D_LJ04?+kVFy5JHWi+o|O*l^TRGu={wyKoUvp5UEPR0(a0!~_qN z@mgtUqs7hC=?Xy2hT~Mgf)!d`r~usp7Vx(B=(XgwskOY)fM>>MkpUt1KOAW(%fE0r zrX;u!Re2G-t1Rbz%*+G-wtrbkeK;lQK!-;_#^DVu%fv$r- zq6em-q6=s-fL=q($x6s7-f6%fDjY`um>K)X1E!2(SF>z`1sQX$rSzuaw`t$10cQlx zDLBNy+kr@&tr)BD^BtJfJLn^XJ`@_V-vcAAC3}M1Ab11;9Q$Mto*-qgAK?NBCMv-F zNh#ywo#TZMoph75h-`QZ&$NVwZ0R=u&Z1p79vm7pGhm_B>S6;74u~N_ao3lG$1qkd zv~N5YS`-Z1{o<(JljQwbpfiVH%M*;0HQWFe>n@7Z)rry>Jk+*h^JNy=_+%`Kzl%nt$2o7{==Aq3qk6v#WveZchfrHf< zf&%1NOtFc<;zr-Hbj%_ua6W^ug;wIAWvlVs9%2IV4I~HnBQOt4b~1HX9ybIeBeKs( z;S+cWU>Lj^_@`wo6nl%NF)|5fAHDq1h1AKyvgV&|@9&BWK+ysI~#Q#IvZobIr38CNl)>C;4#KsUl z+OO4~z;J(O4EvP`Oe?-dk6`$!CFt?y@&X2Vdk*)1wce>I2TxFI**1tpx=4UeJ3t%F z8isk^Ba=X-0KwLRWn9s(084^Il^I_hgo(l!NKa_?(Dh;312-Aoa{z?YkLWN8#xA(q zGS{^uBnf6<=E7$($Ao!cPq>i-?+(Em=Qr5Bb~_8dP|hE+fBN@s1}u1S?7)Y^Tu`HE?wVH?zDz7vKoFkytTt1Zpt= ziZBvpU>mgkBj&3s7zL>Q6F^=}24;B*X68Chw|W{3A77hN;4d9^9On%h95UO3H9F}Y z-Uv|yGYcOvSnOp#1v(jI7~y@ZJ7lr4L4p#!22RQsbN>mQf1=5$3SdUgxuxh47w(5H z!fbWJVkzf&A250asJIW0iVxBhMlIg_37l%Hv80hz2DA?511x{0Gx|h&f+42CIkJwS z1N*W71oR4|H4RY27V}Eh1a1QlIoq1g^Bl8h(6m!PvBvoToEP28jzhV(E%M9CNALzc z%xE*iGx|kS#^!LFtD5@UG^ZK);}b@a&x47up5`6o5N^ zTxkqGYc%DZ{3UcM`{vb1yOF3^F;i1Gq7Jadw&?lL+m<_}TvNjYPFTkJSa}zNW#w%F(Fl_fl7j;V* z6S^Jnm^&H+(C1)^u?RliCZA!sYZp`(3~YD}V5<`uANo{uB$Lvai(U-W26`ATn9}g$ zKpzj#HuNMO1BGF`0#dBoV04^Qyz zlM`BHKZajev0tC#ja+XrZu=dmL0MsR58(ANBQ5|~n3Id31bGgeQ&UG55Iork>=GJf z7G2f{>?}xD>!uZT!4nZ<)ZTmdOxUDpp@qf*6^qP3-xK+$##`hdTfPQ}e1t^t)ASiD zqzwq4U0hUW&NF0|MZdSUP#EYOSbtL=NIsBqq81|!w^2KimMi{o(|QMckyqmZ;Hl71 zKy)168q}fU3>@yM8hWDKW#|$Cs|*k09r3>faN z&r0_w*IHiT8NPXsX&jOi!l(QN8eI<*(Vhgdj1d(l1ul(l3^%qOxOG+Jn& zh6f@Y71lX%(KuD}&<11Z3~vo@CL52?PhK1>S^MzvNGB&ody&RH3&H&rTJL4Pe2>6E z^}zd?9}KrOtPXmOT8cv9mV4g(0vaJ67g};5 zXwXfz9Mhr_olM4O3%s{JavxMvH%QSd#+*;&JS!2R)2{d~*dg(*qLtLG?#aICq4zen zl^&l6whA&@1_rJ$E9PAhcp-aJ;2Ev<#FWsez6_zCodL7VDxiGu-Qcs}_X_U-&j;G( zdr-kxh6TJtdMTIyEGcH9&2N2h|Crb80)`~GJIE5RWbwlQw44Cmp2GhHWRfmUP=p6Y zB}ibYYUEoyFP7-CZp!zv09i?{8N9=9u?e_tpI1s4IF3E8wY$!s=|Ddse{H;0>^IN} zn-lIUV_eC1Pe9W}KfSLb@0VTCtw1bo+Tla!hMC-w@fWGW2Jyzf)XS~4j-Y#cq>FF+cFkMw6q#<)1zrXs^7( znr ziMeeGv%Ijm_>MtBe1JXU55<>IPwcH);ywf)GhBMm4bY-9cmcK&UL3KP&p1rbS#2JL`d6XH`-K^XR zS{sDk2K48Gr}{WzA&2@C{-jzvT>$fq?u362IdmBbJ@v3;;E`ynK^KjTIcWPV!LE-U z(JMdpcuWWKUGxe;%SG$MCwXP55ayG1-xB-rpR==jwuKZQUaDT^Thk$ZvqG#-z_l6bS#FJ?Y4Eex55{Z7UP6gn6#s)p#*r);`u0^1AmH) zI?^YkLk|Yz1c>g|0%OU7kV8Wb(0N^>|BeL39yUVoC-1lz*ji^AT+xLC%L4C|lZ7$In4?SOG z_-e*=)3ARvb8w&Efy+V=vU7OKAj8jVsn5s-;>ntZfaD`mLAZMZczK4IC_t{(tO*RO z=-p*^k(%WpIp7#B^lb2mJ&361btm|yLhNBJN`c7TV6ujW%(k2MOQHvQW7Y1W8a_s7 z4c8L1s3ysB8y0vvouYrc2wori@r5o4=E%N)fAbU*3;+quaepkt_uF0MK1XxWMU(WE zN735C)|e^jSkcqOj@J9YYEEibXtM+*T&QBmx_apV8{F>Fxe`|G+cY>>Lj~xfyqB7Z zPJ{&qWbg52M`QMSFt_^X?jH=bT)+;PB~?tox?6@;)U--$A=UuMWq}m6!L)QZg z%F{M2fI|iL7WJ_A9Il#H#{i@^g826Mq!@ECU~Cgt#1XgYz6hqDkBbT?x7D2_*=4~t z#uw+IkEN$1OT#YU*=#G`M{Gfmsa^U^+(QfcGPma32XY)N(yhxra6TrfcK}Pxe=0nS z-YW7;T+WPt?7z}w4jrRfgF}9BqeBSFt@wxLtfjQ%CI0XLx+9(!0bJ|`H1qRc zxB{Op-~GVLcm5~C-_^1^(YN-YsK%uquF$K+7{N#fE|w0CMTfmrUB83t0cE{I-$do0 z6(-|ez(^Rsk6knk17j?|6#qQ)cQvkeMVHO|jD3PLj(r+6U%~i8ra^ASU>750(tUH$ zQ@vxjce8xiw48Gtt1G%439`h&ver*X?%j7S9Oael8ldP8RPQ>*YLAFdIHQZ zA+nqUgeNG<(woUgMtrT2>ne*jt$0Qm?cNkLGQw;O!>u?8;6S_w`UVSb(+UXQ(krhy z;(3w3a1f~1pa)$%<9(e&o1&foChW9V=x0GF>n?kQHHPrqgIDgcs!(?=GiwtWi%Xi< zVzgV>Hw@?4eGIUwr9+`Q<{UOLKBE+u#s0~7_{4@F>bkZ9*8jZs%ADnK{5*pW*+WB3 z->i%1tVmuDpiU3HF>)l(b|iF#G)h*nuCEcl{_+DoRX3i5j)w}?qF6O3i)rQ3Z) zxmJEg7pX+gxa)7-2p^u2-Ir{c>ZJ%b>xKQVI*iQ5e9*KLq*oX31vW7p)qisZ!8&TM z81>4FGM+2V+sp*hLpyb0Eg$0q320ORAOwx;MbmaaSB=lL?zO+-*p8mvXs2BIq6T%%Qg|!~jM?-hF z{tn$!T7-P=Y0xLS?3TW;qA&L86##`w-*KNlx3F2T9VVC}J4qdfW;~2to4>YyGdxp5C^Ax;{^-Xe@ zw$R|fM);l9%E7&0skM7Y-!u!(=xmz)l~ExxSw2tx#XAZt%@CM!+&t(asajgG3rPAC zynn!ciB-+x^9p(Idui(T)qa4}6N!f#4YHu6_Jtk*UIvG>4H_Jl!R+L*(uy`2o=sAX zju!0;{V}jB#_&VqOK(tA0BCZ@uh4>4;GaT6H+r94QFVA1X5sZQBuckmzp$$nK!nWgjSm=6nu2s95kh{eiM*c1L3Ep7r`7&+ERT zi$loBy7ep3A@31f(dhD)_3?|0S1nhf$Q7-n?C-l9@*#2Ca&Fzl7yfmwE$3>4EIHSy z&Mp6ZPQZtw8gWpwp4B%5Rqgab`>M=40xKHW(5yrJdPr}Y$+n$-LY4T;nxQHEVPeHL zTV&-<)fq}_!a7s4P-i0j7UyQf#-^H1w#?bt3?K0?kZVv%6eZMXchQEeWRE{ zAE6*v16a=P;aTN0BZrgUVTk3#p6N$32TQc0w%8nnV|5fz%nrdfx+2pHF_gGAn9w@3 zSxQP+(luH%L$=@t_@di}xN=ExyTM@49mMgaD4?y+Zv#M2f~`Y@}jMML&Fy?J%K z1b5l;3Uuid1XUdq`SYKxE}>2c+R5(W#`xUuY&x(CCnM8DO%GEnKjI{#XLf3nHc?zN zfoFQAY0*idUvS@1PA?SZTsix(#XFGpqUPPOjni)#qOThB3^QUfoya@ZjTvJr&+*U^ zSAuQga6k*dmc^%ifnS>A*u~c7Iq%05c2*|Rb{V&PaKa6M-6XLEb!x!^uq2SBN9X`w zGJ9lHGz%o$oW7D+)Ac7Vo9OXE06>El$ZWxOIragDEz`{_S$ScT(JMc;NGuPOgBXJv( zPqE3SBsk__~@(;OG5`O!%3*#P!h#vrbD-M>GIdy%fs86riV&z#~P7ln%P zv-1gh#Bf@N-Ww^eb;#M!@@_P6<7U^|GtCkA_rEZ9Y3`Kb@=!f;u22Ik{s%Xwe+yf(!rq0z5nT|V1HXmC za;>~C_LQR4wJv2ke*B>+{J_^m*ZKKJxa;R1)D*%w{_^i{=C9@zSg>69@6KNh=$RsO zWoAeC!W5#tIAeJo#w&`ONCcVu{_`LwkMg>4ojp5-_uRDBn0f7RMjj&!8$G^%Zr%5v zXXlO8lHAJcmLAS+`~78oUJ_Gi)t8H`&VAko6TREIjrVv)p5y(pzMebHh1}ce48wZ1 zSk#wcaO>_O_w*i|I@6WcJ;n3r796Jci~Ct~IUh7O?)192e?O5od0LCrHzUS^V`6ab z+*agDF6R)9z08*Psb+#u5^st$=Z%f{ZaUbHzvd%j< zl85jIdh<%03QYjpTb{x6Lr^q%fo z;cn+$QMzq4V$R6rnjRO8&Tf8s83vQtAzlpL77zMD-tf!%Ko-P=WEv-`C(mIyXrjQ~ zFCP}6eMRjF)|0NYJTLjzZw!vJfLvfgy0iPqDJT{QUA5kfA}As89ty{@jtfK2zfxynUeU!p8OR) zi;8O=u*a<&FDswvuE1v^tXz}%xYKC+glUX5zn9Oj6+REI)p7Z3VUcoe5C1dACR6yj zbkvj2IiDpPL-o9y>}pt4pZR~s_e8HFq&)#kO<25-(}VFvcZ{X~B`YMV!o7=O^=^YU z$&4E0e>gv~(WgHrI#@ajQ*4o@kf!xH&T6`RAfCDJ{zeVj@NUTbtm@+ktIr@)0j#m$8lVsJT!v>@YGl2Xxy zzsFT;UGcYeb zK*Xn?A=(aH6i=qf#|Q%S46AsBD{>v^+n=IiNCd^gM=Kqfu*zGarK#U>Bwi1IgAfa-yP zx?uDiG@0~mu;#Kp5$W9G$3a3xl`k=6qAX|-k`5xtpCwZBOgCvbkNESWN3O>@MPC-Z zs%Y+5MGhFx@4B7b4?b(@9r-r+1TNmq_u#vd*Fsk*-*HR0rXEq*LN5-+M?VbK&RA|e z?fOIefpJ;{@2{1=GbL|zFY}kIt9RL5WWA1wK@V`8QX-=ANp{M;+0H!_=--no9;V#Xq3mC6~jHj~(M zuf#&8M}uD_#fq5oyU4ww&ra#XwFt6$c&0u6=SQdZAC77G>inouTma zJ!UP;OR~r3P>qpL39so7p#*lYdnHysrdJGBTtwbt)uh~8&>^1p$$WeORczVvi^L=BsNG!=!qXq(Rwo*^bB|^fStv#DNqel~NDri^I+}^gtjLF3b2+Zs5 zKF4jq2(CCF2ztO7h?}byuABAochNgnzVC?hggv%`GhwTY^B4UlbBPtXmkew19YO$6 zOq|Sd#XoH|gp#2XdNuwa0=t{3H)fES5!D9 zYhpI`(UbhjA|WpvuiGMUr0=I|;S&|ZLaf?xV(C}@DcBQxT;~?PBGyOfTS2zk)J1O4 z0cGGyE-RB^IeX@z3@CG`a+pyCOb89($xeug(1c1iw^&!c+a}*3-NCi#do-mNQB;@Q zs2%Qoo9u{gvNxGWa`QAs`~{iN{!IAwh73M(Q}^_@gjZ4pLB1s44xAVHboTeeQ>(qJ zmXyS(;r@spj^~e$Tnx-u!WVNF6a!|6>+7WP@xkIBq9U!Wy@gTth>G%cUY zgkjP^Zb}cd%E#{P#d}8YIUS-x;O$f=3%)QE=EKCY&CcQ@VeY3r{6ZGiGW@44+P={d zJc539G!|nn79ua9>kBSFf83(D$l329g+nOjE3~~NHx-w78{bOIS|rlBX00DtbBK|A zPSzZ4ra;@H|8b(G>vE_(cRRuSWp_e})?D~7fgkBljY6uSYjREc*}U-0HhK`Tf)Su@6Um)vqHDWrC4PMlO8 ztXFv4pife_AY@{l=1NL-l@pl5C-l5ujJXHnA<3uWu_Y;`@-FzRw0^(jWbecW^fdI} zN-(46!A$Vw1s_dO1rE>T8k6FblJlVamR%N--$1QF8~-`~sy2?4uw{h>`zmZ{dNz|O zMOg{NFOrR*w!dc@*)nd(ezK;Cb{RQMZv(oddGzTH7n-h9%uPA$wjv{ynMC9+qI1u& z?rq`qgCbd*J-NEXsc=nc>zOP+K3jYy?&*%{JCul+k=&7NidV>-@W(I(>m%CC#Ck`+ ziI6>8TqFgGHLH)GDu;-0X(K1`xxFfUyzFSmma?c|Ne-6h4`Up0!9?Qj#E~U|-g2)` zcs0B5#CR#G=W-1Cw^%s^Z-QKj{Xl*;y=HY*GXyBfBO*yz#VvC09lp!I@O%H{i2-$x zJn~-n`~2FjNd8aW#UYMOMbdhKFa8eeUh;RMKDjJ0#E%ht7X2vSkjzNFQoo0XeM)RZ zy9xRd@h!qyjP)zGT^BeIqfcrD3^Rq_b{{??qjZlv5+3&>-xl2;@({zr5xMII;`_vY z+(PBFUQ^AfHNanaLEaswzIMSFA_9-x=Q=xsu8EQ8(SDDya8_{ z-Ef^FdXYq_5cpkWP)nB=LS$xP5r5=8`aQx^e4iqQY>{<6b!5z&1A60tJ^h9};`ij- zQe(wE>eM_9;6)sB47qx1v_o^8IqTeeR0)dfu&7zuan3BI_(iVHsvIDZ- zsi{Dp7X8kDNIh}%VUj17a0+Tuc&Sic*r+?Wy6=^8otJw0_VM~5+J%@h0d)>4yXk9~ zki(IhqwBJJaPMhqff(E{MVq%UF7Zuga9|hxOzL-(M{X?*QZurPu$V1k;M{R8b{;fp zJH;nMcUeo?q$+)u`W?ndUOXvZ?>t&EHgzESsKc0sHWrG|$XLNdoHzMAke+(P35eH( zeu8F-Umgj|p0L`1==|1xI+6TcuMNF2FU%+ZTchz@T*k6CyqTvS*k`$XPb?YRmq$>h2&WJj1zwem+#q;ST#g>+5efyM0CpB5lg3+ii|PHXbC8X zCvxf7FGR-4v&H_=#(z$*LXinm_HjY(Y%spU*YGB3UAA|`{J0@V8=phJG!8| z6XxQx8?w&PMPf{^s88g@NLeudM!6hIM32L!R#>~_Sn-@mn~CIKa=;{0$ScWT4=)?J zmsjAG97`&pq35x_V)(5VKR>*AdQiNYcPYuh~m-2GU4x*HC8zV$HkaBp2#LBO>%= zt?~A8+XaH3?%m`FjvEJb@CGB}nr$!7Mh`TO-LvS8J9wKFS~%nN{jdd3f{H${I>qL< zka1HM9~FUGI9LS%5Zxx`zG$-#uY34Owt@&la(MhX0Y6%Ue{gI_uA@&3bdC5$3PUs8 z_0Sx3u(uL%iZH9>T0*h#O8HKP8!>PE1K3Z98N<7xW7QM^7I@RZCJ)m;>J+EUp3ogA z2EvF`fx@IO7{*_Do4qub&*)=EiOrs=q2RoOtnp_AAi2|!@9iGv`c8p-(QQ$*#ODVS zqsL%%tU&h@e!y-Rr|9}Dyhp!SrWT5##?7qsPG}F_M=mbbZ4O69p3#_B75Pr;;Ho~{)%W)i{KW7(YL&>3B6m1KR~`RpVvrb$)DwWv2L%h8ze?e8V=`n zH+dqUxFkT%(#xQ{3TeA@;)DAU`MAUyenuSrwiOn>9wylD3fuY;*)!#FXpd+>$6E z`LQwiv3y_k(VY@CsWs9Xv0VInL*nk`Se$}YPONe6iXJp6IkC=(XTJT&iFJPE#G>>+y4mZyA?~TJpT2#^tFk!V z7n_=6l}I5kwFnc&bfJqj+6LLF#R$$nCZD))F37t;$5*1{Qgixj`Gdk-88Q13dmBBl z@*@girya)NgpuaNc|%LyQ2;a`GW9LWL%XGdlx!p zQRTxja~}FgL-@XwTkRgzyB~F%-v7nkn{_9SW?Q>I$QNK6+xSwidO&~#qPwG&KwwN@ zTfnbB&y*uFBeSx0?R~~R7vC9QjLI^NQ1dpcoolW#9xxQxEad*_oXT>XvgjQv!l*=( zWBidhArPVbHF#`-MFi4`<=W_yE8*%JR!WSe>Bi3J|l=j7p+eXs(3;;`tnCf*y*MPP4^Px^^i{EEOi5b=RsBHJwO zr!l!<77X};!HmcerW`m|zye=M5YB5g){}AYEr=U%*zAYp_@>A@^#ps^-h+L~g?BU$ zvimV+eCn8HL*f;2u_h`Guo-`M$vsfEoRy3Mp2{Tu5`IgrEEOeR^uY$l5$6r0HvQSL z3tkaGLG3AcEyIC3TaQRm)rhp47j(cmper$J`JE&dPh^6)bBMrxX6#PJ+3si!eYy`@ zhCjdxKp8?Kfy_agq%)%I^rH5zj9Y-7P!1ntd(pxIE<3OyVD}7(m7h6xN!+Q&svODu zb?LN@%mqPEA72jKH}P!NCpB0t?qyMcFN8}#z8SMd{kfTNOpcv^#{n*5-&0gdwacae zzaBj#h$kBBYV2rtty9n|&gIo}g@W`jNuGX-zC++y4lheu{|9yY&elx?cSz}4T# zDcrDnG9QgW@~TFh`Rs0S3>|j|!Wnr?S823?*|q5A%N}u%E9?OPv9+fO7_0Bx0%!Fy zDqhiX)+!jgDkB6j3FbytWKIPeD+W|9IP8e%%!;3b?&P4s(7J!cYvB6%a%XKv;8y*l z7}tNpz2bjyB)C_s`9E;4HL$sp=^6~{wuT`DhK@W5i-%gx@DuMGoVem_FJ+uR1<1?m zxme1-du$UM&Tqb%VdM~ZxtzzfjA$!kM!=d`uaCcQZS|faQ)_f-j(NDqwMTTrMzFZ# zxy~onQ?zqIeCjz0`^-LIz3}_!3cUn639Q(zm*8ju?63ab3FnHz!*Z!`t_e8ZWV&Ep zG9(M*8qb7nvYa8iuw9b3nt+%bsLqCw22opKUdO_`5>w*`ilqx%Ow__jLaMCmM9QeP zA#)?MM>!(l!ao@aGw1)P2;Trxv(g_y8%HTwliwjfh97ZaEg&ff@PkB01~aCU3vp%6 z{=i;DhpN*#Woz?j>AhP!oi0o!S6FLy6y>d`+{4f20&LVHvXMu$T5XhN?qk1Hb;+L> zgI~7GhQ67W?S0ld=FpiCIPA^C>Ti;VGe1GLkWxBY`&iF~W6App^O}5PURl?t8MjEV z5`i8}(7I$QpYf|H!5r5RlN^Cv19BeBtJiFhu0U6Cyex$N1NQ|#39^-VGBMsk7_Fmn z8BXMd;YT@G^~ek=Cne(mk67Eh{Z>;&^0-yy9k`$q+u&c91pUCjPQ09KksTBPHbSU8 z7DFsPk8Qh;K&|l*%eTCRjg7is@8XA1PrT(C`9~65T2m6P3wstWy!-w@1a++Zr~E7J10ZmX8Yd&cLBaU)dE5g!+ym>l5(C`lO#&$hs$GXa&TxI>GpzFQWP zVVPFnm0y}3@I2XNtU%q#aVdBvDS z;L)i>vf?oe8b}n<@;hr1l*G2|PF71Q?CVq+83N^CUjZHOg>fx5_)kCcZxn2DY9a&h zLTrd8_Q_RUsb#k&UK12plBKd&=kTJSXRR^!5t;(rZeu9TfSY~uF06udxj_=jz8Nil zFzV==3Il7Xm#1V&7+A>${?CgKME%%7;II~CIe%eb@ryy8i;E-SmhAA|=jp;Na*=_F zl>P^?GEO&G1SHz5DRy`2c^g2LNxUL9F(8}h{ldU9{cG|sf=hVh>u_WCNF_V7SRcq% zueUBhc;jP!79?0k@l27CyFBjPC^pw5!ENvp2Wu&JVkc{X=TNTkpoR_AIu?fK7 zaZ}~o!NguzgXUBDpTj8`|N0Jx%-vuK>O1k!9X^TP|NLEpCA&G-xVZV!Js+-robwBR z(At#W-_7*h!yo7T5?{)u&Uf}bstA$LAP*27lPCL&x?VIxg6stv0oL_Pb!pUI;maEA z#v59^LEI{TBd0S3Wq2go4+6nX3w)=tm86@aeU7~~{>-{Xjh*ZViUM(lTZ7kz_guaP7HfxOKysb+{AOLG*(=fFvY%D*HFy>VE{FVS9b-~9t z`5Q%`WE{9pL;L@>SDts*76wxT;7I-svG?xJpzp#Gh=ppdNmEU50AOK5B zSP^_yeh`QE8LdjLKuCyxw6=y}e4cV#k&$Z)viI1tk0YJ4T^`$fo+wBXQqj(x6BlOf zbAJ3}HeCl zR%DBP0QGLR8?-S?dD_m8GpUK{ zoG17i&I9~zhBes_fIVBF2Ps&0}DHbYyxEswhHxeU=lHzmfnf5 z0t0ZGyGvAw7VXkUKxyv04lxIGLjQ!GWzGKij~&SD^RHri`2IXRZ+;zsM~amA_3r0i z`4s>9FRT1Teco=>lJr8TjU9$bi_$w_T-_<;&%NX^sbbayrsFZFu&^vFN@dt`x*7f= zTQrgX7YqA;v9SLa3;X{EENuTzEG&k`fXXR^ayGG2xLr$hEf&|N5SA4KoCtMG@ly6|4lAe$Si+y zxf0r3NNC2^0WxKz6-K!&(+z=%X*mpS8^Vb%eCcopYNu_ca}1vNi>enqe2fB~laO$@ zMVTVAmy+awb%+rI!O}ljobBlA8|!**V+wQJu`pOj-<*2!1$$O`%+TBse?rF{$t@Vz z$KsIWE}o-D{k(8zK*#w0LgprXE?o)?vAA+oxS=Lu>+Qe7!BVIA?{Kh>Cdgt?Huqcu z;!a$D49WZ+8GqTM*W=07>Re3iuNc^f3^|9RyK&0n3;2l$JMS9MB4 ziGgjyXXBR90#;MYe`4j>3p{Vmfa^z{F)PM8q1TuL;+{&eJqY(XHMMIqD2Wa4K*C(m zDPT$&WY(IUuBbQ+Szodj$;&bh0k}(ip$DReNlE0YZ)PUta4K|IXT+-0xOJ|F88K4M z?cq)Bn8b^6=bw`!%a2Gt=w)+ptqcF;PV`Sst2hXk1{l(2nQh3diE4HuCtwz{_i&c+F z+?70VOs*1JP+Y9IZGYopJqkx%oNaR0aM~~mH*$n~Mo3`EI(i#x$r^yd7bcN?Lp<=!KSmoq~jVqu?s(=p|g2@(MSfcUZ+lqm54f`|T7|0j6T*z{qrxt5_I@+YvD_7p zJ6d%<(|9z%m*S;+F?rEO3xWb`xdb~GdQV<_8Ca(?r9ow7kupZ$HO1@oq&RBgRQZip zroZUh3H2s+nnFkiPnzuiWz6R`@@-U^Ec zil2D_Z#eR%$2KdS6djoJLYVUq``#URRoqUMM!tL867u-3xmUH0i17S3xL18~_qNSC zA>U>{ZV@jA|B&T$hT{`HR<=6&>F>O&v45dnkCldz%ve0Ln$JDI#|qB3+&rfFbWz~< z!xc)*9jW8Tte?c$vdQLT+f3&XkFk0#S?|jQbsigXz(f(2%Ya0d<({}F%HpYV#%E`v zSKO<_0grwE0%WB`h{Mqfe$DCF;$#G4!L5#CI9ZdD2vCUJAXNdkM~Dx}dFFuWz(ycw zwlDq;S}x}We3?mC$d@H|=h<^IJr$>BNxUXhV2#>XIjJczBSsWFXJ8eI$Bc#L={&^Z z_u-rXADuVyMqqRvg^zUECrxeW_{O^Rl>42qPl*fIXM6-96IrOH}jDH6NB|L?Z7p zC7wbskj2KHb29RM4)c>s)~d?tn_Kl%ajQ-@l2rk$@eoolcIg!4Hl_%1_3Sv(`u61a z4}R!goX&GB-e}bhm)1e^uk{W%r=D_t7hH#B$SI9wgoW#n{90sn;g{w!4-8}*S%iko z_{}*jtYxHJfzm(=vOpZ;n|rRkww!JTw<>ygPGo4A!>>A!O!;SiRq`Ca@vH7v|CV3% zi_EtDdR29ik-Z1MD#>ebxZFRHBJ%I#?pZNVB_dms9oNTb4ZU2NoZL)_qnw{vNg-j1 z>CUAC0lS4I{}!9cL^U{!K@2x75%*u*s^$yH6~^#wBeRk3jN}vyfxo%&*hJ(rLVZ3Y=>m#rh<9~E!Lm*D+SExfN(Y5%Vm#%2#7EsG@1v+avFSG|3E@_3UD+O7 zy;Dc9<=~ak*+>!K>C~#zWC0G)FL?5;`rw{Sqj3JJ5+jE?rvBxv!RJ?4*@FvfX`7xJi0JynV{!|>C1GSWhE%*uj978ps_RNZZ zTr(Pp{Eb^xc<%Z)&hrm$Ri5zg+^PzR5tOR9RZo}f4k_ZZeM<`G@OOO17*tI7xa-># z&Gc!d$bz>n#T|DUviGOS?gp3r7LNJ2d^!Ed5G)GIhw1o($Ad%G+oK(CVYs}E;O=9w z!#hAO-)B`y=DnD_fH3a=OI}q2ua|3$KS|HQ#0^IY+L2J#|M|SCBM<+LP86(2isQup zYRtBw7_q9HaWygKDFOKpa53tpM~VJ%of8AjHfsp zQBcbJ_`z-pzJjG-I&Jcyv&ugBonLi;oQGc(Sy9Kj0((!r&=D_{njyIOg~yD=vFazl z_o_HJt9;AG!`T7!DJ$Ld9x|}Jc)sFRl}&k3p=nc`_*d$5kTMu4{3fWp668$`@x!g4 z28#ssO??+Opocd~{(?P+E%1w5wbh5u$>1YZoT{)pI0tlt`0B8Yyc)HQ$+d^G z{2Qn0@lQ@wQc{*JPSr-&#ZOUP>qlH+Gku4?4HN|;2F0mLtp?&A?|>(~n#Mq!s;=4P;#8fpL&07!&1pOUReF!n^d~mcCdUcR@+UTPTwycevCQ%Zhy3z(6+1(g z&;+wi_6=UI`-@){4oF@H)0h1-w`t9v7=3Y6wMq5IaR-%HvyKh;Fj~~zm5e=gH`|Ul z>qp>A9KQjME zZ=5Xvd;G#$ZYr!LYsoKvVF&KMS$H!T-D98ui)nqY1~46Yp7#cUVew zTe!>aH-D+d`5$qY8dtkyTrIGeS+T}ccJw z@4qmZIQrihOnih2gBc5Rsu~DN6(Hg$n6C)TC5Rqa`76xjo0G3Nq?A8DfDaU2V2Z>c zaRSUOPO&j|4%-cMAaOzR)Y!D~my;8mv(9U0QgM``Nml&k%$kg~F0uREamCO2B5qu! z9<;z;u$$Z0K@+r?p_d1;858FZ{3a0!=63uOzqx~)1wKe?CwWCLf$eby7qrhl9tzwL zhzEI#Z-DbNRNX9u_ZWRC2_vgR4Ky;M=0B4mFCW;iukv$j)|h0+%qyvHG|g|{2YkJD zS7@3YzVm-#GQ0QlpB5!W7{ zVP3Qb?q?e`cwFAACPCFNvpQ8*)_4?5gIa|-MgACUvusV0cIs;+$Q#S%4%=Wf`8z{n zU`{fYcUYlX@` z{?1oQJb_x57WQMsR~r1zR~mPbF-;J5v6r-ouek7X3KzWf;L$+dv$C*}EfiMWQB4JU zby{IG%ecNk6E3O!mRT>H<|_gOo8w7geNCro z)`DonCwHyS(o22qomz0cALDs-sM8@AJ`I}=P@o)^L+X#p$HG6`%?IM5V8EK&X(o>0 z8JL~qKunXFF(2G=ozPD^*)3lPAQTJ}F%`7v7r4tM^>*M{s15W~Q}iB$zRAG#iTSEM z0=7V1n;n80Bv`ji?yRpSg<7)k3LM~UIAbA8FC>o*rzW0tcf4Pd+-2T#)f&N*N^Aj7 zNj$IihoP2+k|@Z3V57!PS8Yv);Kells=*Bj&d)5nJRh(EIA>ISLJal_OH4%98*~LZ zM)$yLR>+!q1??!QAHrLCpEc4--zOoSGBpg$uDLBKdF7b#mSdR`NFAXr3{tb2wt|LR zgOP-A>eYom)yX5qz)bkut6kU&-hf@VR9`0_VX#Xdso{~g4C>libO2zP9co=t`uSaH@)4=ctu5=_s5yBn-KdI3&M zR%S`gmzux`K3^=p12P#`jJ=+8M00%En23!N+6>+nb_Y0~>LETHT967ZxK*d-W6&Z9 zVNO)5G zP$al)HSY*fl3O{_yKs9#8|b+v2Z6Y7ynD z{J@~_F2(1g{|a4>{@_%lS5M#U9QF;d8XIgWHVD{Kj$J*GneVyy@Wg6|HnuOu6}%e6 z#k_E=_bg^q8k@pfPV^p~e26c)5S-bUMH_!~>k{X0fzAo8@TsVjSG_vf%7+h%h1kAj_W(+~65EJwCV#CmGOUkp9#v{RxTttk=QtU^^Qa=f z;89Jz7S^tGB1}r(i|lFT-9}w zmk87251#+{3{HMI!``9j>S+oNGj$VCHR*N>YXVoyU6*-==bsy4|7@6LK!$LrvR`0n zCM;3mNxRnO9kI!DACX2YqD~~f`e#Wd;JY=ZiU0LV!@2A9w^JNEZx^i}8M!GG7+m>^Z8q@6+&P(esaLa6n*))~I2n)iJiVOs9 zwm$J&ufldJj%~j?aBs7UtnNHaK)gUKex$UHUSr%tU*zs$$sO;BMQ**ad|*`6u<(>UN_bQBOV+2W_Z{J+9}kW zhaMeY;B(HXAwqN_=aMP9)a@Jfin>qH9mU8XOL)?db;DdnJD(8VS{P!F-pq`@i!)lpI^`kE;wUqdmtP(!iy>| zMTQV4Yb%>>XgFv`15YOwb|RzaCn8O)JOH9&5b;A~O|h zaOB?!_e??s8^1L?xX&@MJL(FaRSQRt4;Ie`9|;5I58UYcftq?_pLg^Tr8_Nq3GDHS z>$T?vVbJ(T_{=Hkse#U{@XGB)0p6R>U+^!?9kmPur<=;rby@_gZM9MJoBFP%Efx8% z`tEIJ&v8Vr{eQ)Z5+aSlxTSd)&cWT!tNuV|b zn<`t!E&4c7IpmUQAQo>^f9=$?LMb?0aTOI<81biKO*+JhXDEF3EA(4&61U+r>Be*N z3}<4S={>c_{uIlA*E_`r3}ORYR!QoSR%;6gCEsILmhd~Usm&63_?ofyv~J0T((4)< z!-WeS+cFd%Dlx8pxa|_bS&Ae+Uf?8vjwBwR<}a{ByH_|>Z-}j%a`BI^1M$CDix}-h zoEff9Ro1YA9pzWW#64JKFORG>Ub!U;iFiAmMca0^X7C6?#GBG3S-#BT8gh;%BC^*( zoD%iIq`_|t6K+zyMu_V{PfX_kJ4WYTnu;|@e~vhga{TSkw&@eZ`(XA)T#$~)CQV60ta4x5zdWb(x7 zrH3k5M24995O3IY8@*HMQKJb4v*P#b48U~;RPAcNA7XJWTK|c26?y*o4;WLo;TA42 zdi&m)p=0b9N-`L3A{6vkFzSZflk&4V?r`ao4jT?Jk`yP;>D~Z0Cv9ud9&IElGxj`&aF8$2Yf$a!5Hhl+BcVX_e?KtcUbVK-35?w z?oAt&qfjZi)VxIv0ZQ zeQ2h@$G`bj)ob-X@vZi=w9jw<#<$vn1dI}tzw@m+Z;Bpq<^Lwmlsn_x?0^1+STsNX zt>Iq(`A^zkM)Y6)!PWDx|46C-`tbPn>J4eBMab+XE5SliSe=Hbu91WtQp5Z;-kgvq zF(@-OR0nVm3Qp;-o|ij&y3KCN zoQPF$`8iyqUa^Ejv%7!Yul><(HA(jHvHqA%T3oCAM5y2izQSHK`+@^DhcB~W zZ9h!MtMlRkzE6b1+e7zj@z$+RxwnYwnW%hq$m4CIG%~$!?$sBEb?@vlSr? zE#1QXxtYCjjxjAf-&*?xi_{-$)N8Px*ZVJ)_h}bL%Yo{_xgtk5X=if(-s?L1%^-bB zlToV|r?h`z* z`^9ip%qB@bAD5RX=~z1*ljUx;_A&8?kFDI!@qlP zhSB{i>+#-;hTWTs#IIT74ScfuC0wxG+p?tN-7a%h`*pyxK8`-{q&CsVGNjWXvgU-G z(m&8!+FxQ5Bi*CG3q~o~2lkI-7dl;kVc6Ez8rE3?ez^qGiT7H)cBf=*t_a%`t~#&# zs3N!%^xDQBAiM&ngG~0}l;hjnr(U0{*Gf>SZa|$N=?H@mbnglOi+@_LpLvZ!#NErA zs-eRxiHK(WpJ7ci1EvG4X~u>B1J-mRtSNfMM`w^7Qh1b%sc}c(-u~A) zS7Ra~{{`o2k74_taIQ+z{N!B4;Q1@(D!fkY;&0AXIOWAf1Lta^!j!IelAI5eZDbSx z;Y|`f!3`Dmg^U-Wxj#8qqyNmgioCV*2r^#SL$(#(6h~k=V(1}@rsHRc?*mnu67_^l z&^dP$iLDstXojWbdblu6>?yD+nXR-jJt+|n zF&z8@{&pf@4d~8<`Vs;j=Q^u+feml9C}-9Re_^dATG{rANC}V1CB(-`h7wMMpw=lj zRM=IVr{eykClj}cifJW6Day~KmKepS7&kE^@%8Neplm-iBHY3Fg9KgHX~# z%!a~p1Vq!A%LfT5>=)VZ#G1%q#H=*>20z%Wy%{dIQjtV4d7<_4@!)TH=D3QH;t;{H zY8p^Jvj&J=hE$Vt&{+87UmUAoWh#!<|xH2=Jky&!;d@!of?v5)=Uon4YXXqC8fE8d7y58gqlS3jw#PCKO8-Z%*}w z&ILOiizA+TPw@8OdlVJ6^vErMV+BL{&eA!nKi6JpR!&1~+E8x{5uCazQPGqp^b>>X zfiIKS3**l<@yVn?0J-Z2EtlMIix{mQ0`v^K(4?Aw+!t%y6pF*d9%I47gWt7U ztr@?oy)gxJni?=h7kY)VeEnSB2iFFv%}pz&vUDdb=f)AQx25-U@Th=&P+Inf`g3vq zdO{uvr*Y615I2Fx&odLafBWDT;dRWm(0!D9ZJf9p z9w4Gcny|GxmPkI{%xG(oi*~onx7H=Rn-+OpB1a)(GZ$=H+^Ie+<1GZDyxbjYlfDax z41&cdNY`cqCd(N@4nP$5M-NPE(C#bV(*i4}xnZN6F|Uv+89W}YYE^Ny!yy4$6V#Ye z!q8kV3CMhl+c0Dvkon9-NN%4#;Pi&_Y|`5U#)Zq#o0urPxo_l6T-0Kz9<8B2rFzXB zDPX1|MQ_V}O!$b9$2w7uZ;N;dwdTk~e?S}szP@E1LZN{b@fm-CPhgan@QSiz$c14G zjNR&4mzu>XV`u=0P=Szv9gehPU0F)@oKa@(+xe9Gb4zhqLWT6l>2tvc^>>N-j8&(f z$0@$JgSU*N*kouZ#6S4gE4>H_?h$bd#F3|U*kZ4>!Q-*c*vrvK@4cMy8!vonmw6s; zX>eCK%FWOtMHm}mX!LK7RQro8O~>vDmDCP<6~;7g!aG3GEh7&WxleDDZ=O`}x6r&} zaUm1oOa)^APE{Py`>UV3bM(}ce>pDr`y80_lD|I!1D@K%^aGUaGQk==IYeBcu(;u+ z+-aGiu4Y_^{$}b=fU^fylvCkhC~rX?8~YOWff~n-{ZGr@qKkyD^)8tvvo%YI%uVthiVf0a#X7c#_MD1h(Q&3!c;V@_B3? zhTwSBw_Un|hd@^NWY}neiO>@6l)COq1U7YgVSh5VOS=82FX@pw$dn=kk&y_4x)7hL z@Y1Dp3fKUyU-7H{Smr18Gj+mqCKLWhn74?g$3PmXZ!h2R>(AgRDC%Lk;AOovGK9&+ zC;wOInmCGS%otfVv%+45ZMX@(CPF$`wy)?c7BcZ$7UDAXkUf6P!j+0y4r^tBSv?@x>6=G>a)x%~Kuje)@~p=NLSXC>+7l)htb+PrS)CxH{sj4_eMRD- zZIH4SN^4SZ{4$&4CxD}nuxYI!%CoUrNwPXva_Uc8DJ^4$41G+LEV|NhwAZkfZ06_* zevc`_kMYY^U^e+%2e%lXg5;6H)nX7NY)XgT!f+w7^oVDt8>AcC<(os5W~qPSP@RA} zVhm?QTpa97Fz3id6OGBPFJE670}?;1qZDj~u|JvNi07>8`_FS3WoRa+mB5Nut&`~J zVJE;c2{u_Rld+F3sm8ZYpXWxldcc_v7pl)Z&A{0L0s{YvQ(mOtn(E7ueS`x+v zWCR?DU|U>B74DfDPkek`XN)|JpOC-oU9y^Cc4;ipz$^#;TV{cJQk9s0b}$~ka*tnF zOe`1H$%H67yN8;0oW2=r&OH;)CEyCSRz*Aua7kD^Ww)?CM%a68r{YCli1!!V8F~*w zFe_`em`y)8FHn#v7}~0~gnjNF;CK~=S@>)|Z>mm(f5O%_*d5cVj*a=B!21IBf%73> zUABw+*#Xb8`o*ORrdwLcdRS)-bh=A#d8`Dw^@2HK9L^gucLp+4u#mKG5>Ao5ZiZlS zm5dq|iJffGgXcN2eorKqyV+3-A^z5dct-mUKY$B$W3$lMeiw2YBj*p^Q(w7H?5t+T zL+3tfbf?gAxea4^Wo06kx;|DCzwi+t!WXbt5Xzkb&z71~#2(a_oEOyjxq`_ZG}L z&jNN8xj_@cPh_ZK0X%bf2*IQ+hhzA>-Sw~}wVwh5^4v3`Fb9WA611^JS(}s(h*^A* z9Qpc%tz`p%r}fZ@l~80ak#dISgS|ym8t(@e0+O`>`WgS94`p9qUm=jcIZB`CWeskg zUgAu#FI;Ob-5mwk)~1V4$9{Ar!#smbCnXCO7QLB9#1fQUU`-Yk9(TJg1M++o23N-i zaPN_^-4JJ2uXk2XgPDn@^9L5n6$j(eBm$yxemEA=juEX%1b!0vj%WeqgjMxKo)2G@dtr}})Ak{hqZPPq%ljhV9QWBHSxdZPy(~g1^e}3L3D;TUPE`i`5bWHByb_FP>D7RlI5bBS$JH{ToLrajHK!Qt_FHwTsjCcb?Q9axomD zM{O6k5s#f-iQjl!u=__lX(PPhV2e51(mHTW(>A1wx4)PG2d9B-bK%(`j(!yKn5R>|G-rj(x(lvDkPWm@fP~$}NE&%B2^0?q-1hfO3PIYP#U&$BJ~qZqb0{NOzOYI09@87~|XMt=30&v#QT=mm>_0l{W$Ajb2i zZ0sg1zr2>d|HY9CR(9~ez=O(j|AhxtQ2_HoSonx9L^Pso=$Wx3j;4c6^UvI;l)6{E zr%$Q&^|}8DmyP>HFLc%v5*@x#*4c|docyhLPw}r3U-dTlJM1%c4s)Ds@QHcl1hAX4 zGk?R(8o%N*H6?twa>ce3M6bQe5^1*Uyksza8tj7^-?tU386#1_QnOb;=?Z= zni-0bpL^C1G$%ZyX%N{Tg9l_%V(ovxSDLxoV9UJ?TDCL7cEfWBVNP`YrI}4v*``K_ z3VR!o(>U{Fsv&HX(bOKxj&T=P8w<0F#|0Gre*$l#XL=hPNYgH5@4#8*h`w~G03<*i z^ji;kHKaFX8PQKv()Ksr(q3Hgma+@QTdHM>FOB{LB}3kr@@{aRAS-^y#@c^kV_p05 zJ2sY(ScdIGg!~&HTM8dr$TU~@$SHjYT-`Jq9S`)`<+Y*!yh$efTT1Z#+S!&WjK`ak^ zSK}QZR_QrJ(nI`C+bvezri|5J*jVeogN?0YrNDIs@*k!qq+5iJonXt0&va_P{^T?D zV-KV&MOsFDrV;{-<`OJ~!R9ToBe+9CU*PZJl5r%c;=ljmEyYHucuR?oRlKF@;X!M2 zPCz`3J;R+-`FSOOPh1tm;5Sy4^{@8>FwIl&JjLe48ixrjNRy=lnjZv&^#aD3M5;BP zuCZl48C=J_5$p7*6+MDe8?LiZv|H>ffDT|Qjc~@-P)~}w@+gT>q{RnNAC#@UP$Yim z2Avvps#pQCCzdnCx2j{~)&@E2(Q6h^cyg^`V>zO*iW!qr)?sUI8visdr(72~33k)E zr^f?hJdsHE`44D}07Kv;FCo2%xqsv|&M|eV3u0qzC0fG0f*}moo6L1HEbs0T|?RQ1OgrUTA2+D zs+cE?3HzP3qn-jBm*DFWUy)qEKVS@L5?aSBCUydY$P~;cTAH|#8Q`!KTl%EMHjH9; znP8fRnSXGLje)OK&y9+vR`b%Dlr#*ft5g!;*gelsuINpLy@x%9f>b@Zj0c{Dihkn6 zIEP$+5T3|U&{v$%!OAMOcwX3|C(~Y8^X?X!)>rJnz`AT=Sw zTZcN8Ysrm>0Kdne_~tpqrx!;fHuR0Ee9+daYNiFXI8JF1_y@-+Y-x{W3diZI)!zPw zLmkYXd=^vK&V-R9M2|$(KKsU@uJsV&#DTyMVIH8a^DE7uG*s?&+$hARo#g832 zqi;N_cWGdU>prHev-76*_6v`Sl0n%1jz@L;f5M~U7YL7v_C$D89H-Qx{)OWdTDI~8 zPaakMh{WGGPJbZBf8#jiyXS0PEODB^qaqmo6_3jD5JnLsFsAkKgm-;pTShorG4^1- zoYn0Qx?m7@oQJgFWXOs64%*=)FrO4W!~ z<=O?QPGlg_aDv7}k68c2t)|v7cfRgQ=MRi(J*sMm|G=m+C;PjQD1Hw`WZD0WQKe_! z%fLy9bAR)j&REP`>kod@o{v8KgWr_c1NzvSor($dveL~~2)8OcGveXQuQ~Z7(*$VD z{uRfmyX=BRVZ~aLLLg)&N=@m8+OBS_TgMS!398 z&}BB~_6hh}^j~L{f*(8(yXepS4bk8(cgn8xU#ss~>24GT#R89=9OP6r0}|Z_lpfj? za=LgFJ2o-!J3d#xx#Y29?>fk)dz7cA-sa#_#b?sbTHpMpE9n+DVT~JCcdZcr1w_VVl~0 z*#ujdDV*#B?NIy;gF0`?DAGCzrx{H^;cpzXCK~7?`uaTc8W-8Jzp%5ADVEv1pP2#O zwTI77euBXPY+{`yo;lHf$zRGXgQ3^F%S6WS0E_zdDUJw+276}l$$S{iHJp7aqJvk`!a;aNSX>n|g+E%@N1dV_t>zW6v zI}nGsS>cne&>Qvf6;dpMTz%F3I)kSc^$;UBLeLJN>n7L{u8HLsXutaOkP=@hV&W%X zX$kfeB$U;r4wqB#IY?UgWK!gOEbJ-uBqh2`bf9QDcdUE@Hr7Ci-aauKSgmzJ<~{O|9LY)%Wa?-1OmLZI&)my> z369CGCv-D$v?{IQzlSXiiN4)3P_Jk=+e;IobK);`r$4c!d+Q0?0x$Fzwshwlh_PWI z>&DK%_)8=4m%6#IGCoOZpwuw(a?TUx3p07wMxMAX1}QwCQqABbWMK@=dfZ{Iw5*w4!YRuSJVkv*=1oa4z|+EkDxOnSz4T`zTd^UgO zyvV^;=b|-(_c*JMQ1tYV!$R(4;iD=4&RzOM6EUWown|8R`2)h)m_Z}P zK1L=HPS=q9{*V2w`{N&W%g--t`g8pJhma&8fBg0K`|m&VmmcvLzW@CP{!;R7P%{zZ z^vMDDQ^4p)PmIurJWPJgM>XdEhw_)sYqN*T*OV>+U|zwd-tH|Y?}Isogt&U;x=xy>a0~>k2C#$Mvv2(F(-5#DEI)~rz zs;UktpRmYlc>Hi*v0f=xJv)T>BJ1_hbZhta{VVv%_-Rt#_eSn$Ff6=ia&xcK+w^HV z4#9!C*nY&m!3wp;VJRHvvTu6R9j#`n-`WGI8(ys!k*r<=uNZ! z{TaGW*i(ETK70KB4Ck89dc=ErZ_OL`iZx{wH~Ku+E@?5L;%_xhaFo+R`7S1KEUz2jogu!OlH0GW;W4oG{l3&WFeO0!ANrOK9=>p zu+i5A>nRzT4oY%EC%D_>;r?ThrWeMVYkSvbHvi9<)NQ?8SMQP#`KEs`63W^M{UbpK zAY<6Sa1Y=q8C)Xw!-NtjZ ztyfUDgftj1#{8c1RHSdjKGZ`-kt9BM$2dcPwzxUyMUceR?+GR-@Mhz(X|wRZnBWCH ziOFyz?v-y#t6zBBhaf>YEhLB}yVe>A8qvDydN-b3J$*4F55m0Fa2}NND$QTTEpirOk?WpSB1=gG9}o<$a2g@g@Ka3qLVTMH2~lIryB-C!~F;U>E-YI zr-U21{}Tht3*v#W%OS1%*g31V)F5L3oyeCls2V%ok&J`Zg8>#nH4F2JoS}LgJQmM8fCmAm2?79mm+~CyqR6dR7*ss7 zWslD$8g8ZN!r{_wnh?I!CE!ojY5G`a7i0vGJ8pzN1-hY>FbP;w_4p-?3o4wbTE&Oj z#%S5W838Z6AQ=r};}{rLbLe4VMGegNmNmqqg$z4}69ufZCd0fQ8BVuFI+ec@54Hnw zp6WBk<`WZ-y6k2{X1>yaw8&E%UyfOJy3&GlFbRdEa3 zijl=srER|^hK(14dx9H zCcbRH>`-fV(shMCR=qkjCEVtQTf}DEh_``AFOny|w&nj!UCszj3IX!U)^OJ|mJKlh zZ^t@=jlia6R~_l7|He0^_75}^48yFT_9PrLbK+`k@#x?DsuML~1SwM@F^FLUf|$eZ z!F;$g=I;ACs+EOr^@&+O|HQzOuh#h2#Ru|>Z#ACK)!8C7we{a0xK{{huo3yb$D)0a z31+V9AxV1=ViReA@EqoN60Z8K!TGASR4a)K@*AWk7@rdZ;m~g-CKt#$o0N&2$1S3k z$lJnt;+Z^yksoO6gy%cQkN}XeEQmwm#iHxM!xE~dpQLnreD-}lYr#E7?)XOwznZZS zUiS3uQ(Jxeo^vL$AHezKtN~(RtVs2~U^HR;!%fvDrv|KtY=kdWd8|-uVJA12{3mb3 z_t5|HI7>L{JTdq2d-9Use5nTqga}5d?oAljlWX1UcZp&PUV;Lkh7~vVRACfjG%G{+ zeUGd&JMBB%V}&;|vivOOTH%f=pDHDCP)`9CxUI#~^18VRvF#CGMPfEf~50m%Rf{n6d0scrZ_Vz%m)2Y+Z@9-sB0Z9 zw0`^&gCTMC+Yx*qB=$?Koy@xo3R_&0G@@nx8KaWeG5Z6-);0B|`iDs+(i^p9gRU}{ z7xwyW3jTBzB#3z=yY;4cFgXgiU>-9if|nbxcSJ*R2v`kV)1xJlcZH{i-^HAh{e!d> zJ{CR+MumR7dL4mVqYm{FwOOA>-{3;;P0(fv?|KNPFY3HY=giVf4euGuR`N{rlbfYg zDf<#6JRwMY&j8LO);dpkd*R+_@w>*2ei@?5*d`PW zqZ@fj3k;;kU2|@_!*yw8Sth=ELeDv;)fA7Y_`Iwh7o0)VgTb@f$dMK&m=SvG_!;pX()HGrsthj1 zecRM-YY#J;U<5%6qPvh6jKkL7f)%G-E+SSrFq#OOr%GY=cWBW^9F2Nwtfd>p1s*LbYfgAI-XUb-ng!H?p6 z^?TqCS;IDQ@(wcR%{M#CJfy&&_o!ZSi|htkGxo3pC=PB6FaU+e_1lPHHva=Q0{75p z9aoDDI6&bwrZQ(j2!c53D?rpE8gTRh6Vz({s3k>T!`B3BpXo68y zzZ2qNxa3!tuz?|Kk&2f8HU`-jY=>9gBOOeELMBSC7@E$`hg1R9sbPjHFr=nH{H!6@ z0w-L>Y^Z$~=f2Lh4O_Bh-6Ky$9@O0PdV{fDz@1Tr{;Ycg;}IXc@bvI^n(K)b0S8^j zKp+)Pn@E25-HyLl`;sI4%a@9FZ*Rvt;!U+P_Br#?ku%LcQavD;a{VZNc{&4$a0-^&x5 z?y;GwPT)#Qn-0F9>x=8RLwyEV(H8hE8|)UD<<^&t z)g=}!xkNQvs&x?n1qYrt!IKUBEm)^-Ubfx3*K{d;tN9SqqGS4uw+|J`UVL6)XnXeV z^Ylq=7dx7=u~^u)!4=O$Jl;0&oMqW4uxRLIZ;GGXv*2=JwEf9Ft9nio4pm(hZ{T^V z{F32uDkXj!vd6WZ(3u17@0hp-GxNm&<-}t^){^MuDL7_ zs`XS%b*#+*J!YVCQZ3Qmz|F$ELZv|Obtt-!f39qFY_J&oihwIbi;yMr`q+@lfRnT7 z&@JpiZ6`OkdAX9=<$)1dl%`3x&AmJ6Z)r43Bl znsvEjHnnkk^BDXKuw))RxQqR)trb)HL%Q)X8@`(dO1@#m37bnA9!p6~x#NDXBW&(E zxeG1qs`R`-zXfIpKSJoeZ!@;@7Vy$!o&fvyOfA$UU7 z(Run#oVvDTrX6H8dggX%IBSok^@S^x--us_Po~cZ;@PmQ5W?bK`>&kSX%C+elEBxf zZK1lzXAL#w({y~%N0_tr#|du3?;j_r-Y+pig`39_WB@UO=~v2_!6VIUeuKDxmpR5^ z!KxxI*XBKXTKHY4d)Z~~l3#9@A&~b740#jzYv@lvu%>~{saPkm4)~A6^%ICscW_?V z#7%+{EO1;G?|#L*>zF3%(1mhM_6)xxw5Z<^CE!>@>)_lcm+VmEf&Db9u8-goUY8+j zV5hml@V0(ocps_VA;G3TEz*nnZu9HAZ+tJ^BU!F~<6khoBo^gQm&&0MltPS!n+oIW zW1777@TORf9t(@BVrh7cP*^zWD$v!kIOa(d3aqsDA{I@-689)>llg?y^X zkSEC!>@Q6Ytmx#3Z%PgkYtFH2TgD!EAB?v^5H-d=Xfp-?-|rLnb-V$W3IUk-cf*B4 zVAQ8_1~J8W#|KflTH@lhjtg+dD6r>WzhnGdU&l-Lw80Sb9v01N{Aqb1>7G)q7cXdY z3%;InIDk^Q6yvUXcH@A>=DctpVTN)riw((U4Z39(^4PBK%ibimRiqu&ea$eoCI5nH2%pNwm-wZyjkSI6YAVW{CN;)&<%}y2nhMH zxG1eyG6AS3VnXJiqL{gL&pLQUH?b;sFR?t-mZYY~3i7GbB%_WVd<62R^mOGPlVy(| z-f22ML1dW^`RP4`5dxm~Xr6$hh>ydXgG|RzaBow3Ta8QiQB~9EKuY-d&0ny__?+UD zCH5{JA2XnK(lU>Et%Gz#b$A56(0Z!|KEz6$-M#ZPgPezUr49I=+`;k-hTb~7yWBSS zITIL1T)ZZxxzjfnuZymj!gRpaKo~Z31Q5HlywrgZ7-V=lon8rWGz3P6(oH%-{9U0~cOuYPm4Pspn~0w?ZF2wp2`ZO#rvy-0^H9QBoq{zSRn zhkG@6vDri1r3b%pleTkHGp8SK?p=xd(l;Z8Y@kkEkXmxYTWT&D5TqfcRjla?=2yQ- ziEeC15b01)m>u?A3M;W05`*4WrL!NeInxHH@5uf{d5@s8_kJwYyA8eIw9+C7fgFMDNdin82#L8o^G|Hwdb2Zz#2ph?f2_gk z7(ac21IqYbaj|XcVP~3^YS7`SrT!6J#x5W5Iq@^D+!;QGEX5}=s}7kLumaRAWpm<5 zAePN5b5TMKcU%cq#0u!MB2#>QTN|!ReYWJj@NTlXq6;-fC;%`&bFV~I-{#yGE?($W z@zFKW5~A*=eQe;s!nMFu2rCf{;5*gd?Wp0LsN2ME3S90tZX(gJDv3Y2;Bd*Cadu)* zV~x~X;~%H5DGD_Hl>#37@yb&2G^5H5>%5sTNd;W~Z5nu%@v54-|9>-H`mHhG5| z>vI9mbr;)o?)C|UJ7LsPeg*E_rtwnc3xF;2tQ~M)$R$z^;?3ppb87nH7Ho|4TxJv zKyd=);Q1F$FF;}Yk~`MX-17-GPehKwJU0EPm(-Zk0N*kKIgHMTU>UDJu%;_66_V|;IF zugIxoe%#ukGszpGE8e364~LX zxhH!1oRJ|D8BN9UJ99sY0onrsdLbQg{Mg>j={o`<^L|N;75;}0_7(XZ@W&BYT&tae z4+Mys+Gi-rsi}&3^tEb(XGWok0f>ijX>KjM|7@j`v?CSAXTv5suyX~kY%t?s$7Zc> z2Lea9fyA5o^Css1uzv3^RpZ-IPy>X~fCjoGs(n*@F*V<#r zy~WtLwR@P1Xd8g)YV;&i@+P)SxZN?maJSgYIKYVW<{3o|AdTk`x!TmA^ zmNRIDv2#%*@ovcP6~k${&u?Fr>eh)xGp9q<%#xpB@1x6L)d?(OrQprDR;x2Bbj?|H zi4r@Vh9Z{v6o}zCdh`s`69f0oyWoMs399=@W8j%UeG=C6)u4&VONr zCU-;~u=&pg{*kDFUq^9|jsj=ho(IlxPR4Ud!~W5wEMrf?bQ z#e5SRKM(^6-)W_krbHUuW2m07F>ut1>-k*q@rI`&=uv9T+MAv@o%msxF=Jmo#SNH! z#t)PqQ%Yk#8=QJ`uJQEdi~Nk%s%TDFtLBSp5o@DV3g3y#~{sd ze8HdE2ji5iu`>tP{DNPEF>9tM$0uxJg5y{aOY7W%4|l+pdc%)e^;ENis+?Ka$gE~Y z=`F@rOM=5^3l{0tWwfD}8c1B1eo3zxS%Lw!@Mop7j8{irCts+5_4ykY&0}(mzi~>^ z^2wL^Jr6o{>b>d#bj1$K z+A$hBk*$V70IIT$2}dl2ZZ&mKAdL!>!rI3>PRH>3(1W1v?@rb{%|s$vCpX?4uvjMa zOh9MDG5dz@oX(y&9Xvmy!wWu1(s(1@o6XA)8DGAsvc{zfjZd-oS1HlNC>I zlhW=!bBkR@jhu1a+8rYSwykmviOA)w3&Nrd7Atg3jM0I-0S}RUInU+L%O++nXpCeL zqm02FFj>PUKev)6Y_nDzl3SrTo;&m^<}C(3XhpdM$GT)BYES)AV$zSChuTMN6TVLj(Pe9Z86@vveY`0R#Tv8Ow~m+L zO6773rvpsvaR6w?+omqUc8sYlK44~)$;bmrRonsNHFh`1LnR2!fd`sa_G!uUP?-e# zpK#yPL2)!*P`(gNsqhTuh)pL|fd~;^n>N^~2zi_^h>3kFaf948)7lQ`T;0s!Zzif< z{i8~Q*Yt4g!;HrXNN^2EOUjVp^*2^J9xiEp(7#~Ya)=EN)eC+57=Yx^rV$F~j}h=B zR}hT|AV)Wa`4R7wd>ZPQFT5i*ty{OBY7!f|j{B5zERiBDW8%SBUA!mH zt>erqKAjb{Fh}^i!71dkE-(}K#LaL<>@J;GtiMkMd3AJGhI7VG=A zr)6RdTN(X;me9hF>OR+};*T)AYHjq)kdl=Qjp!+eo zWe@iDDWCbtbvEhV!0*1V4{6rLo+d84|AnItZTMLI{o7&-%FTY-_H45Odt2{~UR%R` z;-$rt8G6l2bch~$V{a>banFNi7W^+f!i`D&$r*U7%`Yxp@%DC|s6HJ||6g$Fo-dQ( z#{=8#-(s8bR7YT&-vJ`R+Y7CK(yV`fHmW}B9;(kCzdy^X&w6C3(t227o9A#X+a+Ul z*fVym^K4-K;!Muq)xC#Tw*eomP)}F!-j1o4^3qu>oO9yuKj7F+=Rs>f9Va*8j3*DH z#%$cZU-{`}^6(*k+x>zldv0g#&f7Q67$ZIDVFWn$ilTU%QR&?9I?C}aYfR$fJM;80 zo#f2lz6c$-mbYa_CzmXW-(UC+ocl84T>Pbn$;4_#yX>wAL{VfAwSA8`M&@)LvCPAvMr!Tqa9C78~yZ|>joC33?0WX{SjK^s_ix%L0F_vP_T zl==Swgdzeerywk+92RNLHsMf`ncSJl-6mD6xjH%8q-oM789-ES4{lTx@WusnRZh=E zSuPQmLy$`Z(Un#1RZ&FT-}^~hns#B){qFwvo7bzI$sEr-^Lakc=lR?pO@W;StP%Lh z`%8pH6jzlfwZPFW#nt%3`WnB3$>?>73SU^Cf-#@$)8&9+NyXKyISCWZo)c%!W$)vp$czIk45YS=2TVK zkG$NbsI8U=Q6A84hFBbx@ZBNp1tKuM<%15oLrf?QVoQ#d7Eo6?;>N@-7DV2qm zR|CKBop=Vub+Ez=ViuELiKiSK6ZsY?`6jo71d=+NV(b|YdW}wSdh`{UTQJA9h797YBM#(!wnh+(z*}XwRskEVtkg=S z1CE5z>ncu|D$}*5DC32EoMsA;2>>p?ItM~xRBDM99x@C#D|Of$hN?~iZ2n9R*1V z6`~-zN3E2R))!(4R~BOIAbo1o4yOSnC3)8Gqx8H!A>tUYD@v6prS~8=HHD;zh%KCua_oXKHL!&vKPc%f1oM&JtN zLl7(EBR&L|cLvrMr?3Q#fbJ*_JeTkz+zVb&Diduj#DZ7J+QIGHsv=3NT8a#F39uG^ z){7hwG{WaqF=SQ)TarTQ3mAmiDp<;SCLzMY3~R5Q@xqxtq%@+aZU(rPLe_IFGPYw~ zgU6LE!Y5K{Im;)E5Wp(gOd1(+bzm?8MkVbNGgPVsEeP2R{QF3(orAavJ1elDip|xC z=9KxgaJj-LL3>=BgnXikt4jO`7Lir@b#crWV${SM69u~p-4@0(O>r?F#t@!R@K1H< zMWiAp3_vbDDz&4qAv6%ksSmv`7%3I9mP13Smzd+wH1SLZl2!%7mluO`bwL9W)mc^W zo5yo@sCRWtnj!V#kdMBWiy?UurjRhER#%{8C4=~&N(0KN%lJTvEUc%YSx`kpfUVtJ z0(&;b1?^ONu^M9)4VZ~4mWP0>7J|O0GpM8iz67Dj_yB08EEVv^H1KMX0tpArj%s+7 zS;g@5PO@QhKH3%65@MOyN35cMwr_!t!nMCNy z0;NhQ!FPGgUg(|pKz{by>UMHn86ASFJCQ;T3}LjNGlDf97O&c)!g23c{G8Pqs%>kU{V(hSB! z7bg9xs--A-hTH>K3o#zBE+r_~0*&6Fv>Lg*RfX1wsshCG)QHb2!n&lBLT&&Y^l;Q-DoQz_QHAJNLfT?Psl%=+Wy)8ICDeg&=3ONOzno$Qei?M38>E#@Q6bQoxX+*sYmrxs%OtH_ z0TXr?B8!ODgd|sfYlH#5N=XV19Z$+>abcDwDT9&w8ugekR$}N|SUvAcxIpiTAOvB- zBIOw1ja4E8ZkXw+0^~RYj%2g3S&-6SZ^Zgoo{DI|FV(URaz3Rp;d^u2z%|UL5}T5j z^9&u=)wKZ=S@s#S_kli`aUT9Zh{ePG=!dgOKX6V@AMZz!{>*Q^I-*Q3bT7fZ@5qc=mZ~s`2@V*zBq6gN`q+P^N@Dsy-x2K)uO# z*ddvqUw|d9@p2@`h$G8cs+OfqWYU(icqitbgeg)Q{3)cDv?_usC27N%0B&d7z`HCY z^o!D})h6DW!LgYZL-h;3j;%=;{B#m^GZ3&~Rf8LnO2VAbA_^v>%BWjKVJ7Or&<|PT zCMqNsvJ!YcQgWFI+FK77(^6?E4_jXqOS`qOSfSpHGY&`HBxSdX$ofs?3k_P@n?yM@ z5$?0IQny+wwwVDpiP;wR%yy$(X6C9OH(Cjw ztuTeuVjU0fPbXxqkThsar^GmyFgx5edWSoWl!XpxPSRMboA~xxFdl~Yv`IGp@D)5dW}-X0iT3012n6E z{@R*mU=h^mf-H`O;FpwUmtM+%IRvjL^ah2h+LY5#S|~-15%@QU;N{~2e*$?#9oDk6 zHQ<5S0`sL3+0flmwp3ltibcUvn5U3zBY`+OxDr4f_F1iX%Vd!yl4qD_e|PqG=?KcHYe3ZKxa^-&>UB>3{woA6Pi7> z1DW*@J0Ry$ZbqyIuddDz*lI~A1oDx<>~7ZE^l6rh6fzEm$2U_Y;PqL=Wol?tr%OTK zZ^xX8MG!+70e=}NqU=to2|6@miIw+R6Jav;S>7RKkZVJ>2;-G!E@1~?wCp zonK0;g&1QfVD)r0J{R_Me1be|K}@>2+!;h))2s)dW^+nPFjs5r$SQ}}VDM|C4F?Nc z(eXi}PAO&3j@^uVwT@sK?=4kXNxpCr+G)whn4k}F9BUFfs~vmFt5u4zXF1xA8bxR) zoBAbAOeL76P^v;v2s0_>Sr|rk%8h_}q}+1oiOQKWrjX7Z>k^JDA=6@q8!5Mh+(Thu zK>RM`P@mGA5em)u39rWC23U%Gny@eJ;=yh6L~l^-f}fT!OSgfm;LULjrM;mC1bu)s zqX8qVw;;D0{9vS#BJk`;nD6x^;D;JTj49qjqe`EFb|%nes9Ygqk@`}?6y`nFW~6b~ z@m{FbQ7GP__UX7}p@u^>KKMzP%Moi0cSndb1}&w?ux5x|9OjNaCS*;d>MEDON1>dH zlJIMQU3D$owRi^i(L%KsZo@WBSOMOYc<9CgOrgF+nU6M-NDJ$T$V@DftSg8>pe@#g zt1@OsB_uj`#sr{I8Xg%ODexoN`%xO5C+*@Kd~xH*^8r$2D!2Nihz3Rl z0UuS3afLjKa-h}(IAgClTJ)X898fEy=ysKoN{P%OMBdRhL|_%UQFbREH=sXoqp|A4NK@j+f)^q3Gg#FCG&fkDQcwrG z1BR#}Qj&?7!)YeN(gr^SX&RO#;x-2v8)Y!91tRe3v^(XDGjuV=&@Ck2M-_ly5XP!> z4!B3aY^1X~L_|igVPS;sK?&PFIX3$^wWh=-w}?}q@1U&YduT-hd4mucsMJ;JI9Zia znuVq@qN`?#7fT7 zVyunepK;Bghu(q6Y&f={LLkTsY~?AnIv*idIA?TN!Ee%BLXuJ=29g29VlHD?jmZ1`Jhx-O~FfD>1X^Z(ni>wMjH;+3;!D$uAH2V=#8Myz{dbCn2Qc&NKu)+ zj{1iXU29;M1+HKnAvt&P0c^Pjc2_sfOMaG$offl+1$VXHS zoF^ReK=%E-o4{D54!d1idql(%!3p{oM0A6GMXZzFx4UH6ZCg#ck&Vb8Y7g+Rj^4=N^caQRjL;h)Jg0cOOBn)hPR`q%5>_h= zAzXvf@eJyLNXuCxBzX=fl4215C?_$e6!d~#7S&!1VKWl;5_R0ffE91Q=Ay%Vnz-#Il=*8dM!uIH8Ot??tuJBSQmAu@>4&i#ljF$TT-g1 zTMLb_4mg%ZND%aQ5%9@XaCbFP@W2Bc6NnrzNi~!xVrJx+2OQ$t^^gYmJ%Hn! z2m*~J=5646A-!Yl zzzMp&OtFtgR%aK%Nf=NucOqoX!W>k`po1k@c_y7 zVUYP?0(sG86-u}P@p>WlezU=9Hqm;XoGZ3jAqz(!mm`)1wtesU+wL482c=UORE>K^AJPCk>FYQmk6)~pQ z7$FB}IejH~Iq-#YxW}uJ16Unr9A5C(^=k*zP=q%o@Whz75tfBok+3mRjD$(mA9ETc znt;Ar372S+Rlyqbu=(f!a$u8nlqN<7n_dQ5;)QJIGK7(o4yi^8K^{Wz;Da6%ssN`1 zYhtDd+zQOD^_#5i(63sswqvNIHY0NW z)L^}Vt`fK=NrPTNMnupWycV?%FL)`177Dot%a`LgA7cz(XQb*NZOci1m>RXw2>uzz z&}BxA%2X{Xhd}28UWm>IELBc1@QjvPV@8k79b>{w0d(;~{8s70p&|-;0vu~qV8>}x zGA0BIu?iw{F?bu$aC|pgVO3Xv4*|bcgUZ0*TMe*x=5xrQ#7NAgP8D(>reJ!;JWz>A zDq<)BiM~w@Srz9f_-X0wOhBs9hdEd})sVlW$WttrK_3OZWx`;RKv%^P8Ah(xhTT3I z>K26h1TZ#*$Sa}K28HQrs}nqkT8R5BnP3cK%w@>29}q)WmL9Z(7^_sY%_*z|{{|~p z0P7ZLxf662G!!-*l9gH&(gVJrvvSZA9`pqKGuHeH$h`PI@RGI&vX!eXNuuwjIn*A5 zyscJ)K9xdMglIX)MEEvD7eW?;o=%y9|FMmeMv$%+hJPLKoT6<$qTsD&>7-B>@}c^- z9KM;ca`3Myk5z4ePRL_M7Cg{RLg!ihb}=^Kh%UxKe`SGoUFJbOQ&Y6K#^aVE0uTY% zG%7?=p!wkhr{ztE2dRaAQIC3?#9@)hqc{amS$vofNsi!}q=UlY1~$G?7`7G#dxgYT)|5Fm%&SY{eFILVMc0si%ysup#1Q+}5<^&9%r@KClt zzv2FW-~V(Lu8?0EwtBddtTpvpiHNn>Dr~H^IOsPOCz0o?82P^xDPKm`SVx$WMM_Ib zgZ`4V&}A2ST_QPbxojVn#GyA@rb$<5LXeAn}v5{%MLigV%Z|kVBfHj-*lZD5BXoDDX55& zyuqYxFpSxDaW6Dg4OV7n~lXD$@C=FZV^{fG>{-JAJS7bk&+^bSFG?cr~X+Jxr zb=Pw6tg(F}Rn%{d*K_)R#pfHZ!%elJEuFBTq5r>cgg_llA-~nhf|5r3N;S^uEWR~9 zVJt4l;w7E6Z)VRo$J6G&p>_ctW#?yOJL*D$$IujFTpPv+0x$X9h|Kxva#?BJ`!3@Y z6fFM-x|F49s%&4tsyA#6G`Nwrc_qb2q*5$|mXGu=hxTffsx0l)H$0rPr)$d7*S^0F zKcQYw;(lRv9f6q@a4@0hlmv4Z#R>4vI=W9g99K4h!u{xcU-XHYpy zOIcle|vVWHH=-( z7?blxPa~dXp93yuaU=UR zEXG#XbrG~}+ZN9?-6%n!Aebui#VV^j$pT-@A1Ej*OU6)yvaGBiyDunD7jV^{s$f}V z%wJPMuGimC>4|a`filtXstoe&2ugqV_R?%+Ld zJNS)xmMtA0{&?-g_41`sc4Y04TiTnSS@L2@{+-NO&G?g_GhfcvZFu7Ak%EnLI`^5} zRoZXE_1nICzL?#nTe)%E;S=4e|M;l&58n=+&klVhJ>kN!J@;|j1Am(M=-Q&ER_vbi zNN>esngOpxwwDWb^;|xZ?sk9v_x|a_BKJ%Cw0}B&=*EXUUwpL5c4^p#PftF$w&$*y z9dG<`z>~j}ziJjtJG$~=->ReM4zxaZaCFt-+aEm8C%;Jbko2p0#zD`0^-p7L@r@@W z=U3e(c)9Q|R7X!~yMos&Z(3)F+a4eP+_7=yqgz(#9)10$Z`<6xapby_bMEdYEHgfT ze*V~2y8gc|oi^_L?7zn1{XD%x({G)*uFdvQi%#F)f(+%4%SG(K76QTfmV)e13NVyS zhmrhZ1pWRSLpj_Z@YZl;Rgnsfs`tMi)CXF+Zo!K+W+ahxl$Gcmx-z@Nrf0Qf4g@=? z;{q4ySG>iPcuS+em&K;x;ZSmTSyjjvsR*?5)`TkjWua)YB9utBi?}?z;VvC|uX`bUqhR_9|67*6=ciXq z{nfUo-|at`+YMW@w{~i`HpfjJL$hx0IcZX_Sf9Cr-yM-M{(7kA>Q%eHS3LQSpQ{|I znBMQpz7JPFEE+Jk`*UaGHx=Br@b=C_Ei=E{|H9xGPQU%rp@n~b<(@@HgOl#n9i8;p zrdF@A>jkG8{c z?6J3a{|Fm*OpIfG$`*y;<(LZe6B01CH!qJ)?pT9Tg^KTwm zS1`rneYKEoWnS}ur;}a2y-V+(KI`*F`=yeugBF>dbolP-KmWwK$y?^HXuq6p)6Z&~ zGvnTAH=l5RJ#+6$&*m-@=KlP_);_zpb(=fz<5!a9wFgga+0p0bPCr@cwl^`KUF2$K z{bSgivF0sf{rM#?K6I(ak6&K5O;lnUck^qOf|G||ylZA~&a%Bj#T_~+ehBoLF8k$$ zsg>0oznt55g+b>C9 zx)fN>{9>{9{$vI_I&k~I0V77fv55Lk6Ipk5UC;e_$E*H$Jyp$|-fxce`nq!bnx3;~ zCSFt;9$o(Mn3H>K3kpoiQ%Cc@D?YgR`Tl=PP073N{qAcf^*XSlTdy%EkDaM8y%`?9 zE1|w^-?_~Xs}GFt6zaBFv0%)}FFqRKez*_U_tKlg46>y5p?Tcvk) z@3D33ro*GGKFvw>tuYrULu0TS?0>w2@x4>87 zNhAV^f(9azB@nV4r2HMLAelo~rsbmqgZ_h#~Ip7lHZ*?Df{bQc? zql3YP6CT_=efYRpchtNrneoAc7gm4vW1^>G_FdgxPd`?3)YZvFr zar*h6Z8yEV<|#h>&e)a$-Or;A%sxv|KFYuBi&H25{B75ppI&+0o2}=H_WotnlxHrT z-Z}ovQX>jg|Wr4;naq0H~;wk57UZNlGoq< zbn;Bj@P6N!LwW_~b{u{BF?XAJZ?*TbONTA!TrhC{3j6i}+X8I@i+*k~XW2meOVj#| zd~JL0PuAYIdEkU4>0kHXx(gIk{^52%ScT&UR=hmm-WD`ZyYZ@De5s> zwD_4PX9kk<4+vTA-VuW5XH4At{#M1K$37V7czDXN2Tsilbl>Q{Wye<8&YpiQ$xGhX zbxfeOd;a7xtIwXYGhg}3 z#9m`6zUWi)%F7=t8fR_)&6F31mYf?jcgo*3tbFI;DRI``ckMan;$fSnx3?GE{)p|> z=lbP&^G2@ozB!+ZA54Am(u)i4k#{|GKD}#u=f1V){`lvy1*i4#BV(Q_3p0Ct^9FA3 zv*L{r&%9nW=9VY=eKhc#?8hIuWb1O|nP=Y|Ga^eW?^r#v&@%3ucV`F= zstyJ7LQ^hA&9SeJwb;!U98!e8dhhZ3Uv7KjinbrWyYA&Ymfm`*fPaP__J~?sEl0~D|Kl9xprSbT)YewBuD*o#~#$60= zd3M=;!+~Ccn@T>c?q@!E(-+qrR;-xtHSMYFHZkz@{SApA9>>Bqg!>IO^<)Ff5V(5 zr(M5zHWv)&o-e=UwQik$y3hAu5kIH=)5&YnN0%RaM=$A_LEABh-#M$rP?elcS zfge$Be$d|^SKP1ft$ycM_w+kMZI@I@x*k6BePYdyp+(=j)ZI>u_x$5QW>3c>WQ10~ z{qT0Y}UkE!R5 zRqh!5-PVhVGhH*sfB5T>cL#ocVAZ}I^*r2;)ma{HTbJHHf7WMJ`=trO!!2}l{udr@ z8R6kdz{9O2Jly=BKiJ=gP$0s?MQcYco@VZIQv(n8&pIA%GHJNIF#k#=G-;rX?a++Be~qpF4N!(DuvQy$f3&d;G!&FD{?f-v8CH z+M(ZnIq0$1cRt%D{CtPiGuG~`z3uK`-l_f%oZorph%uTk#{K=*@YlyWitpLGZCCl} z`(u;$J$Ju!L)X1Kx-RIHN}MPn7=PeH&-(l(JlvJU;I4vBO)x&1NP42lM1fGF>evi9 zF>|sTY5d0{)0YvBko;YpFY7LlT&eNlS!ubTS@W0i{Avdew5%73%W9ILibT7zYKY!M zB(7sF%WB2#;({SVT%f^%1jJ_#7l^m0d!(Hy7Uk@9=i0?vy2z&ppI^=`Q-DapQX^ZDg$x7HAN*~QUiiXHv z9Sp9-B=->9kpqKdwZ|7ALVQ(C9gv6qb|_pO&bJEkuJ(>PShpf?xeUKtr9yRpfGSXy z%v9l@hXV*F)Pb-(R-N)x`{ONwg0Vxc4f&SMAm6GU@|ud6CwUq4*VTi*wW|%F=Bx30 zMhN3=i~|7uJV4*r+L)IdV$8#E{F_l|n;nJg1Fq}Qh6&1$yOq(**m%3Cwi;I#oFZd& zcOqFmtncNEW8*hK>(@RzZZ)=vtF8kEkiVKKREInjA>KnKM_I_PQ4RUc01=85a*0?XlPDxIsk~5Bq$rXKWkRt;E*DEh;zF5F zst^@Q6r_iXHxI}gg5Hroi?Ia()@3bTuBx^ZxzJ>5RL_zQ6ZEfpzL6!%)elllp8W>O z-^SA2c7XTp4C$ljdjb`45`E`gi2A;`cRq(cpd8deQH~$twy-%|APrO(MGT z#_hxX|NGuFhf*Vz%jK}THg*a}?;L8{ugff3bJWabP`Wp1;+D^@*hC|>#NAhzKRuc> z((7<;Bh5IQ=6%UsbZ>3a%mY_%CbvHg%xmuYd2^G7_I#e(pZXQMc|UTOzRpb^w;joS z{PL39yn)=6sc(}8K5f##mDfv)u}zy{^-ukw)g70gZ&=3~_S?zQj4ge%b#{;E2wN+l XkJtSs46Mc2I(t_qu<+#mx_kd0(QA

Date: Tue, 23 Apr 2024 16:03:06 -0700 Subject: [PATCH 2/4] fix: Remove references to implicit batch for TRT 10 (#2773) --- .../dynamo/conversion/impl/normalization/ops.py | 9 +++------ .../dynamo/conversion/impl/select.py | 15 ++++----------- .../dynamo/conversion/impl/squeeze.py | 5 +---- .../dynamo/conversion/impl/unsqueeze.py | 10 +--------- 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index bbe566d0b7..b1e4fbf24c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -58,7 +58,7 @@ def batch_norm( # For BatchNorm1d, reshape 1d to 2d output_shape = input.shape - if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4: + if len(input.shape) < 4: assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "BatchNorm1D with more than one dynamic dims is not currently supported." @@ -75,7 +75,7 @@ def batch_norm( output = layer.get_output(0) # For BatchNorm1d, reshape output back to 1d - if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: + if len(output_shape) < 4: output = impl.shuffle.reshape( ctx, target, @@ -411,7 +411,7 @@ def softmax( input: TRTTensor, dim: Optional[Any] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) + input_ranks = len(input.shape) if not isinstance(input, TRTTensor): raise RuntimeError( @@ -433,9 +433,6 @@ def get_softmax_dim(ndim: int) -> int: dim = cast(int, dim) dim = get_positive_dim(dim, input_ranks) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 layer = ctx.net.add_softmax(input) layer.axes = 1 << dim diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6f827de2eb..2ec6420e0b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -40,19 +40,12 @@ def select( "of the TensorRT region!" ) - ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) + ranks = len(input.shape) dim = get_positive_dim(cast(int, dim), ranks) dynamic_shape = has_dynamic_shape(input.shape) - if ctx.net.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" index = index if index >= input.shape[dim]: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index cde4fdd90d..45bdefcd80 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -32,11 +32,8 @@ def squeeze( for dim in dims: dim = get_positive_dim( dim, - len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0), + len(input.shape), ) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 assert input.shape[dim] != -1, "We don't support squeeze dynamic dim." assert ( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index ce893f8d5b..d056b8f0e8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -29,17 +29,9 @@ def unsqueeze( dim = cast(int, dim) - input_shape_size = ( - len(input_val.shape) + 1 - if ctx.net.has_implicit_batch_dimension - else len(input_val.shape) - ) + input_shape_size = len(input_val.shape) dim = get_positive_dim(dim, input_shape_size + 1) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0 - dim -= 1 - assert ( len(get_dynamic_dims(input_val.shape)) <= 1 ), "Currently we don't support unsqueeze with more than one dynamic dims." From cad3f94781bb2d58d4637ef7a8403b853ea6fb9d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 25 Apr 2024 10:49:25 -0700 Subject: [PATCH 3/4] chore: patches --- tests/py/dynamo/runtime/test_hw_compat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index fa87c9947c..1d051c5fdc 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -74,6 +74,9 @@ def forward(self, x): not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8, "HW Compatibility is not supported on cards older than Ampere", ) + @unittest.skip( + "Skipping this test because the hw_compat.ts can't be generated using torch nightly" + ) def test_hw_compat_3080_build(self): inputs = [torch.randn(1, 3, 224, 224).cuda()] From ddd083365a85faffed8277a0efbd505bf907b51e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 25 Apr 2024 11:11:57 -0700 Subject: [PATCH 4/4] chore: linter fixes --- core/conversion/converters/impl/expand.cpp | 6 ++++-- core/conversion/evaluators/eval_util.cpp | 3 ++- core/plugins/impl/interpolate_plugin.h | 4 ++-- core/plugins/impl/normalize_plugin.h | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 49f18159dd..0e68768e15 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -44,7 +44,8 @@ bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor if (size != targetSize) { if (size != 1) { TORCHTRT_THROW_ERROR( - "The expanded size of tensor (" << targetSize << ")" << " must match the existing size (" << size << ")" + "The expanded size of tensor (" << targetSize << ")" + << " must match the existing size (" << size << ")" << " at dimension " << i); } } @@ -131,7 +132,8 @@ bool add_expand_dynamic( // if size == -1, we can't validate the expansion before setBindingDimensions. if (!(size == -1 || size == 1)) { TORCHTRT_THROW_ERROR( - "The expanded size of tensor (" << targetSize << ")" << " must match the existing size (" << size << ")" + "The expanded size of tensor (" << targetSize << ")" + << " must match the existing size (" << size << ")" << " at dimension " << i); } } diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 2c6567aa95..9b6139073d 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -165,7 +165,8 @@ c10::optional toIValue(const torch::jit::Value* v) { void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { if (!elem_type->isSubtypeOf(c10::NumberType::get()) && elem_type != c10::BoolType::get()) { std::stringstream error; - error << "Input must be of ints, floats, or bools, " << "got " << elem_type->repr_str(); + error << "Input must be of ints, floats, or bools, " + << "got " << elem_type->repr_str(); // special case empty list torch.tensor([]) if (elem_type->isSubtypeOf(c10::TensorType::get())) { if (empty_list) { diff --git a/core/plugins/impl/interpolate_plugin.h b/core/plugins/impl/interpolate_plugin.h index 661cee270f..ce009af03e 100644 --- a/core/plugins/impl/interpolate_plugin.h +++ b/core/plugins/impl/interpolate_plugin.h @@ -57,7 +57,7 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override {}; + void setPluginNamespace(const char* pluginNamespace) noexcept override{}; nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; @@ -117,7 +117,7 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* libNamespace) noexcept override {}; + void setPluginNamespace(const char* libNamespace) noexcept override{}; const char* getPluginName() const noexcept override; diff --git a/core/plugins/impl/normalize_plugin.h b/core/plugins/impl/normalize_plugin.h index 7e564b505b..5d51a68293 100644 --- a/core/plugins/impl/normalize_plugin.h +++ b/core/plugins/impl/normalize_plugin.h @@ -41,7 +41,7 @@ class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override {}; + void setPluginNamespace(const char* pluginNamespace) noexcept override{}; nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; @@ -101,7 +101,7 @@ class NormalizePluginCreator : public nvinfer1::IPluginCreator { const char* getPluginNamespace() const noexcept override; - void setPluginNamespace(const char* libNamespace) noexcept override {}; + void setPluginNamespace(const char* libNamespace) noexcept override{}; const char* getPluginName() const noexcept override;