-
Notifications
You must be signed in to change notification settings - Fork 73
/
Pytorch_Wasserstein.ipynb
863 lines (863 loc) · 176 KB
/
Pytorch_Wasserstein.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# An efficient implementation of the Sinkhorn algorithm for the GPU\n",
"*Thomas Viehmann <tv@mathinf.eu>*\n",
"\n",
"Recently the Wasserstein distance has seen new applications in machine learning and deep learning. It commonly replaces the Kullback-Leibler divergence (also often dubbed cross-entropy loss in the Deep Learning context). In contrast to the latter, Wasserstein distances not only consider the values probability distribution or density at any given point, but also incorporating spatial information in terms of the underlying metric regarding these differences. Intuitively, it yields a smaller distance if probability mass moved to a nearby point or region and a larger distance if probability mass moved far away.\n",
"\n",
"There are two predominant variants of Wasserstein distance approximations used in machine learning:\n",
"- Stochastically optimised online estimates of the Wasserstein distance. This is the concept underpinning many of the GAN applications using a (heuristic approximation of) the Wasserstein distance as a *discriminator*. Starting from the [Wasserstein GAN](https://arxiv.org/abs/1701.07875) as an improvement over the KL-based DCGAN, with improvements to how to estimate the Wasserstein distance in [WGAN-GP](https://arxiv.org/abs/1704.00028), and [SN-GAN](https://openreview.net/forum?id=B1QRgziT-).\n",
"- Direct computation of the Wasserstein distance as a replacement for the cross-entropy loss in mini-batch training. This is commonly done using the entropy regularised Wasserstein distance and the Sinkhorn iterations [Cuturi](https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport). In the context of deep learning this has been proposed by [Frogner et al.](http://cbcl.mit.edu/wasserstein/), but there is also earlier work in image retrieval using the (non-regularised) Wasserstein distance, see e.g. [Y. Rubner et al](http://ai.stanford.edu/~rubner/emd/default.htm). A comprehensive treatment is given in [Peyré and Cuturi's book](https://arxiv.org/abs/1803.00567), R. Flamary's [Python Optimal Transport](https://github.com/rflamary/POT/) library provides implementations for many algorithms in this area.\n",
"\n",
"This code is concerned with this latter use of the Wasserstein distance. One of the challenges is the numerical stability of the Sinkhorn iteration and carrying that over to mini-batch computations efficiently. While the ingredients appear to be readily available, it seems that they have not been put together in recent implementations we observed.\n",
"\n",
"\n",
"The following is the code for [Thomas Viehmann: \n",
"Implementation of batched Sinkhorn iterations for entropy-regularized Wasserstein loss, arXiv 1907.01729](https://arxiv.org/abs/1907.01729). If you use the code in academic work, please cite this paper.\n",
"The paper has a self-contained writeup of the key calculations to derive the algorithm.\n",
"\n",
"First we need the some imports."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import time\n",
"import torch\n",
"import torch.utils.cpp_extension\n",
"%matplotlib inline\n",
"\n",
"from matplotlib import pyplot\n",
"import matplotlib.transforms\n",
"\n",
"import ot # for comparison\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The kernel\n",
"\n",
"The following GPU kernel computes\n",
"$$\n",
" \\log v_{bj} := \\log \\nu_{bj} - \\operatorname{logsumexp}_{i} (-\\frac{1}{\\lambda} c_{ij} + \\log u_{bi}).\n",
"$$\n",
"\n",
"This has two key properties that shape our implementation:\n",
"- The overall reduction structure is akin to a matrix multiplication, i.e. memory accesses to $c_{ij}$ and $\\log u_{bi}$\n",
" to compute the result $\\log v_{bj}$, with the additional input $\\log \\nu$ following the same access pattern as the result. We parallelize in the independent dimensions ($b$ and $j$) and split the reduction over $i$ amongst multiple threads then combine their intermediate results. We have not employed tiling, which is commonly used to speed up the memory accesses for matrix multiplication.\n",
"\n",
"- In our implementation, the stabilisation of the `logsumexp` calculation is carried out in an online fashion, i.e. computing the stabilisation and the reduction result in a single pass, similar to the Welford algorithm for the variance.\n",
"\n",
"I explain a bit about the reduction (in particular the bits about `WARP_SHFL_XOR`) in [this blog post](http://lernapparat.de/sinkhorn-kernel/)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"cuda_source = \"\"\"\n",
"\n",
"#include <torch/extension.h>\n",
"#include <ATen/core/TensorAccessor.h>\n",
"#include <ATen/cuda/CUDAContext.h>\n",
"\n",
"using at::RestrictPtrTraits;\n",
"using at::PackedTensorAccessor;\n",
"\n",
"#if defined(__HIP_PLATFORM_HCC__)\n",
"constexpr int WARP_SIZE = 64;\n",
"#else\n",
"constexpr int WARP_SIZE = 32;\n",
"#endif\n",
"\n",
"// The maximum number of threads in a block\n",
"#if defined(__HIP_PLATFORM_HCC__)\n",
"constexpr int MAX_BLOCK_SIZE = 256;\n",
"#else\n",
"constexpr int MAX_BLOCK_SIZE = 512;\n",
"#endif\n",
"\n",
"// Returns the index of the most significant 1 bit in `val`.\n",
"__device__ __forceinline__ int getMSB(int val) {\n",
" return 31 - __clz(val);\n",
"}\n",
"\n",
"// Number of threads in a block given an input size up to MAX_BLOCK_SIZE\n",
"static int getNumThreads(int nElem) {\n",
"#if defined(__HIP_PLATFORM_HCC__)\n",
" int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };\n",
"#else\n",
" int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };\n",
"#endif\n",
" for (int i = 0; i != 5; ++i) {\n",
" if (nElem <= threadSizes[i]) {\n",
" return threadSizes[i];\n",
" }\n",
" }\n",
" return MAX_BLOCK_SIZE;\n",
"}\n",
"\n",
"\n",
"template <typename T>\n",
"__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)\n",
"{\n",
"#if CUDA_VERSION >= 9000\n",
" return __shfl_xor_sync(mask, value, laneMask, width);\n",
"#else\n",
" return __shfl_xor(value, laneMask, width);\n",
"#endif\n",
"}\n",
"\n",
"// While this might be the most efficient sinkhorn step / logsumexp-matmul implementation I have seen,\n",
"// this is awfully inefficient compared to matrix multiplication and e.g. NVidia cutlass may provide\n",
"// many great ideas for improvement\n",
"template <typename scalar_t, typename index_t>\n",
"__global__ void sinkstep_kernel(\n",
" // compute log v_bj = log nu_bj - logsumexp_i 1/lambda dist_ij - log u_bi\n",
" // for this compute maxdiff_bj = max_i(1/lambda dist_ij - log u_bi)\n",
" // i = reduction dim, using threadIdx.x\n",
" PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_v,\n",
" const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> dist,\n",
" const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_nu,\n",
" const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_u,\n",
" const scalar_t lambda) {\n",
"\n",
" using accscalar_t = scalar_t;\n",
"\n",
" __shared__ accscalar_t shared_mem[2 * WARP_SIZE];\n",
"\n",
" index_t b = blockIdx.y;\n",
" index_t j = blockIdx.x;\n",
" int tid = threadIdx.x;\n",
"\n",
" if (b >= log_u.size(0) || j >= log_v.size(1)) {\n",
" return;\n",
" }\n",
" // reduce within thread\n",
" accscalar_t max = -std::numeric_limits<accscalar_t>::infinity();\n",
" accscalar_t sumexp = 0;\n",
" \n",
" if (log_nu[b][j] == -std::numeric_limits<accscalar_t>::infinity()) {\n",
" if (tid == 0) {\n",
" log_v[b][j] = -std::numeric_limits<accscalar_t>::infinity();\n",
" }\n",
" return;\n",
" }\n",
"\n",
" for (index_t i = threadIdx.x; i < log_u.size(1); i += blockDim.x) {\n",
" accscalar_t oldmax = max;\n",
" accscalar_t value = -dist[i][j]/lambda + log_u[b][i];\n",
" max = max > value ? max : value;\n",
" if (oldmax == -std::numeric_limits<accscalar_t>::infinity()) {\n",
" // sumexp used to be 0, so the new max is value and we can set 1 here,\n",
" // because we will come back here again\n",
" sumexp = 1;\n",
" } else {\n",
" sumexp *= exp(oldmax - max);\n",
" sumexp += exp(value - max); // if oldmax was not -infinity, max is not either...\n",
" }\n",
" }\n",
"\n",
" // now we have one value per thread. we'll make it into one value per warp\n",
" // first warpSum to get one value per thread to\n",
" // one value per warp\n",
" for (int i = 0; i < getMSB(WARP_SIZE); ++i) {\n",
" accscalar_t o_max = WARP_SHFL_XOR(max, 1 << i, WARP_SIZE);\n",
" accscalar_t o_sumexp = WARP_SHFL_XOR(sumexp, 1 << i, WARP_SIZE);\n",
" if (o_max > max) { // we're less concerned about divergence here\n",
" sumexp *= exp(max - o_max);\n",
" sumexp += o_sumexp;\n",
" max = o_max;\n",
" } else if (max != -std::numeric_limits<accscalar_t>::infinity()) {\n",
" sumexp += o_sumexp * exp(o_max - max);\n",
" }\n",
" }\n",
" \n",
" __syncthreads();\n",
" // this writes each warps accumulation into shared memory\n",
" // there are at most WARP_SIZE items left because\n",
" // there are at most WARP_SIZE**2 threads at the beginning\n",
" if (tid % WARP_SIZE == 0) {\n",
" shared_mem[tid / WARP_SIZE * 2] = max;\n",
" shared_mem[tid / WARP_SIZE * 2 + 1] = sumexp;\n",
" }\n",
" __syncthreads();\n",
" if (tid < WARP_SIZE) {\n",
" max = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid] : -std::numeric_limits<accscalar_t>::infinity());\n",
" sumexp = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid + 1] : 0);\n",
" }\n",
" for (int i = 0; i < getMSB(WARP_SIZE); ++i) {\n",
" accscalar_t o_max = WARP_SHFL_XOR(max, 1 << i, WARP_SIZE);\n",
" accscalar_t o_sumexp = WARP_SHFL_XOR(sumexp, 1 << i, WARP_SIZE);\n",
" if (o_max > max) { // we're less concerned about divergence here\n",
" sumexp *= exp(max - o_max);\n",
" sumexp += o_sumexp;\n",
" max = o_max;\n",
" } else if (max != -std::numeric_limits<accscalar_t>::infinity()) {\n",
" sumexp += o_sumexp * exp(o_max - max);\n",
" }\n",
" }\n",
"\n",
" if (tid == 0) {\n",
" log_v[b][j] = (max > -std::numeric_limits<accscalar_t>::infinity() ?\n",
" log_nu[b][j] - log(sumexp) - max : \n",
" -std::numeric_limits<accscalar_t>::infinity());\n",
" }\n",
"}\n",
"\n",
"template <typename scalar_t>\n",
"torch::Tensor sinkstep_cuda_template(const torch::Tensor& dist, const torch::Tensor& log_nu, const torch::Tensor& log_u,\n",
" const double lambda) {\n",
" TORCH_CHECK(dist.is_cuda(), \"need cuda tensors\");\n",
" TORCH_CHECK(dist.device() == log_nu.device() && dist.device() == log_u.device(), \"need tensors on same GPU\");\n",
" TORCH_CHECK(dist.dim()==2 && log_nu.dim()==2 && log_u.dim()==2, \"invalid sizes\");\n",
" TORCH_CHECK(dist.size(0) == log_u.size(1) &&\n",
" dist.size(1) == log_nu.size(1) &&\n",
" log_u.size(0) == log_nu.size(0), \"invalid sizes\");\n",
" auto log_v = torch::empty_like(log_nu);\n",
" using index_t = int32_t;\n",
" \n",
" auto log_v_a = log_v.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();\n",
" auto dist_a = dist.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();\n",
" auto log_nu_a = log_nu.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();\n",
" auto log_u_a = log_u.packed_accessor<scalar_t, 2, RestrictPtrTraits, index_t>();\n",
" \n",
" auto stream = at::cuda::getCurrentCUDAStream();\n",
"\n",
" int tf = getNumThreads(log_u.size(1));\n",
" dim3 blocks(log_v.size(1), log_u.size(0));\n",
" dim3 threads(tf);\n",
" \n",
" sinkstep_kernel<<<blocks, threads, 2*WARP_SIZE*sizeof(scalar_t), stream>>>(\n",
" log_v_a, dist_a, log_nu_a, log_u_a, static_cast<scalar_t>(lambda)\n",
" );\n",
"\n",
" return log_v;\n",
"}\n",
"\n",
"torch::Tensor sinkstep_cuda(const torch::Tensor& dist, const torch::Tensor& log_nu, const torch::Tensor& log_u,\n",
" const double lambda) {\n",
" return AT_DISPATCH_FLOATING_TYPES(log_u.scalar_type(), \"sinkstep\", [&] {\n",
" return sinkstep_cuda_template<scalar_t>(dist, log_nu, log_u, lambda);\n",
" });\n",
"}\n",
"\n",
"PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n",
" m.def(\"sinkstep\", &sinkstep_cuda, \"sinkhorn step\");\n",
"}\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Incorporating it in PyTorch\n",
"\n",
"We make this into a PyTorch extension module and add a convenience function (and \"manual\" implementation for the CPU)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"wasserstein_ext = torch.utils.cpp_extension.load_inline(\"wasserstein\", cpp_sources=\"\", cuda_sources=cuda_source,\n",
" extra_cuda_cflags=[\"--expt-relaxed-constexpr\"] )\n",
"\n",
"def sinkstep(dist, log_nu, log_u, lam: float):\n",
" # dispatch to optimized GPU implementation for GPU tensors, slow fallback for CPU\n",
" if dist.is_cuda:\n",
" return wasserstein_ext.sinkstep(dist, log_nu, log_u, lam)\n",
" assert dist.dim() == 2 and log_nu.dim() == 2 and log_u.dim() == 2\n",
" assert dist.size(0) == log_u.size(1) and dist.size(1) == log_nu.size(1) and log_u.size(0) == log_nu.size(0)\n",
" log_v = log_nu.clone()\n",
" for b in range(log_u.size(0)):\n",
" log_v[b] -= torch.logsumexp(-dist/lam+log_u[b, :, None], 0)\n",
" return log_v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use this update step in a building block for the Sinkhorn iteration:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class SinkhornOT(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, mu, nu, dist, lam=1e-3, N=100):\n",
" assert mu.dim() == 2 and nu.dim() == 2 and dist.dim() == 2\n",
" bs = mu.size(0)\n",
" d1, d2 = dist.size()\n",
" assert nu.size(0) == bs and mu.size(1) == d1 and nu.size(1) == d2\n",
" log_mu = mu.log()\n",
" log_nu = nu.log()\n",
" log_u = torch.full_like(mu, -math.log(d1))\n",
" log_v = torch.full_like(nu, -math.log(d2))\n",
" for i in range(N):\n",
" log_v = sinkstep(dist, log_nu, log_u, lam)\n",
" log_u = sinkstep(dist.t(), log_mu, log_v, lam)\n",
"\n",
" # this is slight abuse of the function. it computes (diag(exp(log_u))*Mt*exp(-Mt/lam)*diag(exp(log_v))).sum()\n",
" # in an efficient (i.e. no bxnxm tensors) way in log space\n",
" distances = (-sinkstep(-dist.log()+dist/lam, -log_v, log_u, 1.0)).logsumexp(1).exp()\n",
" ctx.log_v = log_v\n",
" ctx.log_u = log_u\n",
" ctx.dist = dist\n",
" ctx.lam = lam\n",
" return distances\n",
"\n",
" @staticmethod\n",
" def backward(ctx, grad_out):\n",
" return grad_out[:, None] * ctx.log_u * ctx.lam, grad_out[:, None] * ctx.log_v * ctx.lam, None, None, None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also define a function to get the coupling itself:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_coupling(mu, nu, dist, lam=1e-3, N=1000):\n",
" assert mu.dim() == 2 and nu.dim() == 2 and dist.dim() == 2\n",
" bs = mu.size(0)\n",
" d1, d2 = dist.size()\n",
" assert nu.size(0) == bs and mu.size(1) == d1 and nu.size(1) == d2\n",
" log_mu = mu.log()\n",
" log_nu = nu.log()\n",
" log_u = torch.full_like(mu, -math.log(d1))\n",
" log_v = torch.full_like(nu, -math.log(d2))\n",
" for i in range(N):\n",
" log_v = sinkstep(dist, log_nu, log_u, lam)\n",
" log_u = sinkstep(dist.t(), log_mu, log_v, lam)\n",
" return (log_v[:, None, :]-dist/lam+log_u[:, :, None]).exp()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define some test distributions. These are similar to examples from [Python Optimal Transport](https://github.com/rflamary/POT/)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABEjUlEQVR4nO3dd3xUVf7/8deZSe+kkZAACSGE3jsoTbqAuth1F2miIFhX3eKuq36ta0FRpIqiIiooIkrvNZTQCSShJJRU0kjPnN8fN+GHbEIGMsmdcp6PxzzIzNyZ+7lM8p57zz33HCGlRFEURbFfBr0LUBRFUeqWCnpFURQ7p4JeURTFzqmgVxRFsXMq6BVFUeyck94FVCUwMFBGREToXYaiKIrN2LdvX4aUMqiq56wy6CMiIti7d6/eZSiKotgMIcTZ6p5TTTeKoih2TgW9oiiKnVNBryiKYuesso1eURSlLpWWlpKSkkJRUZHepdw0Nzc3wsPDcXZ2Nvs1KugVRXE4KSkpeHt7ExERgRBC73LMJqUkMzOTlJQUIiMjzX6darpRFMXhFBUVERAQYFMhDyCEICAg4KaPRFTQK4rikGwt5CvdSt2q6cZGlZWbSEy/wuHzOVwpLmNMx0b4ebjoXZaiKFZIBb2NKTdJ3lsTz8LtpykqNV19/K3fTnB/t8ZM6BtJY38PHStUFMXaqKC3IVeKy5ixJI51x1MZ1aERA1sG0S7Ml5IyybxtSSzedZZv9pxj4bhu9GkeqHe5iqJYCbPa6IUQw4QQ8UKIBCHES1U8L4QQMyuePySE6Hzd80YhxAEhxEpLFe5oLuYUcu/snWw4kcq/R7Xm4wc7cXencJoHe9O6kQ/v39eRLX8dQGSAJ5O/3MvhlBy9S1YUpQb9+/cnPj4egMzMTNq2bVsn66lxj14IYQRmAYOBFCBWCLFCSnnsmsWGA9EVtx7AZxX/VpoBHAd8LFS3QyktNzHpy72cyypg/rhuDIgJrnK5Rn7uLBrfnT99toNxC/fwwxO9iQz0rOdqFcW2vPrLUY5dyLXoe7Zu5MO/RrWpcbmEhASio6MBOHToEO3atbNoHZXM2aPvDiRIKZOklCXAEmDMdcuMAb6Uml2AnxAiFEAIEQ6MBOZZsG6H8smGBI6cz+W9e9tXG/KVQnzd+GpCdyTw6PzdZBeU1E+RiqLclLNnzxIWFobBoMXwoUOHaN++PUlJSUyYMIGxY8dabF3mtNGHAcnX3E/hj3vr1S0TBlwEPgT+CnjfcpUO7HBKDp9sTOCujo0Y1jbUrNc0C/Ji/l+68qfPdvD+2pP8Z0zdHA4qij0wZ8+7LsTFxdG+ffur9/ft28f9999Ps2bNmD9/vkWD3pw9+qo6bUpzlhFC3AmkSSn31bgSISYLIfYKIfamp6ebUZb9Kyot59mlcQR5ufLq6JsL605NGvBoz6Ys3nWW4xcte1iqKErtHTx48OqFT6dOneLnn3/WtekmBWh8zf1w4IKZy/QBRgshzqA1+QwUQiyuaiVSyjlSyq5Syq5BQVWOne9wPlp/ilNp+bw9tj2+HuaPa1HpmcEt8HV35t8rjiLl9d/NiqLoKS4uDpPJRIcOHfjPf/5Dq1atWLRoUZ2sy5ymm1ggWggRCZwHHgAeum6ZFcA0IcQStGadHCnlReDlihtCiP7A81LKRyxTun3LyC9m4fbT3N0pjH4tbu2Lz8/DheeHxvD35Uf49fBF7mzfyMJVKopyqw4dOsSBAwfw9v5jq3ZmZiZ///vfOXDgAG+++SYvv/xyrddVY9BLKcuEENOA1YARWCClPCqEmFLx/GxgFTACSAAKgMdqXZmDm7s1iZIyE08NbF6r93mgWxO+3nWON349zsCWwXi4qEsnFEVveXl5GAyG/wl5gICAAGbPnm3R9Zn1Vy+lXIUW5tc+NvuanyUwtYb32ARsuukKHVDWlRK+2nmWO9s3olmQV63ey2gQ/GtUa+6fs4ulscmM62P+iHeKotQNb29vTp48WW/rU4OaWaEF205TUFLOtFruzVfq0SyAzk38WLD9DOUm1VavKI5GBb2VySksZdGOMwxvG0KLhpbrkTqhbzPOZRWw7niqxd5TURTboILeynyx/Qx5xWUW25uvNLRNQ8L83Jm/7bRF31dRFOungt6KlJab+GrXWQbEBNGmka9F39vJaOCxPhHsOZ3FkfNqHBxFcSQq6K3IhhNpZOQX80jPpnXy/vd1a4yni1Ht1SuKg1FBb0W+i02moY/rLfebr4mPmzP3dWvMLwcvcCnH9iZFVhTl1qigtxIXcwrZFJ/GvV0a42Ssu49lXO8IykySH/en1Nk6FEWxLirorcQPe1MwSbiva+OaF66FpgGedI/wZ9n+FDUsgqI4CBX0VsBkkizdl0zvqACaBNT9NIB3dw67Ot+soij6sZqJR5S6tzMpk+SsQp4fElMv6xvRLpR/rTjKsv3naR/uVy/rVBSr9dtLcOmwZd8zpB0Mf6vGxaxp4hGlji2JTcbX3ZmhbULqZX2+7s7c0SqYXw5eoLTcVPMLFEWxuOomHvnpp5+YNGkSY8aMYc2aNRZZl9qj19mV4jLWHL3E/d0a4+ZsrLf13t0pnFWHL7HlZDqDWjWst/UqitUxY8+7LlQ38cjIkSO56667uHz5Ms8//zxDhgyp9brUHr3ONpxIo7jMxMh25s0eZSn9WgTRwMOZZQfO1+t6FUXR1DTxyOuvv87UqTccK9JsKuh1turwRYK8Xeka4V+v63VxMjCqQyPWHkslt6i0XtetKEr1E49IKXnxxRcZPnw4nTt3tsi6VNONjq4Ul7ExPo37ujbGaKhqNsa6dXenML7ceZbfj1yq826diqL8UXUTj8ycOZN169aRk5NDQkICU6ZMqfW6VNDraMOJNIpKTYyo52abSh0b+xHm585qFfSKUq9uNPHI9OnTmT59ukXXp5pudLTq8EUCvVzpVs/NNpWEEAxp05CtCRnkF5fpUoOiOCI18YiDKCjRmm2Gtw3Rpdmm0tA2IZSUmdgcn65bDYqi1C0V9DrRu9mmUtemDfD3dGHNsUu61qEoSt1RQa+Tymab7pH6NNtUcjIaGNQymA0n0igpUxdPKYo9UkGvg6LScjaeSGdom4a6NttUGtomhLyiMnYmZepdiqIodUAFvQ52JmVSWFrOHa2t44rUvtGBeLgYWX1UNd8oij1SQa+DjSfScHc20qtZgN6lAODmbKR/TBBrj6ViMqmhixXF3qigr2dSStYfT6NP88B6HdumJkPbhJCeV8yB5Mt6l6IoioWpoK9nJ1PzOZ9dyKBWwXqX8gcDWgbjZBCsPZamdymKoliYCvp6tv5EKgADYqwr6H3cnOka0YBN8SroFaW+qIlH7NSG42m0DfMhxNdN71L+x4CYYN787QQXcwoJ9XXXuxxFqRdv73mbE1knLPqeLf1b8mL3F2tcTk08YoeyrpSw/9xlBra0jt421+tfcZShrpJVlLpX3cQjx48fZ8qUKYwdO5bPPvvMIutSe/T1aPPJNEwSBrW0rmabSi0aehHq68am+HQe6N5E73IUpV6Ys+ddF6qbeKRVq1bMnj0bk8nEpEmTLLIutUdfjzacSCfQy5V2Yb56l1IlIQT9Y4LZlpChrpJVlDp2o4lHVqxYQd++fRk0aJBF1qWCvp6UlZvYHJ/GgJggDFZwNWx1+scEkV9cxr6zqpulotSl6iYeARg9ejQ7duzg66+/tsi6VNNNPTmYkk1uURkDrLTZplKf5oE4GwWbTqbRK8o6LuhSFHtU3cQjmzZtYtmyZRQXFzNixAiLrEsFfT3ZcjIDg4DeVh6eXq5OdG3qz+b4dF4e3krvchTFLt1o4pH+/fvTv39/i65PNd3Uk62n0mkf7oefh4vepdRoQMsgTlzK40J2od6lKIpdUhOP2KGcwlLikrO5PTpQ71LMcrWb5UnVzVJR7IEK+nqwIyEDk4TbWgTpXYpZooO9aOTrxhYV9IpiF1TQ14MtpzLwdnWiY2M/vUsxixCCPs0D2ZGYSbkazVKxU1La5u/2rdStgr6OSSnZcjKdXlEBOBtt57+7b3QgOYWlHDmfo3cpimJxbm5uZGZm2lzYSynJzMzEze3mhlBRvW7q2JnMAs5nFzKlf5TepdyUPs218wnbEjLoYCNHIopirvDwcFJSUkhPt73mSTc3N8LDw2/qNSro61hlO7etnIitFOjlSqtQH7adymDqgOZ6l6MoFuXs7ExkZKTeZdQbs9oShBDDhBDxQogEIcRLVTwvhBAzK54/JIToXPG4mxBijxDioBDiqBDiVUtvgLXbeiqdJv4eNA3w1LuUm3ZbdCD7zl6msKRc71IURamFGoNeCGEEZgHDgdbAg0KI1tctNhyIrrhNBiqHXCsGBkopOwAdgWFCiJ6WKd36lZab2JmYye0tbGtvvlLf5oGUlJvYfVpNGq4otsycPfruQIKUMklKWQIsAcZct8wY4Eup2QX4CSFCK+7nVyzjXHGzrbMftRCXnM2VknL6NrfNoO8W4Y+L0cD2hAy9S1EUpRbMCfowIPma+ykVj5m1jBDCKISIA9KAtVLK3VWtRAgxWQixVwix1xZPkFRle0IGQkBPK5kE/Ga5uxjpGtGAradU0CuKLTMn6KsaavH6vfJql5FSlkspOwLhQHchRJVzZUkp50gpu0opuwYF2caFRTXZkZBJ20a+NjHsQXX6Rgdy4lIe6XnFepeiKMotMifoU4DG19wPBy7c7DJSymxgEzDsZou0RQUlZRxIvkzv5ra5N1/ptubal+6ORLVXryi2ypygjwWihRCRQggX4AFgxXXLrAD+XNH7pieQI6W8KIQIEkL4AQgh3IE7AMtOzmilYs9cprRc0ifKNtvnK7Vu5IOfh7NqvlEUG1ZjP3opZZkQYhqwGjACC6SUR4UQUyqenw2sAkYACUAB8FjFy0OBRRU9dwzAUinlSstvhvXZkZCBi9FAtwh/vUupFaNB0KtZADsTtasIhbDeSVMURamaWRdMSSlXoYX5tY/NvuZnCUyt4nWHgE61rNEmbU/MoFMTP9xdjHqXUmu9owL47cglkrMKaRLgoXc5iqLcJNsZfMWGZBeUcPRC7tVhBGxdr4rmJ9VOryi2SQV9HdCaOax/NilzRQV5Euztyo5EdeGUotgiFfR1YEdiJp4uRrsZDEwIQe+oAHYk2t5of4qiqKCvE9sTM+ge6W9TwxLXpHdUIBn5xSSk5de8sKIoVsV+kshKXMopIin9it20z1fqVdEMpZpvFMX2qKC3sJ1J2gnLXnbSPl+psb8Hjf3d1bg3imKDVNBb2I6ETPw8nGkV4qN3KRbXu1kgu5LU9IKKYmtU0FvYzqRMekT6YzDY34VFvZsHkFtUxrELuXqXoijKTVBBb0HJWQWkXC6kt40Pe1CdXs0q2+lV842i2BIV9Ba0s+JEpb21z1cK9nGjebCXOiGrKDZGBb0F7UzKJNDLhehgL71LqTO9mgUQeyaL0nKT3qUoimImFfQWIqVkZ2ImPZoF2PXAX72iAigoKedQSo7epSiKYiYV9BZyOuMKl3KLrrZj26vK2bJ2JanmG0WxFSroLWRnRfDZy/g21fH3dKFliPfV8xGKolg/FfQWsjMxk4Y+rkQGeupdSp3r2SyAvWezKC4r17sURVHMoILeAqSU7ErKpJedt89X6hUVQFGpiYPJqp1eUWyBCnoLOJWWT0Z+id12q7xez8gAhEA13yiKjVBBbwGVgWevF0pdz9fDmdahPlfH9VEUxbqpoLeAHYkZhPm509jfcabZ69UsgP3nsikqVe30imLtVNDXkskk2X06y2GabSr1igqgpMzE/nOX9S5FUZQaqKCvpeOXcskuKLX7/vPX6xbpj0HALtVOryhWTwV9Ldn7+DbV8XFzpl24nxr3RlFsgAr6WtqVlElEgAeN/Nz1LqXe9WoWwMGUbApKyvQuRVGUG1BBXwvlDto+X6lXVACl5ZK9Z1Q7vaJYMxX0tXD0Qg55RWVXx39xNN0iGuBkEKr5RlGsnAr6WqgMOEc7EVvJw8WJjo39ro7zoyiKdVJBXws7EzOJCvIk2MdN71J00zsqgMMp2eQWlepdiqIo1VBBf4tKy03EnslymKthq9MzKgCThNjTWXqXoihKNVTQ36JDKTkUlJQ77InYSp2bNMDFyaDa6RXFiqmgv0U7KybIdtQTsZXcnI10adJADXCmKFZMBf0t2pGYSatQH/w9XfQuRXe9ogI4fimXy1dK9C5FUZQqqKC/BUWl5ew9e9nuZ5MyV++oAKSE3afVXr2iWCMV9Ldg/7nLlJSZ6NNcBT1A+3A/3J2NqvlGUayUCvpbsCMhE6NB0C3CX+9SrIKLk4Fukf5sV0GvKFZJBf0t2JGYQftwX7zdnPUuxWr0jgogIS2ftNwivUtRFOU6KuhvUn5xGQdTclT7/HX6VFxPoLpZKor1UUF/k2JPZ1Fukg5/odT1Wjfywdfdme0JanpBRbE2Kuhv0o7EDFycDHRp2kDvUqyK0SDoHRXA9oQMpJR6l6MoyjWc9C7A1mxPyKRLkwa4ORv1LsXq9G4eyG9HLnEms4DIQM9bfp/0gnROXT7FqexTJOUkkVaQRmZhJllFWZSaSpFSYsKEh5MHfq5++Lr60sirEZE+kUT6RtLSvyUNPRtacMsUxbaZFfRCiGHAR4ARmCelfOu650XF8yOAAmCclHK/EKIx8CUQApiAOVLKjyxYf726fKWEYxdzeW5wC71LsUp9Ks5bbE/IuKmgzynOYdv5bey5tIfYS7Ek5yVffc7fzZ+GHg0JdA+kRYMWuBpdEUIgEBSUFZBdnE12UTabkjexrGjZ1deFeYXRObgzPUJ70C+8H35ufpbaTEWxOTUGvRDCCMwCBgMpQKwQYoWU8tg1iw0HoituPYDPKv4tA56rCH1vYJ8QYu11r7UZuyqG4+2t+s9XKTLQk1BfN3YkZvBIz6Y3XDa/JJ+1Z9ey+uxqdl/YTZksw9vFm64Nu/JAzAO0CmhFlF8U/m7md2HNKc7hdM5pjmQcYX/afnZc2MEvSb9gEAa6NOzCkKZDGB45HF9X39puqqLYFHP26LsDCVLKJAAhxBJgDHBtWI8BvpRa4+wuIYSfECJUSnkRuAggpcwTQhwHwq57rc3YmpCBl6sT7cP99C7FKgkh6B0VyIYTqZhMEoNB/M8yxzKPsTR+KatOr6KwrJAwrzAebf0odzS9gzYBbTAabr1JzNfVl47BHekY3JFHWj+ClJJjWcdYf3Y9G85t4I3db/De3vcY3HQw97a4l07BndAORhXFvpkT9GFA8jX3U9D21mtaJoyKkAcQQkQAnYDdVa1ECDEZmAzQpEkTM8qqf9tOZdCzWQDORnUOuzp9mgfw4/4Ujl3MpW2YtucspWTnhZ3MPTyXval7cTO6MTxyOGNbjKVdYLs6C1shBG0C2tAmoA3TO0/nWOYxlp1axq9Jv7IyaSXtg9ozvu14BjQegEGoz1SxX+YEfVV/hdd3q7jhMkIIL+BH4GkpZW5VK5FSzgHmAHTt2tXqum2cyyzgXFYB4/tE6F2KVevTvLI/fQZtGvmw9fxWPo37lKOZRwn2COaFri9wV/Rd+Lj41HttrQNa0zqgNc92eZYViSv44ugXPL3xaZr5NuOpTk8xqMkgtYev2CVzgj4FaHzN/XDggrnLCCGc0UL+aynlMmzUtor+4X2jg3SuxLo19HEjKsiTNQn72F34f+y5tIdwr3D+3evfjIoahYtR/9E+PZw9eKDlA4xtMZa1Z9fy2cHPeGbTM7QLbMczXZ6hW0g3vUtUFIsyJ+hjgWghRCRwHngAeOi6ZVYA0yra73sAOVLKixW9ceYDx6WU71uw7nq3LSGdUF8txJTqZRZm4ha6lPjSrfhdbsDL3V/m3hb34my0vuEinAxODI8czuCmg/kl8Rdmxc1i/OrxDI0YyvNdnyfEM0TvEhXFImpsmJRSlgHTgNXAcWCplPKoEGKKEGJKxWKrgCQgAZgLPFnxeB/gUWCgECKu4jbC0htR18pNkh2JmfRpHqgO7athkiaWxi9l1E+jOF+2k+KMfvy705c81Oohqwz5azkZnLg7+m5W3r2SqR2nsil5E6N/Gs2CIwsoM5XpXZ6i1JpZ/eillKvQwvzax2Zf87MEplbxum1U3X5vU45eyCG7oJTbotWwB1U5m3uWf2z7B3HpcXQP6c4znV7krg8T2Hu6iEExeldnPjcnN6Z0mMKdze7k7di3+WDfB/x++nde6/MaMf42tCGKch3V1cAMW09p7fNqfJs/MkkTi48tZuyKsSTmJPJG3zeYN2QebYNb0LlpA7aeSte7xFsS7h3OxwM/5v3+75NakMoDvz7AZwc/o9RUqndpinJLVNCbYdupDFqGeBPk7ap3KVYj9Uoqk9ZM4u3Yt+ke2p2fxvzE6KjRV5u2bo8O5Mj5XDLzi3Wu9NYNbjqYn8b8xJCmQ/g07lPG/T7uD1ftKoqtUEFfg8KScvadvayaba6xKXkTY38Zy+GMw7za+1U+GfgJwR7Bf1jmtoreSdtsfDTLBm4NePv2t3m337uczj7Nvb/cyy+Jv+hdlqLcFBX0NdhzJouScpPqVgmUlpfy9p63eWrDU4R6hvLdnd9xT/Q9VZ6gbhvmi5+H89VmL1s3LGIYP4z+gZgGMfxt2994ZfsrFJWpSVYU26CCvgZbT6bj4mSgu4NPG5hWkMb41eNZfHwxD7d6mMUjFhPpG1nt8kaDoE/zQLaeSrebYYsbeTViwdAFTG4/meUJy3n0t0dJzlVNOYr1U0Ffg00n0+kR6Y+7i+MOSxx7KZb7frmP+MvxvNvvXV7q/pJZFz7dHh1Iam4xp9Ly66HK+mE0GHmq01PMGjSLC/kXuH/l/WxJ2aJ3WYpyQyrobyDlcgEJafn0a+G4zTbfnfiOyWsm4+3izbcjv2VYxDCzX1vZ3LXlpG32vrmR28NvZ+mopYR7hzNt/TTmH55vN0cuiv1RQX8DW05q7cv9Yxwv6EtNpby+63Ve3/06vRr14puR3xDlF3VT7xHm505UkKfNn5CtTphXGIuGL2JoxFA+3P8hL219SbXbK1ZJzTB1A5vi0yrCykvvUupVTnEOz216jt2XdvNYm8eY0XnGLQ8ffFt0EEtiz1FUWm6Xs3K5O7nzzu3vEOMfw8z9M0nJS+GjgR8R6K56aSnWQ+3RV6OkzMSOxEz6xQQ51LAHKXkpPPrbo+xL28frfV7n2a7P1mqM+P4xQRSVmth9OsuCVVoXIQQT203kgwEfcPLySR5Z9QiJ2Yl6l6UoV6mgr8b+c5fJLy5zqPb5w+mHeXjVw2QWZjJn8BzGNB9T6/fs2SwAN2cDG0+kWaBC6zaoySC+GPYFxeXFPLLqEXZd3KV3SYoCqKCv1qb4dJwMgt5RjjFt4ObkzYxfPR53J3e+GvGVxYbqdXM20jsqkE3x9h/0AG0C2/DNiG8I8QzhiXVPsCppVc0vUpQ6poK+GptPptOlaQO83ax75EVLWH5qOTM2zqCZXzO+HvE1zXybWfT9B8QEcSazgKR0++lmeSOhXqEsGr6IDkEdeHHriyw6ukjvkhQHp4K+Cqm5RRy/mEv/mOCaF7ZhUkrmHprLKzteoUdoDxYOXUiAu+WPYCr/HzfG2183y+r4uPjw+eDPGdJ0CO/tfY/3Yt9T3S8V3aigr8Lmin7f9tw+b5Im3ol9h5kHZnJnszv5ZOAneDh71Mm6Gvt7EB3s5TDNN5Vcja682+9dHmz5IIuOLeKf2/+pxrdXdKG6V1Zh44k0Gvq40irUW+9S6kSZqYx/7fgXKxJX8EirR3ih2wt1Pjn2gJbBfLH9DFeKy/B0dZxfO4Mw8HL3l2ng1oBP4z4ltySXd/u9i6tRjYSq1B/H+YszU3FZOVtOpjOmU5hddqssLi/mhc0vsDF5I1M7TuXx9o/Xy3b2jwlizpYktidkMKSNY03RJ4TgiQ5P4Ofqx5u73+TJdU8yc+BMPJ09wWSComwoyNL+Lc6DkitQWgjlJdqt8ihACBAGcHIDJ1dwcgdXb+3m5gMeAeDqoy2nKNdQQX+dXUlZXCkp545W9tc+X1BawPSN09l9cTcvd3+Zh1pdP/Vv3ena1B8vVyc2xqc7VtAX50NWImSd5sHLqfh4teHvl/YweXEfPs0qwDc/HaTJcuszuoBHIPiEgk8j8AkDv6bQIAL8I8G/mfYloTgUFfTXWXcsFfeKLoH2JLckl6nrpnIo4xBv9H2D0VGj63X9Lk4GbovWullKKe3vaKm8FNJPwKUjkHoEUo9CxknIPf+HxUa6+eHuH8LzroWMb+jP5+3uJdA7TNsbd/Ot2EP3AmcPLbSNLnDtBWumcigvhrJiKC3QvkiK8yqOCjLhSjpcyYDcC5B+EhI3Qsk1vZ2EQQv9wBho2BpC2kHDdtoXgEGdsrNXKuivIaVk/fFU+kYH2tXl+peLLvP42sc5lX2K9/q9x+Cmg3WpY0DLYH47comjF3JpG+arSw0WISXkJMO53ZCyBy4cgEuHoXKcG6MrBLeEiNsgsDkERGtB2qApuPkyEJh1YSczNs7gsSuHmNvnKUI86+goR0qtWejyacg6rX35ZMRDejycWgOyXFvO1QcadYRGnSG8GzTpCZ72tbPjyFTQX+PYxVwu5BTx9B0t9C7FYjIKM5i0ZhLJecnMHDCT28Jv062WQS2DMQhYeyzVtoJeSshKgjNb4fRWOLsd8i5qzzl7agHZbSI06gQh7bVQN974T6tXo158Pvhznlj3BON+H8f8ofMJ8wqzfO1CgGeAdgvv+sfnSou0o5DUI9qX1fl9sHMWVM6NG9AcmvaGiNsh8jbwdqAmNzsjrLFvb9euXeXevXvrfb0z15/ig3Un2fO3O+xiftjUK6lMXDOR1IJUPh74MT1Ce+hdEvfN3klecRm/zdDvC8cshdmQtBESN2jNHzkVE4x4NYSmfbQAbNwdgtvUGOo3ciTjCI+vfRwPZw/mDZlHU5+mlqn/VpUWwcU4OLcLzu2EszuhOEd7LjAGmg+CqEEQ0Qec3XUtVfkjIcQ+KWXXKp9TQf//jf5kG0aDYPmTfep93ZZ2If8CE1ZP4HLxZT4d9CmdG3bWuyQA5m1N4vVfj7P1rwNo7F83/fZvWcYpiF8FJ1drQSfLwdUXmt0OzfpDZD9tL9fC5xfis+KZtGYSTgYn5g2dZ/Erk2vFVA4XD2pHM4kb4ewO7RyBkzs06wcthkKLYdqJX0VXKujNkJpbRI//W88LQ2OYOqB5va7b0pLzkpm4eiJ5JXnMHjyb9kHt9S7pqrOZV+j37ib+eWdrJvStfirCeiGl1mRx/Bc4sVJrvwbt5GSLIRA9BMK61mqP3VwJlxOYuGYiEsm8IfOIbhBd5+u8JSUFWtifWgMnf4Psc9rjYV2g5Z3QarR2XkKpdyrozfDN7nP8bflhVj99OzEhtnuh1Lncc4xfPZ7CskLmDplL64DWepf0P4Z9uAVfd2e+e7xX/a9cSm0P9egyOPoTZJ8FYYSIvlpQtRwBvuH1XxeQlJPExNUTKTOVMXfIXGL8Y3Spw2xSam388au0L8sLB7THQ9pBm7uhzT1al06lXqigN8O4hXtITM9nywsDbLbr3+mc00xcPZESUwnzhsyz2qB4f008n2xMYO8/BuPvWfPcsxaRlQSHvofD30PmKTA4ac0xre+CliPBwzomfz+be5YJqydQVF7EnMFzrPKLulrZyVrgH10GKbHaY2Fdof19WvB72d+1KdZEBX0NcgpK6frGWsb3ieTlEa3qbb2WlJSdxIQ1EzBJE3OHzKVFA+vtOXQ4JYdRn2zj3bHtubdr47pbUWE2HF0OB5dA8i5AaCdS242F1mOsJtyvl5yXzITVE8gvzWfu4Lm0CWyjd0k3L/scHFkGh3+A1MPaUVPzO6DjQxAzXF20VQdU0Nfgh30pPP/9QX6a2oeOjf3qbb2WUtm+CzB/6Pybntu1vkkp6f3WBtqG+TL3z1X+Xt46kwnOboP9X8HxFVrf9sAY6PggtLtXt2aZm3U+/zwTVk8gtziXzwd/TrugdnqXdOvSjmtftoe+07qluvlpe/md/6w18ygWoYK+BhO+iOXEpTy2vWh7zTYnL59k0ppJGIXR+nps3MC/fj7CkthkDrwyGA8XC5zszEuFuK9h/yK4fEbrLdNuLHR6WLsIyMY+V4CL+RcZv3o82cXZzB48mw5BHfQuqXZM5ZC0Sfucjv+ijeMT2hG6jNM+K1fbPTdmDW4U9A5/zXNuUSlbT2UwvG2IzYV8fFY8E1ZPwEk4sWDoApsJeYChbUMoLjOx8UQtxqiXEpI2w9I/wwetYf2r4BMOd8+B507Ane9rvUFs7HOtFOoVysJhC/F38+fxtY8Tlxand0m1YzBq/fDHLoDn4mH4O9rQESufhv+2hF+e1q4wVizO4YN+w/E0SspNDG8XqncpN+V45nEmrJmAq9GVhcMWEuEboXdJN6VHZACBXq6sPHTh5l9clAu7ZsOs7vDlaDi9BXpMgWl74bFfocP94GJlffRvUYhnCAuGLiDIPYjH1z7O/tT9epdkGR7+0ONxeGI7TFinnTM5uARm94X5Q7UT52UleldpNxw+6FcdvkiIjxudbKht/mjmUSaumYiHkwcLhy2kiU8TvUu6aUaDYGS7EDacSCO/2MzJONJPwq/Pw/ut4PcXtUP9u2bDs8dh6BsQaKV9z2upoWdDFgxdQEPPhkxZN4XYS7F6l2Q5QkDjbnDXp/DccRj6f3AlDZZNhA/awMY3Ie+S3lXaPIcO+vziMjadTGdY2xAMBts4vD+cfphJqyfh5ezFwmELaexdh71W6tioDo0oLjOx7lhq9QuZTHBqHXx1D8zqprXBtxoFkzbCpA3aSVYHuBQ/yCOIBUMX0MizEU+ue5LdF3frXZLluTeAXlNh2j545Edt7KDNb2uB/+MkOG8nRzM6cOig33AijZIyEyNspNkmLi2OyWsn4+vqy8JhC+tmEKx61LlJA0J93apuvikthL0L4dMe8PWftIG3BvwdnjkGd8+GMOsY0qE+BboHMn/ofMK9w5m6fio7zu/Qu6S6YTBoXTEfXgpP7YNukyD+N5g7ABYMg2MrtBO7itkcOuh/PXSBIG9XujRtoHcpNdqfup/H1z6Ov5s/C4ctpJGX7Y8tYjAIRrYLZfPJdHIKKkZMzE+DDW9oe3Ern9b21u+eA08fgX5/BS/7ncfXHAHuASwYuoAInwie2vAUW1O26l1S3QqIguFvwbPHYOib2jj7Sx+FjzvD7s+18fiVGjls0OcUlLLxRDqj2jfCaOXNNrGXYpmybgrBHsEsGLqg7sYu18GoDo0oLZfs2LMTfpkBH7SFLe9C4x4w7leYvFk7uepUT1fQ2oAGbg2YN2QeUX5RzNg4g43nNupdUt1z84FeT8L0A3Dfl+AZDL/9VdshWP+atoOgVMthg37l4QuUlJu4p7N1N3/suLCDJ9c9SSPPRiwctpCGng31Lsmi2st4vvT8iKGbRkPct9qVk9Ni4cFvtfFnbLRrZF3zc/PTxsNpEMOzm55l7dm1epdUPwxGrYfOxLUwYa02Tv7W/2o7CCumQ0aC3hVaJYcN+uX7zxMd7EWbRj56l1KtrSlbeWr9UzTxacKCYQsIdLeTGX+khJNrYMFwxIIhdOM4n5bfRdbk/TDqQ7vtPWNpvq6+zBkyhzaBbXhh8wusSlqld0n1q3F3uH+x1o7f8SGte+YnXeG7RyBln97VWRWHDPpzmQXsPXuZuzuHWe1FUuvPrmf6xulE+UUxf8h8/N2sc1yWm1JepvWP/qwPfHOvNh7K0Dc5N24P75Xey88JpXpXaHO8Xbz5fPDndAzuyEtbX+KnhJ/0Lqn+BURpOwjPHIHbntWuq5g3EBaN0sbQt8Kr/+ubWUEvhBgmhIgXQiQIIV6q4nkhhJhZ8fwhIUTna55bIIRIE0IcsWThtbH8wHmEgLs6Wmezza9Jv/Lc5udoE9CGeUPn4efmp3dJtVNaBLHztRNoyyZqE3rcNRtmxEGvJ4lpHErbMB++35uid6U2ydPZk8/u+IyeoT355/Z/8t2J7/QuSR9ewTDoFXjmKAx5Xbvu4qu7YE7/ip46Jr0r1E2NQS+EMAKzgOFAa+BBIcT1Y6cOB6IrbpOBz6557gtgmCWKtQQpJcsPpNAzMoBGftbX/3r5qeW8vPVlOjfszJzBc/Bxsd6mpRoV58H2j+Cj9vDrs+AZBA98A0/s1Pq/G52vLnpvl8Ycu5jL0Qs5OhZsu9yd3Pl40Mf0D+/P67tf54sjX+hdkn5cvaH3U/D0IRj1ERTnaj11Pu0Bcd9owy44GHP26LsDCVLKJCllCbAEGHPdMmOAL6VmF+AnhAgFkFJuAbIsWXRtxCVncyazgLs7Wd/e/OJji3llxyv0btSbWYNm4eFso5fxF2TBxv/TTpCtfQWCW8FffoGJ67Sx3w3/+2s3pmMjXIwGtVdfC65GV97v/z5Dmg7hv/v+y6y4WVjjoIX1xslVGzBt2l5tfB2jK/z0BMzsDHvmatdqOAhzgj4MSL7mfkrFYze7zA0JISYLIfYKIfamp9dioKsaLD9wHlcnA8PbWU8XRSklnx/8nLdj3+aOJncwc+BM3J2s72ijRnmXYM0/tIDf/LbWa2bSBvjzzxB5+w170Ph5uDC4dUN+jjtPSZnjHmLXlrPRmXduf4e7mt/F7IOzeXfvu44d9qD11Gn7J5iyFR76HnxCYdXz8GF77YizOE/vCuucOePDVvXXef1vjjnL3JCUcg4wB7Rhim/mteYqLCln+YHzDG0Tgrebc80vqAdSSj7Y9wELjy5kdNRoXu39Kk6Gup+j1KKyz2l/MPu/AlOp9kfV91loeHOzI43tGs6vhy+y/niqzQ0yZ02MBiOv9n4VDycPvjr2FVdKr/BKz1cwGox6l6YvISrmAh4MZ7fDlve0I86t70PPJ6D7ZKudjKa2zEmUFODaAVXCgeuvWTdnGd2tPHSBvKIyHu5hHYOAlZnKeG3Xayw7tYwHYh7g5R4vYxA21BEqIwG2va9NKIHQ2t37PK31grgFt0cH0dDHle/3paigryWDMPBS95e0XjmHPievJI+3bnsLF6O68AwhtKPNiL5aN8yt78GmN2HHx9Btojbejp1Ne2hO0McC0UKISOA88ADw0HXLrACmCSGWAD2AHCnlRYtWagFf7z5H82Avukfq/61dXF7Mi1teZP259UzpMIUnOzxptV09/8elI9pFKkeXa+2g3SZqJ79qOXuT0SC4p3M4n29OJC23iGAfNwsV7JiEEEzrNA1fV1/eiX2H/JJ8Phzwoe2e+6kL4V20i/Mqf6e3fwS7Z0Pnv0Cf6TYzI1lNatx9lFKWAdOA1cBxYKmU8qgQYooQYkrFYquAJCABmAs8Wfl6IcS3wE4gRgiRIoSYYOFtMMuR8znEJWfzcI8mugdqfkk+T657kvXn1vNS95eY2nGq7jWZJWUffPsgzO4Dp9ZCnxnw9GEY/rbF/iDu7RKOScJ3sck1L6yY5dHWj/J6n9fZc2kPE1ZPIKvIavpGWI+QtnDvQu3EbduxsHc+fNQRVjylTSxv4xxmKsG/LT/Msv0p7H75Dnw99GufTy9I54l1T5CYnch/+vyHUVGjdKvFLFLCmW3a4W3SJm2+z55PaJNGuNfNYHCPzt/NydQ8tr04EGejDTVlWbmN5zbywpYXCPEMYfYdswn3to+91TpR1Xmn257TepBZKYefSjC/uIyfD5xnVPtGuob8mZwzPPrbo5zLO8fHgz627pC/OkzBUFh0J6Qeg8GvaVcf9n+pzkIeYFzvCFJzi1l9VE04YUkDmgxg7pC5XC66zKO/PcqJrBN6l2S9/JrAyP9qffF7TYUTq+DTnrDkYZscF98hgv6nA+e5UlLOQzqehD2YfpA///ZnCssKWTB0AX3D+upWyw2ZyuHoT/D57dowBbkXYMR72i98n+n1MoFz/5hgmvh7sGjHmTpfl6PpFNyJRcMWYRRGxv0+jh0X7HRMe0vxDtGusn3mCPR7Ec5s1cbF/+pu7UjXCltEqmL3QS+lZPGus7QK9aGjTtMFrju7jgmrJ+Dl4sWXw7+kbWBbXeq4obISOPA1zOoB3/8FSgtgzCxtWNjuk+p1FiejQfDnXk2JPXNZXSlbB5o3aM7iEYsJ8wpj6rqpjjk+zs3y8IcBf9PmRbjjVW0S8y9Gake8J1dbfeDbfdBvOZXBiUt5TOgbWe8nPKWUfHXsK57d9Cwx/jEsHrGYpj5N67WGGpVcgV2fwcxO8POTWqDfuwim7oFOj/xhmIL6dG/Xxrg7G9VefR0J8Qxh0bBFdA3pyj+3/5NP4z5VF1aZw80H+j6tdUIY8R7kXoRv7tMmNT/8gzZwnxWy+6D/fHMiIT5ujO5QvzMylZpKeWP3G7wT+w4Dmwxk3pB51jUCZUEWbH4HPmwHv78EDZrCwz/C41ugzV3a1YQ68nV35u7OYfwcd4HLV0p0rcVeebl48emgTxkTNYbPDn7GS1tfori8WO+ybIOzu3akO32/NkBfeSn8OAE+6QJ7F2gD+VkRuw76wyk57EjMZHzfCFyc6m9Tc0tymbpuKt/Ff8djbR7jv/3+az1DGuSch9V/14Yp2PgGhHWF8avhsVUQfYdVTfQxrncExWUmvtp1Vu9S7Jaz0ZnX+rzGjM4zWHV6FRNXTySzMFPvsmyH0Vm7UPDJXXD/1+ARACuf0Xagtr4Phdl6VwjYeffKad/sZ3N8OjteHlhvQx6cyz3HUxue4lzuOf7Z65/cE31Pvay3RmnHYftMOLxUa09sN1brB9+wjd6V3dDERbHsPXuZ7S8OxNPVxoaGsDFrzqzhb9v+RoBbADMHziTGP0bvkmyPlNoJ220fQuJ6cPGGruOg55PgU7etCg7ZvfJcZgGrDl/koZ5N6i3kd5zfwQO/PkBmUSZzhszRP+SlhNNb4ev7tK5hx37SrmKdEQf3zLH6kAeYOqA52QWlLFZ79XVuSMQQFg1bRJks49HfHmXd2XV6l2R7hNAG8Ht0mdYM2mIo7JylDaD205PaDpcO7Dbo529LwmgQjO8TWefrklKy6Oginlj/BCGeISwZuYRuId3qfL3VKi+DIz9q3cAW3Qnn90H/v2kTMgx/W+sjbCM6NWnAbdGBzN2aRFFpud7l2L02gW1YMnIJ0X7RPLPpGWbFzcIk1WiityS0A4ydr/Vc6/qYNmTIpz1h8VhI2lyvPXXsMugv5RSxJDaZuzuF0bCOx0spKC3gxS0v8t7e9xjUZBCLhy/W74rDolxt72FmJ/hhvHb/zg8qLnJ60WZH5ps2oDkZ+SUs2XNO71IcQpBHEAuGLWBM1BhmH5zNtPXTyClW3VxvWYMIGPGutqM14B9wMQ6+HA2f36bNc1tW950N7LKN/uVlh/lhXzIbnutPY/+6G8DpdM5pntn4DKdzT/NUp6cY33a8PqNPXj4Luz+HA19ps+k07QO9pkGLYVVO8mGL7pu9k3NZBWz+a39cnRx8uN16IqVkafxS3op9ixCPED4c8KFqt7eE0iLtXNnOWZB+ArxDtSbVruNrtTPmUG30pzOusHRvMg/3aFqnIb/6zGoe/PVBsoqymH3HbCa2m1i/IS8lnNmuzXg/syPs+VxrD5y0UetB03KE3YQ8wLSBzbmUW8RSNdhZvRFCcH/L+1k4dCEl5SU8vOphfjz5o+pvX1vObtD5z1pPnYd/0MbP2fAavN8afplRJ10z7a4bw/trT+LqZGDqgOZ18v7F5cW8G/su38V/R/vA9rzX7z1Cvepx7PTSQq39ffds7eo89wZa75luk8DX+qZHtJTbogPpHuHPh+tOcVenMKuZOMYRdAzuyNJRS3lp60v8e+e/iU2N5ZWer6jhjmtLCG0SlOjB2knaXZ9Cxilt6G9Lr8oav51vtenm6IUcRs7cxrQBzXl+qOUPMZNykvjr5r8SfzmecW3GMb3zdJwN9RQ42ecgdj7s/xIKsyCoFfScAu3uAxfH+IM7lJLN6E+282T/KP46rKXe5TicclM5cw/P5bODn9HEuwlv3f4WbQKsv+eWTTGZbvlI/EZNN3a1R//u6nh83Z2ZdHszi76vlJLvT37Pu7Hv4ubkxqxBs7g9/HaLrqNKJhMkboDYeXDyd20PoOVI6P64NjuOFV3cVB/ah/txV8dGzNt2mod6NCG8gWN8wVkLo8HIlA5T6NKwCy9tfYlHVj3CjE4z+HObP9vWzGjWrI6aW+3m08ktKiU5q4An+kfh6265veysoiymb5zOa7teo3PDzvw4+se6D/n8dNj2AXzcCb7+E5zfq42F/fRhuH8xRN7mcCFf6YVhLRFoX+qKPrqFdOPHUT/SP7w//933XyavnczFfKubUE65hl013ZSVmzBJLDbcwdqza3l91+vkleTxTJdneLjVw3W352IywZktsG8RHP9Fm+ygaV+t/22r0eCk5vqs9O7qE8zamMhPU/voNiKpoh3p/njqR96JfQejMPJi9xcZEzXGNmZLs0M3arqxq6C3lMtFl3lz95v8duY3Wge05o0+b9C8Qd2c3CXnPBz8VusaefmMdnK1w4PQZRwEqa5sVckvLmPAe5sI8nLl52l91CxUOkvOS+Yf2/7B/rT99Avvxz96/oMQzxC9y3I4KujNJKXk19O/8s6ed8grzWNK+ymMbzfe8idcy4oh/jc4sFgbD0OaIOI2bULiVqO07lfKDf1+5BJTFu/jhaExddbDSjGfSZpYfGwxHx/4GCeDE890eYaxLcaqtvt6pILeDCl5Kby+63W2X9hO+8D2/Kv3v2jRoIXlViClNhRB3Dda98iibPAJg44PaTd/y55AdgRTv97P2mOp/Dq9L9EN637mK6VmybnJvLrzVXZf2k3n4M78s+c/6+5oWPkDFfQ3UFxezIIjC5h/eD5GYWRG5xncH3M/RkuNx551Gg4thUPfQVYiOLlpe+0dHoRm/XUf992WpecVM+SDzTQN8OTHJ3pjNKi2YWsgpWR5wnL+u/e/FJQW8EjrR3iiwxOq330dU0FfBSklm1M28/aet0nJT2FYxDCe6/qcZdoW81K1AYwOf6/1mAGtaab9fdB6DLj51n4dCgA/x51nxpI4Xh7eksf7ReldjnKNrKIsPtz3IcsTlhPsEcyzXZ5lROQIdbK2jqigv058Vjzv7n2X3Rd3E+kbyd96/I2eoT1r96ZXMuDYz1rAn92utbs3bAft/gRtx4JfY8sUr/yBlJInFu9n7fFUvp3Uk+6Rtjlwmz2LS4vj/3b/H8ezjtM+qD0vdnuR9kHt9S7L7qigr3Ax/yKfHvyUnxN+xsfVhyc7PMm9Mffe+snWvEtaV8jjK7RxZ2Q5BERDm7uh7Z8gWF29WR9yi0oZ88l2rhSXsXJ6X4K91clsa2OSJlYkruCj/R+RUZjBkKZDeKrTU0T4Ruhdmt1w+KDPLMxk3uF5fBf/HQLBgy0fZFL7Sfi63kITSmYinFgJx1dCSiwgIbCF1te9zd3aZB7q0LTenbiUy12zttM+3I9vJvbASXW5tEpXSq/wxdEvWHR0ESXlJdwTfQ+T209W3TEtwGGDPqMwgy+OfMHSk0spKS/hruZ3MaXDlJv7pTKVQ/IeOPkbxP8OGRVXZIZ2gJiRWpu72nO3CssPpPDMdwcZ1zuCf41qrdqCrVhGYQZzD81l6cmlCAT3RN/DxHYTVeDXgsMF/fn883x17Ct+OPkDpaZSRkaOZFL7SUT6mjnbVH66NsbMqTVaP/fCy2Bw0sZ5jxmujTdjQ7M0OZLXVh5j/rbTPD+kBdMGRutdjlKDC/kXmHd4HssTlgMwOmo049qMM/9vVbnKYYL+eOZxFh5dyJozaxAIRjYbyeT2k2niU0MolxVre+2JG7Rgv3hQe9wzCJpXDCPafJDqLWMDTCbJ898fZNmB87w2pg2P9orQuyTFDBfzLzL/yHx+SviJkvISBjYZyF/a/IWOQR3VkZmZHCLo80vyGfj9QAzCwL0t7uXhVg9XfxhoKtfC/MxWOL0Fzu6A0gIQRmjcHaIGQfOBENrJribvcBSl5SaeWLyf9SdSef++DtzdSaepHZWbllmYyTcnvmHJiSXkluTSOqA1D7d6mKERQ3E1Wn6cdnviEEEPsOPCDtoFtsPb5bqrJMvL4NJBLdDPbNf+rZwDM7CFduFSswEQ0UfttduJotJyHlsYy86kTP4xshUTb1NXHtuSgtICViat5Jvj35CYk4ivqy+jo0YzNnoszfzUZ1kVhwn6qwqzteEGknfDuZ2Qsg9Kr2jP+UdpgR5xuzamu089zg6l1Kui0nKe+S6O345cYnyfSP4xshUGdfWsTZFSsvvSbn44+QPrz62nzFRGh6AOjI4azdCIobfWc85OOUbQlxXDr89Cyl5IjwckCAM0bAtNemq3pn3AW53VdyTlJslrK4/xxY4zDG3TkHfGdrDofAVK/ckszGRF4gp+TviZxJxEnA3O3BZ2G8Mih9EvvJ/DD7HgGEEPMLsveDeC8G7QuBuEdQFXNdiVo5NSMn/bad767QQNfdz48IGOdItQV9DaKiklx7OO80viL6w+s5r0wnRcja70DevLgMYD6BfeDz83P73LrHeOE/SKcgNxydlM//YAKZcLeLJ/c6YNbI6bsxpUzpaVm8o5kHaA38/8zsbkjaQVpGEQBjoGdaRvWF/6hvUlxj/GIYZLVkGvKBXyikr514qjLNt/njA/d/4+shXD24aoLnx2QErJscxjbEjewNaUrRzPOg6Av5s/3UO60y2kG91DutPUp6ldft4q6BXlOjsSM/jPL8c4cSmP7hH+TB3YnNujA+0yABxVRmEG289vZ9fFXey5uIe0wjRAC/5OwZ3oFNyJtoFtaeXfyi7a91XQK0oVyspNfBubzCcbTpGaW0yrUB8m3RbJ8LahuLuoJh17IqXkbO5ZYlNjiUuLY3/qflLyUwAwCANRflG08m9FK/9WxPjH0KJBC5vr0aOCXlFuoKTMxE9x5/l8cyKJ6VfwcnViWNsQ7u4URvdIfzUnrZ3KKMzgaMZRDmcc5kjmEU5kniCzKPPq80HuQUT5RdHMtxkRvhFE+ETQxKcJIR4hlpuYyIJU0CuKGUwmya7TmSzff57fjlwiv7gML1cnekcF0C8miC5NGxAd7K1msrJj6QXpnMg6QWJ2IqeyT5GYncjpnNMUlBVcXcbJ4ESYVxhhXmGEeoYS5hVGiGcIDT0aEuwRTLBHsC5NQbUOeiHEMOAjwAjMk1K+dd3zouL5EUABME5Kud+c11ZFBb2it8KScjafTGfzyXS2nEznfHYhAJ4uRtqF+9IyxIfohl40D/KiaYAnwd6u6mIsOyWlJL0wndM5p0nOS756u5h/kQtXLpBVlPU/r/Fw8iDII4gAtwD83fzxd/PHz80PP1ft5uvqi4+Lj3Zz9cHL2QtXo2utzhHVKuiFEEbgJDAYSAFigQellMeuWWYE8BRa0PcAPpJS9jDntVVRQa9YEyklZzILOHDuMnHJ2RxMzuZUWj4FJeVXl3E2CkJ93Wno40qQtyuBXq74ebjg5+6Mr7szXm5OeLk64enqhLuzETdnA27ORlydDDgbDbg4GXAyCHUy2AYVlhWSVpBG6pVUUgtSSStII6Mwg/TCdLKKssgqzCKrKIuckhxM0lTt+1QeKay8e+Ut1XGjoHcy4/XdgQQpZVLFmy0BxgDXhvUY4EupfWvsEkL4CSFCgQgzXqsoVk0IQWSgJ5GBntzTWRsgzWSSXMwt4lRqHimXC0m5XMj57ELS84qIv5THtrwMcovKbnpdTgaB0SCu/lt5E0JgEGAQAlFRkxDaHDeCip8rHgft5/+/AVX+WOsvFfWVVBWfilvzPzzqCgRhQopCTOIKJkMBUhRgElcqHitEGorIL66bq7bNCfowIPma+yloe+01LRNm5msBEEJMBiYDNGmixnpXrJvBIAjzcyfMz73aZcpNktzCUnIKS8kvLuNKcRlXSsooLDFRVFpOUVk5JWUmSspMlJabKCmXlJWbKDdJykyS8sqblEipHVmUmyQSrt7Xfv7/jwFce4x+7RH7H47da3lqTtb2DRyWz42fddMv6Kv64r7+U65uGXNeqz0o5RxgDmhNN2bUpShWzWgQNPB0oYGni96lKA7OnKBPARpfcz8cuGDmMi5mvFZRFEWpQ+Z0EI4FooUQkUIIF+ABYMV1y6wA/iw0PYEcKeVFM1+rKIqi1KEa9+illGVCiGnAarQukguklEeFEFMqnp8NrELrcZOA1r3ysRu9tk62RFEURamSumBKURTFDtyoe6W6tltRFMXOqaBXFEWxcyroFUVR7JwKekVRFDtnlSdjhRDpwNlbfHkgkGHBcmyBI24zOOZ2O+I2g2Nu981uc1MpZVBVT1hl0NeGEGJvdWee7ZUjbjM45nY74jaDY263JbdZNd0oiqLYORX0iqIods4eg36O3gXowBG3GRxzux1xm8Ext9ti22x3bfSKoijKH9njHr2iKIpyDRX0iqIods5ugl4IMUwIES+ESBBCvKR3PXVFCNFYCLFRCHFcCHFUCDGj4nF/IcRaIcSpin8b6F2rpQkhjEKIA0KIlRX3HWGb/YQQPwghTlR85r3sfbuFEM9U/G4fEUJ8K4Rws8dtFkIsEEKkCSGOXPNYtdsphHi5It/ihRBDb2ZddhH0FZOQzwKGA62BB4UQrfWtqs6UAc9JKVsBPYGpFdv6ErBeShkNrK+4b29mAMevue8I2/wR8LuUsiXQAW377Xa7hRBhwHSgq5SyLdrw5g9gn9v8BTDsuseq3M6Kv/EHgDYVr/m0IvfMYhdBzzUTmEspS4DKScjtjpTyopRyf8XPeWh/+GFo27uoYrFFwF26FFhHhBDhwEhg3jUP2/s2+wC3A/MBpJQlUsps7Hy70ebJcBdCOAEeaLPS2d02Sym3AFnXPVzddo4Blkgpi6WUp9Hm/uhu7rrsJeirm5zcrgkhIoBOwG6gYcWsXlT8G6xjaXXhQ+CvgOmax+x9m5sB6cDCiiareUIIT+x4u6WU54H3gHPARbTZ6tZgx9t8neq2s1YZZy9Bb/Yk5PZCCOEF/Ag8LaXM1bueuiSEuBNIk1Lu07uWeuYEdAY+k1J2Aq5gH00W1apokx4DRAKNAE8hxCP6VmUVapVx9hL05kxgbjeEEM5oIf+1lHJZxcOpQojQiudDgTS96qsDfYDRQogzaM1yA4UQi7HvbQbt9zpFSrm74v4PaMFvz9t9B3BaSpkupSwFlgG9se9tvlZ121mrjLOXoHeYSciFEAKtzfa4lPL9a55aAfyl4ue/AD/Xd211RUr5spQyXEoZgfbZbpBSPoIdbzOAlPISkCyEiKl4aBBwDPve7nNATyGER8Xv+iC081D2vM3Xqm47VwAPCCFchRCRQDSwx+x3lVLaxQ1tcvKTQCLwd73rqcPt7It2yHYIiKu4jQAC0M7Sn6r411/vWuto+/sDKyt+tvttBjoCeys+75+ABva+3cCrwAngCPAV4GqP2wx8i3YeohRtj33CjbYT+HtFvsUDw29mXWoIBEVRFDtnL003iqIoSjVU0CuKotg5FfSKoih2TgW9oiiKnVNBryiKYudU0CuKotg5FfSKoih27v8ByhJZML+dtKsAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# some test distribution densities\n",
"n = 100\n",
"lam = 1e-3\n",
"x = torch.linspace(0, 100, n)\n",
"mu1 = torch.distributions.Normal(20., 10.).log_prob(x).exp()\n",
"mu2 = torch.distributions.Normal(60., 30.).log_prob(x).exp()\n",
"mu3 = torch.distributions.Normal(40., 20.).log_prob(x).exp()\n",
"mu1 /= mu1.sum()\n",
"mu2 /= mu2.sum()\n",
"mu3 /= mu3.sum()\n",
"mu123 = torch.stack([mu1, mu2, mu3], dim=0)\n",
"mu231 = torch.stack([mu2, mu3, mu1], dim=0)\n",
"cost = (x[None, :]-x[:, None])**2\n",
"cost /= cost.max()\n",
"pyplot.plot(mu1, label=\"$\\mu_1$\")\n",
"pyplot.plot(mu2, label=\"$\\mu_2$\")\n",
"pyplot.plot(mu3, label=\"$\\mu_3$\")\n",
"pyplot.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We run a sanity check for the distance:\n",
"(This will take longer than you might expect, as it computes a rather large gradient numerically, but it finishes in $<1$ minute on a GTX 1080)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OK? True took 32 sec\n"
]
}
],
"source": [
"t = time.time()\n",
"device = \"cuda\"\n",
"res = torch.autograd.gradcheck(lambda x: SinkhornOT.apply(x.softmax(1), \n",
" mu231.to(device=device, dtype=torch.double),\n",
" cost.to(device=device, dtype=torch.double),\n",
" lam, 500),\n",
" (mu123.log().to(device=device, dtype=torch.double).requires_grad_(),))\n",
"print(\"OK? {} took {:.0f} sec\".format(res, time.time()-t))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We might also check that sinkstep is the same on GPU and CPU (Kai Zhao pointed out that this was not the case for an earlier versions of this notebook, thank you, and indeed, there was a bug in the CPU implementation.)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"res_cpu = sinkstep(cost.cpu(), mu123.log().cpu(), mu231.log().cpu(), lam)\n",
"res_gpu = sinkstep(cost.to(device), mu123.log().to(device), mu231.log().to(device), lam).cpu()\n",
"assert (res_cpu - res_gpu).abs().max() < 1e-5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can visiualize the coupling along with the marginals:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAI/CAYAAADQs2XyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAACHz0lEQVR4nOzdd5hdV3n3/e+99t7nnOkzkka9N/cubIwLmBZTDQkhEBIImDgESEKe8CQQkrzJ8yR5eVMJgUAINiUQDEkoTiB0jLFxk3tVsYrVR33qKXvv9f6xzozGQrLGtjRnyu9zXXNp5px9Zu7jIv201r3vZd57RERERGT8uEYXICIiIjLdKICJiIiIjDMFMBEREZFxpgAmIiIiMs4UwERERETGmQKYiIiIyDiLG13AMzFr1iy/dOnSRpchIuPonnvu2ee97250HSIiJ9OkCmBLly5l7dq1jS5DRMaRmW1tdA0iIiebtiBFRERExpkCmIiIiMg4UwATERERGWcKYCIiIiLjTAFMREREZJwpgImIiIiMMwUwERERkXGmACYiIiIyzibVIFaRk2GgkrLj0BC7DpcZqmbEzmguRiye0cy8jiYiZ40uUUREpjgFMJny0izntif28/1H93D3lgM8vrvvuNcWYsdFi7u4bOVMfu6suaya0zaOlYqIyHShACZT1r7+Ctffupn/uGc7e/sqtBQiLlzSxSvOnsfSWc3M72yiKYnIvae/nLL1wCAbe/q5/Yn9/M131/M3313PBYs7+eWLF/O6CxaQRNqxFxGRk0MBTKacw4M1/vGHG/jCnVuppDkvOX0Ob7hoAVedPptiHB33dS8Y9XlPX5lv3LeTG+9+kv/9Hw/ysR9t5H0vXcU15y3AaYtSRESeI/PeN7qGMVuzZo3XYdxyPN57vnrvDv7yW49xcLDK6y5YwHuuWsmK7tbn9D2//1gPf/e99Ty2q5eLlnTx//78OazW1uS4MbN7vPdrGl2HiMjJpAAmU8K+/grv//cHuHndXi5c3Mn/fd3ZnDW/46R9/zz3fPW+Hfz5Nx9loJLyvpeu5jdfuEKrYeNAAUxEpiJtQcqk99ON+/idL9/P4aEa/+eas/iVS5ac9GDknPGGixZy1Wnd/MlNj/DX31nHnZsP8PdvPI+ZrcWT+rNERGTqU1exTGr/esdWfvWGu+hoSvjGey7jrZcuPaWrUjNbi3zszRfwF68/mzs27efV/3grj+/uPWU/T0REpiYFMJmUstzzZ//1CH/89Yd50epuvv6eyzhjXvu4/Gwz4y2XLOGrv/kCcu/5xU/czq0b9o3LzxYRkalBAUwmnTTLef+/P8BnbtvC2y9byqfeuobW4vjvpp+9oIOvvfsy5nc28WufuYv/fnDnuNcgIiKTkwKYTCrVNOe3b7yPr923g/e/fDX/z2vOaujk+vmdTfz7b17KBYs7+Z0b7+cb9+9oWC0iIjJ5KIDJpJHlnt/98v1866Hd/NGrzuC9L17V6JIAaC8lfPbtF3PRki5+98sKYSIicmIKYDIpeO/50Nce4psP7eKPXnUG77xieaNLeoqWYsxn3/48nrd0Br/3lQe4eV1Po0sSEZEJTAFMJoX/79vruPHubbz3qpUTLnwNay7EfPpta1g9p43f/MK93PfkwUaXJCIiE5QCmEx4X7rrST754yf45UsW83svX93ocp5WWynhs+94Ht1tRd7x2bvZdmCw0SWJiMgEpAAmE9pPN+7jj7/+MFeu7ub/vPYszCb+5PnZbSU+946LyT1c+7m76SvXGl2SiIhMMGMKYGZ2tZmtM7ONZvaBYzxvZvbR+vMPmtmFo567wcx6zOzh43zv95uZN7NZz/5tyFS0df8Av/nFe1k2q4WP/fIFxNHk+fvCslktfPyXL+SJvQP8zo33k+WT58gvERE59U74J5qZRcDHgVcAZwJvNrMzj7rsFcCq+sd1wCdGPfdZ4OrjfO9FwMuAJ59p4TK1DVUz3vWFewG4/m3Po72UNLiiZ+7yVbP409ecyQ8f7+EffrCh0eWIiMgEMpYlhYuBjd77Td77KnAjcM1R11wDfN4HdwCdZjYPwHt/C3DgON/774HfB7Q8ICO89/zR1x/m8d29fOSXzmfxzOZGl/Ss/crzl/ALFy7kH3+4gR+v39vockREZIIYSwBbAGwb9fX2+mPP9JqnMLPXAju89w+MoQaZRr589zb+897t/NaLV3HV6bMbXc5zYmb8+evO5rQ5bbzvxvvYeWio0SWJiMgEMJYAdqyu56NXrMZyzZGLzZqBDwF/csIfbnadma01s7V792oFYarb2NPHn/7XI1y+cha/85KJMWj1uWoqRPzTWy6kmua8T/1gIiLC2ALYdmDRqK8XAkcfejeWa0ZbASwDHjCzLfXr7zWzuUdf6L3/lPd+jfd+TXd39xjKlcmqkmb81pfup7kQ83dvPK+hRwydbMu7W/nz15/NXVsO8MkfP9HockREpMHGEsDuBlaZ2TIzKwBvAm466pqbgLfW74Z8PnDYe7/reN/Qe/+Q9362936p934pIcBd6L3f/ezehkwFf/XtdTy2q5e/+oVzmd1eanQ5J93rzl/Aa8+bz99/bz33bzvU6HJERKSBThjAvPcp8F7gO8BjwFe894+Y2bvM7F31y74FbAI2Av8CvHv49Wb2JeB24DQz225m157k9yBTwB2b9nP9rZv5lecv5qVnzml0OaeEmfF/X3c2c9pLvO/G+xiqZo0uSUREGsS8nzz9KGvWrPFr165tdBlykg1UUq7+h1twZvzP71xBcyFudEmn1O1P7OfN/3IH77hsGX/ymqMnusjRzOwe7/2aRtchInIyTZ7JljJlffh/Hmf7wSH++g3nTfnwBXDpipm89dIlfOanm7lr8/EmtIiIyFSmACYNdcem/fzrHVt5x2XLuHjZjEaXM27+4OrTWdDZxO//xwPaihQRmYYUwKRhyrWMP/zqQyye0cz7X35ao8sZVy3FmL96w7ls2T/IR36wvtHliIjIOFMAk4b5+I82smnfAH/x+rNpKkSNLmfcvWDFLN64ZiGf/slmHtvV2+hyRERkHCmASUOs293HJ25+gp+/cAFXrJq+890++Ioz6GhK+OBXHyLXgFYRkWlDAUzGXTjr8SHaSjF/9KrpfRdgV0uBP371Gdy/7RBfvHNro8sREZFxogAm4+5r9+3g7i0H+YOrT2dGS6HR5TTc685fwGUrZ/LX31nH/v5Ko8sREZFxoAAm4+rwUI2//NZjnLeokzeuWXTiF0wDZsafvfYsBqsZf/XtdY0uR0RExoECmIyrv//eevYPVPnza87GTaGzHp+rlbPbeMfly/jy2m3c9+TBRpcjIiKnmAKYjJsNe/r41zu28ssXL+achR2NLmfC+e2XrGJOe5E/+cYjasgXEZniFMBkXHjv+b/ffIyWQsTvTbOZX2PVWoz54CvO4KEdh/nqfTsaXY6IiJxCCmAyLn60rodb1u/ld166Wo33T+O1583nvEWd/NW3H2egkja6HBEROUUUwOSUq6Y5f/7fj7G8u4W3Xrqk0eVMaM4Zf/LqM+npq/DJHz/R6HJEROQUUQCTU+7f7tzKpn0D/NGrziCJ9J/ciVy0pIvXnDefT92yiR2HhhpdjoiInAL601BOqd5yjX/4wQZesGImV502u9HlTBp/cPVpeOBvv6uxFCIiU5ECmJxSn7z5CQ4O1vjgK87ATGMnxmphVzNvf8FSvnbfDh7dqXMiRUSmGgUwOWV2Hhri+ls387rz52vsxLPw7hetpL2U8OFvP97oUkRE5CRTAJNT5iPfX4/38P6f09iJZ6OjOeG9V63klvV7+cmGvY0uR0RETiIFMDklNvb08R/3bOdXL13Cwq7mRpczaf3qpUtY0NnEX317Hd5rOKuIyFShACanxN98Zz3NhZh3v2hFo0uZ1EpJxPteuoqHdhzm2w/vbnQ5IiJykiiAyUn3wLZDfPuR3bzzimXMbC02upxJ7+cvXMjK2a38zXfXkWZ5o8sREZGTQAFMTrq//s46ZrQUeOcVyxtdypQQOeP9L1/NE3sH+JqOKBIRmRIUwOSkumPTfm7duI93v2gFrcW40eVMGT931lzOXdjBR76/gWqqVTARkclOAUxOGu89f/fd9cxpL/Irz9eRQyeTmfF7Lz+NHYeG+MrabY0uR0REniMFMDlpbt24j7u2HOA9V62klESNLmfKuXLVLC5a0sXHfriRci1rdDkiIvIcKIDJSeG952+/u575HSV+6XmLGl3OlGRm/N7LVrO7t8yNdz3Z6HJEROQ5GFMAM7OrzWydmW00sw8c43kzs4/Wn3/QzC4c9dwNZtZjZg8f9Zq/NrPH69d/zcw6n/O7kYa5ef1e7t92iN96ySqKsVa/TpVLV8zkkmUz+PjNTzBU1SqYiMhkdcIAZmYR8HHgFcCZwJvN7MyjLnsFsKr+cR3wiVHPfRa4+hjf+nvA2d77c4H1wAefafEyMXjv+cj3N7Cgs4lfuHBho8uZ0syM//Wy1eztq/BvWgUTEZm0xrICdjGw0Xu/yXtfBW4ErjnqmmuAz/vgDqDTzOYBeO9vAQ4c/U2999/13qf1L+8A9Cf3JPXj9Xt5YNsh3vvilRRi7Wqfapcsn8nzl8/gkz9+Qr1gIiKT1Fj+tFwAjL7tanv9sWd6zdN5B/A/z+B6mSC0+tUYv/OSsAqmXjARkclpLAHMjvHY0YfSjeWaY39zsw8BKfDF4zx/nZmtNbO1e/fqQOKJ5pYN+7h/2yHec5VWv8bTpStmcvGyGXxCq2AiIpPSWP7E3A6Mvq1tIbDzWVzzM8zsbcCrgbf445w07L3/lPd+jfd+TXd39xjKlfHivecff7CB+R0l3nCRVr/G2/tesoo9vRXNBRMRmYTGEsDuBlaZ2TIzKwBvAm466pqbgLfW74Z8PnDYe7/r6b6pmV0N/AHwWu/94LOoXRrszs0HWLv1IO960QqtfjXApStmsmZJF5+8+QlNxxcRmWRO+KdmvVH+vcB3gMeAr3jvHzGzd5nZu+qXfQvYBGwE/gV49/DrzexLwO3AaWa23cyurT/1MaAN+J6Z3W9mnzxZb0rGx8d+uJHutiJvXKO5X41gZrznxSvZebjM13VGpIjIpDKmw/q8998ihKzRj31y1OceeM9xXvvm4zy+cuxlykRz75MHuXXjPj70yjM09b6BXrS6m7MXtPNPN2/k5y9cQBxpJVJEZDLQ79byrHz8hxvpak745UsWN7qUac3MeO9VK9myf5BvPvS0u/4iIjKBKIDJM/bYrl5+8HgPb79sGS3FMS2iyin08jPnsmp2Kx//0UbyfEw3H4uISIMpgMkz9ombn6ClEPG2S5c2uhQBnDPefdUK1u/p54eP9zS6HBERGQMFMHlGtuwb4L8f3MmvXLqEjuak0eVI3WvOnc/Crib+6eaNHGeii4iITCAKYPKM/PMtm4gjx7WXL2t0KTJKHDl+48rl3PvkIe7a/DMnf4mIyASjACZj1tNb5j/v2c4vXrSQ2W2lRpcjR/nFNYuY1Vrgn25+otGliIjICSiAyZhdf9tm0jznN65c0ehS5BhKScTbL1vGj9fv5ZGdhxtdjoiIPA0FMBmT3nKNf7vjSV55zjwWz2xudDlyHL966RJaizH//ONNjS5FRESehgKYjMkX73iSvkrKu16o1a+JrL2U8JZLFvPfD+5k2wGd8CUiMlEpgMkJlWsZ19+6mStWzeLsBR2NLkdO4O2XLSNyxqd/olUwEZGJSgFMTuhr9+1gX39Fq1+TxNyOEq+/YAFfXruN/f2VRpcjIiLHoAAmTyvPPf9yyybOXtDOC1bMbHQ5MkbXXbmcci3nc7dvbXQpIiJyDApg8rS+99geNu0b4LorV2BmjS5Hxmjl7DZeesYc/vX2LQxVs0aXIyIiR1EAk6f1qVs2sbCriVeePbfRpcgzdN2Vyzk4WOM/7tnW6FJEROQoCmByXPdsPcA9Ww9y7eXLiCP9pzLZPG9pF+cv6uTTt24m0yHdIiITiv5UleP65x9voqMp4Y1rFjW6FHkWzIzfuHI5W/cP8p1Hdje6HBERGUUBTI5p874BvvfYHn71+UtoKcaNLkeepZefNZclM5v551s26ZBuEZEJRAFMjun6WzeROMdbX7Ck0aXIcxA5452XL+OBbYdYu/Vgo8sREZE6BTD5GQcGqvzHPdt53QXzdej2FPCGixbR2ZzwqVs0mFVEZKJQAJOf8YU7tlKu5bzziuWNLkVOgqZCxK8+fwnff2wPm/cNNLocERFBAUyOUq5lfP72LVx1Wjer57Q1uhw5SX710iUkznH9rVoFExGZCBTA5Cm+cf8O9vVXtfo1xcxuK/G6C+bz72u3c2Cg2uhyRESmPQUwGeG959M/2cwZ83Ts0FT0ziuWU0lzvniHjicSEWk0BTAZ8eP1e9nQ08+vX7FMxw5NQavntPHC1d187vatVFIdTyQi0kgKYDLi+ls3M7utyKvPnd/oUuQUeecVy9jXX+Gm+3c2uhQRkWltTAHMzK42s3VmttHMPnCM583MPlp//kEzu3DUczeYWY+ZPXzUa2aY2ffMbEP9167n/nbk2Xp8dy8/2bCPt71gKYVYuXyqunzlLE6b08b1t27WYFYRkQY64Z+0ZhYBHwdeAZwJvNnMzjzqslcAq+of1wGfGPXcZ4Grj/GtPwD8wHu/CvhB/WtpkOt/spmmJOItlyxudClyCpkZ116xjMd393Hrxn2NLkdEZNoay1LHxcBG7/0m730VuBG45qhrrgE+74M7gE4zmwfgvb8FOHCM73sN8Ln6558DXvcs6peToKevzDfu38kbLlpIZ3Oh0eXIKXbN+fOZ1Vrk+ls3N7oUEZFpaywBbAGwbdTX2+uPPdNrjjbHe78LoP7r7DHUIqfAF+54kmqW8/bLlja6FBkHxTgMZr153V429vQ1uhwRkWlpLAHsWLfDHd08MpZrnhUzu87M1prZ2r17956MbymjlGsZX7xjKy89YzbLu1sbXY6Mk7c8fzGF2HHDbVsaXYqIyLQ0lgC2HVg06uuFwNG3UI3lmqPtGd6mrP/ac6yLvPef8t6v8d6v6e7uHkO58kx8/b4d7B+o8o7LlzW6FBlHs1qL/PwFC/jqvds5qMGsIiLjbiwB7G5glZktM7MC8CbgpqOuuQl4a/1uyOcDh4e3F5/GTcDb6p+/DfjGM6hbTgLvPTfcFgavXrpcg1enm3dcvoxyLeff7nqy0aWIiEw7Jwxg3vsUeC/wHeAx4Cve+0fM7F1m9q76Zd8CNgEbgX8B3j38ejP7EnA7cJqZbTeza+tPfRh4mZltAF5W/1rG0U827GP9nn6uvVyDV6ej1XPauGLVLD730y1U07zR5YiITCvxWC7y3n+LELJGP/bJUZ974D3Hee2bj/P4fuAlY65UTrrrb93MrNYirzlvXqNLkQa59vJl/Npn7uabD+3k9RcsbHQ5IiLThiZuTlMbe/r48fq9vPXSJRTjqNHlSIO8cHU3K2e3ajCriMg4UwCbpq6/dQuF2Gnw6jRnZrzjsmU8vKOXuzYfa1yfiIicCgpg09DBgSpfvXc7rz9/ATNbi40uRxrs9RcsoLM50WBWEZFxpAA2Df3bXU9SSXONnhAAmgrhCKrvPbaHrfsHGl2OiMi0oAA2zVTTnM/fvoUrVs3itLltjS5HJoi3XrqUyIzP/nRLo0sREZkWFMCmmW89tIs9vRXecZlWv+SIOe0lXn3uPP597Xb6yrVGlyMiMuUpgE0jw4NXl3e38MLVOlVAnuray5fTX0n58t3bTnyxiIg8Jwpg08g9Ww/y4PbDvP2yZTinwavyVOcs7OB5S7v47E+3kOUaSSEiciopgE0j19+6mY6mhF+4cEGjS5EJ6h2XLWP7wSG+9+juRpciIjKlKYBNE9sODPKdR3bz5osX01wY0wEIMg29/Ky5LOxq4oZbtzS6FBGRKU0BbJr43E+3YGa89dIljS5FJrDIGb/2gqXcteUAD20/3OhyRESmLAWwaWC4sfqV58xjfmdTo8uRCe6Nz1tESyHihts0mFVE5FRRAJsG/mPtNvoqKddq8KqMQXsp4RfXLOK/H9zJnt5yo8sREZmSFMCmuCz3fOanW7hwcSfnL+psdDkySbz9sqWkuecLd2xtdCkiIlOSAtgU98PHe9i6f5BrL1/e6FJkElkys4WXnjGHL9yxlXIta3Q5IiJTjgLYFHf9rZtY0NnEz501p9GlyCRz7eXLODhY4+v37Wh0KSIiU44C2BT2yM7D3LHpAG97wRLiSP+q5Zm5ZNkMzprfzg23bcZ7DWYVETmZ9KfyFHbDrVtoLkT80vMWN7oUmYTMjHdctoz1e/r5yYZ9jS5HRGRKUQCbonr6yvzXAzt5w0UL6WhKGl2OTFKvPm8e3W1FjaQQETnJFMCmqC/cvpVanvP2yzR6Qp69Yhzxq89fws3r9rKxp6/R5YiITBkKYFNQuZbxhTuf5CWnz2bZrJZGlyOT3FsuWUwxdtxw25ZGlyIiMmUogE1B37h/BwcGqrxDg1flJJjZWuT1FyzgP+/ZzoGBaqPLERGZEhTAphjvPdffupkz5rVz6fKZjS5Hpoh3XL6MSprzb3dqMKuIyMmgADbF3LpxH+v39HPt5csws0aXI1PE6jltXLFqFp+/fSvVNG90OSIik54C2BTz6Z9sZlZrkdecN6/RpcgUc+3ly+jpq/BfD+xsdCkiIpOeAtgUsmFPHz9ev5e3XbqEYhw1uhyZYl64upuVs1u5/lYNZhURea7GFMDM7GozW2dmG83sA8d43szso/XnHzSzC0/0WjM738zuMLP7zWytmV18ct7S9HXDbZspxo63PH9Jo0uRKcjMuPbyZTy6q5fbN+1vdDkiIpPaCQOYmUXAx4FXAGcCbzazM4+67BXAqvrHdcAnxvDavwL+zHt/PvAn9a/lWdrfX+E/793BL1y0kBkthUaXI1PU6y9YwIyWAjfcqsGsIiLPxVhWwC4GNnrvN3nvq8CNwDVHXXMN8Hkf3AF0mtm8E7zWA+31zzsANZY8B1+880mqac47NHhVTqFSEvErz1/C9x/rYdPe/kaXIyIyaY0lgC0Ato36env9sbFc83SvfR/w12a2Dfgb4INjrlqeolzL+PztW3jRaaFHR+RU+tXnL6EQOa7XKpiIyLM2lgB2rFkGR3fgHu+ap3vtbwK/671fBPwucP0xf7jZdfUesbV79+4dQ7nTz03372Rff5Vfv2J5o0uRaaC7rcjrLpjPf967nYMazCoi8qyMJYBtBxaN+nohP7tdeLxrnu61bwO+Wv/83wnblT/De/8p7/0a7/2a7u7uMZQ7vXjv+fStmzhjXjsvWKHBqzI+3nnFcsq1nC9qMKuIyLMylgB2N7DKzJaZWQF4E3DTUdfcBLy1fjfk84HD3vtdJ3jtTuCF9c9fDGx4ju9lWrplQxi8+k4NXpVxtHpOG1eu7uZzt2+lkmaNLkdEZNI5YQDz3qfAe4HvAI8BX/HeP2Jm7zKzd9Uv+xawCdgI/Avw7qd7bf01vw78rZk9APwl4e5JeYY+/ZNNzG4r8prz5je6FJlmfv2KZeztq/CN+3X/jIjIMxWP5SLv/bcIIWv0Y58c9bkH3jPW19YfvxW46JkUK0/12K5efrJhH//7506jEGumroyvy1fO4vS5bVz/k8384kULtQIrIvIM6E/tSezTP9lMcyHiLZcsbnQpMg0ND2Zdt6ePWzbsa3Q5IiKTigLYJLX7cJmbHtjBG9csorNZg1elMV57/nxmtxX59E82NboUEZFJRQFskvrsT7eQ5Z5rL9fgVWmcYhzxa5ct5Scb9vHozt5GlyMiMmkogE1C/ZWUL965lVecPY9FM5obXY5Mc2+5eAnNhUirYCIiz4AC2CT05bu30VdOeecVWv2SxutoTvil5y3ipgd2suvwUKPLERGZFBTAJplalnPDrZu5eOkMLljc1ehyRAB4x2XL8MBnbtvS6FJERCYFBbBJ5lsP7WLHoSGuu1LHDsnEsWhGM688Zx7/dueT9JZrjS5HRGTCUwCbRLz3/POPN7Giu4UXnz670eWIPMVvXLmc/krKl+58stGliIhMeApgk8htG/fz6K5errtyOc5p6KVMLGcv6OCylTO54bbNVNO80eWIiExoCmCTyD/f8gTdbUVed8GCRpcickzXXbmCPb0VvnH/jkaXIiIyoSmATRIP7zjMTzbs4+2XLaUYR40uR+SYrlw1izPmtfOpWzaR577R5YiITFgKYJPEP9+yidZizFsuWdLoUkSOy8z4jSuXs6Gnnx8+3tPockREJiwFsEngyf2DfPPBnbzlksV0NCWNLkfkab363Hks6Gzikz9+otGliIhMWApgk8C//GQTkTPeoWOHZBKII8evX7GMtVsPsnbLgUaXIyIyISmATXD7+it8Ze02Xn/BAua0lxpdjsiYvPF5i+hqTrQKJiJyHApgE9xnb9tCNcv5jReuaHQpImPWXIh52wuW8v3Heli3u6/R5YiITDgKYBNYX7nG527fwtVnzWVFd2ujyxF5Rt526VKaC5FWwUREjkEBbAL74p1P0ldO+c0XafVLJp+ulgJvvngxNz2wk20HBhtdjojIhKIANkGVaxnX37qZy1fO4tyFnY0uR+RZeecVy3AGn7plU6NLERGZUBTAJqj/vHc7e/sqWv2SSW1eRxM/f8FCvrJ2G3v7Ko0uR0RkwlAAm4DSLOeTP36C8xd18oIVMxtdjshz8hsvXE41y7nhts2NLkVEZMJQAJuAQs/MEO+5aiVmOnRbJrfl3a288px5/OvtWzk8WGt0OSIiE4IC2AST555/uvkJTp/bxktOn93ockROive8aCX9lZTP/nRLo0sREZkQFMAmmO8+upuNPf28+6qVOKfVL5kazpzfzkvPmM1nfrqZgUra6HJERBpOAWwC8d7zsR9tZOnMZl51zrxGlyNyUr3nqpUcGqzxxTu3NroUEZGGG1MAM7OrzWydmW00sw8c43kzs4/Wn3/QzC4cy2vN7Lfqzz1iZn/13N/O5PajdT08vKOXd1+1kkirXzLFXLC4i8tWzuTf7nySPPeNLkdEpKHiE11gZhHwceBlwHbgbjO7yXv/6KjLXgGsqn9cAnwCuOTpXmtmVwHXAOd67ytmNq0bnrz3fPQHG1nY1cTrL1jQ6HJETom/eN05dDQl2l4XkWlvLCtgFwMbvfebvPdV4EZCcBrtGuDzPrgD6DSzeSd47W8CH/beVwC89z0n4f1MWrdu3Mf92w7xmy9aQRJpZ1impqWzWuhqKTS6DBGRhhvLn/QLgG2jvt5ef2ws1zzda1cDV5jZnWb2YzN73jMpfKr5xx9sZF5HiTdctLDRpYiIiMgpNpYAdqy9gqMbOI53zdO9Nga6gOcD/xv4ih1j6JWZXWdma81s7d69e8dQ7uRz+xP7uWvLAX7jyuUU46jR5YiIiMgpNpYAth1YNOrrhcDOMV7zdK/dDny1vm15F5ADs47+4d77T3nv13jv13R3d4+h3MnnI99fz+y2Im+6eHGjSxEREZFxMJYAdjewysyWmVkBeBNw01HX3AS8tX435POBw977XSd47deBFwOY2WqgAOx7rm9osrn9if3cufkA737RCkqJVr9ERESmgxPeBem9T83svcB3gAi4wXv/iJm9q/78J4FvAa8ENgKDwNuf7rX1b30DcIOZPQxUgbd576fdven/8AOtfomIiEw3JwxgAN77bxFC1ujHPjnqcw+8Z6yvrT9eBX7lmRQ71dyxaT93bDrAn77mTK1+iYiITCOad9Ag3nv+7rvrmdOu1S8REZHpRgGsQW7duI+7thzgvVet1OqXiIjINKMA1gDee/72u+uZ31Hijc9bdOIXiIiIyJSiANYAP1rXw/3bDvFbL1mluV8iIiLTkALYOMvzsPq1aEaTpt6LiIhMUwpg4+zbj+zmkZ29/O5LV+vMRxERkWlKCWAcpVnO3353Hatmt3LN+UcfpykiIiLThQLYOPrafTt4Yu8Av/fy04jcsY7JFBERkelAAWycVNKMj3x/A+cu7ODnzprT6HJERESkgRTAxskX73iSHYeG+N8/dxpmWv0SERGZzhTAxkFfucbHfrSRy1fO4opV3Y0uR0RERBpMAWwc/MstmzgwUOUPrj690aWIiIjIBKAAdort7avw6Vs386pz53HOwo5GlyMiIiITgALYKfaR76+nmua8/+WnNboUERERmSAUwE6hjT193Hj3Nn7l+UtYNqul0eWIiIjIBKEAdgp9+H8epzmJ+O2XrGp0KSIiIjKBKICdIrc/sZ/vP9bDb161ghkthUaXIyIiIhOIAtgpkOeev/jWo8zvKPGOy5Y1uhwRERGZYBTAToH/vHc7D+/o5Q9ecTqlJGp0OSIiIjLBKICdZAOVlL/6zjouWNzJa8+b3+hyREREZAJSADvJPvnjJ9jbV+GPX32mjhwSERGRY1IAO4m2HRjkU7ds4rXnzefCxV2NLkdEREQmKAWwk+gvv/UYzowPvlJHDomIiMjxKYCdJLdt3Mf/PLyb9754JfM6mhpdjoiIiExgCmAnQS3L+dObHmHxjGauvVxjJ0REROTpjSmAmdnVZrbOzDaa2QeO8byZ2Ufrzz9oZhc+g9e+38y8mc16bm+lcT730y1s6Onnj199psZOiIiIyAmdMICZWQR8HHgFcCbwZjM786jLXgGsqn9cB3xiLK81s0XAy4Ann/M7aZA9vWU+8v0NXHVaNy89Y3ajyxEREZFJYCwrYBcDG733m7z3VeBG4JqjrrkG+LwP7gA6zWzeGF7798DvA/65vpFG+fNvPkY1y/nT156lsRMiIiIyJmMJYAuAbaO+3l5/bCzXHPe1ZvZaYIf3/oFnWPOEcdvGffzXAzt594tWsGRmS6PLERERkUkiHsM1x1rWOXrF6njXHPNxM2sGPgS8/IQ/3Ow6wrYmixcvPtHl46Zcy/ijrz/MkpnNvOuFKxpdjoiIiEwiY1kB2w4sGvX1QmDnGK853uMrgGXAA2a2pf74vWY29+gf7r3/lPd+jfd+TXd39xjKHR//dPMTbN43wJ+/7mw13ouIiMgzMpYAdjewysyWmVkBeBNw01HX3AS8tX435POBw977Xcd7rff+Ie/9bO/9Uu/9UkJQu9B7v/tkvbFTaWNPH5+4eSOvO38+V6yaOKFQREREJocTbkF671Mzey/wHSACbvDeP2Jm76o//0ngW8ArgY3AIPD2p3vtKXkn4yTPPX/41YdpLsT80auPvhlURERE5MTG0gOG9/5bhJA1+rFPjvrcA+8Z62uPcc3SsdQxEfzbXU9y15YD/H+/cA6zWouNLkdEREQmIU3CfwZ2Hhriw//zOJetnMkb1yw68QtEREREjkEBbIy89/zR1x8myz3/7+vP1cwvERERedYUwMboa/ft4IeP9/D+nzuNxTObG12OiIiITGIKYGOw+3CZP73pES5a0sWvvWBpo8sRERGRSU4B7AS893zgqw9SzXL+5hfPI3LaehQREZHnRgHsBL6ydhs3r9vLB64+nWWzdNyQiIiIPHcKYE9j6/4B/s9/Pcqly2fy1kuXNrocERERmSIUwI4jzXJ+98v345zxt288D6etRxERETlJxjSIdTr6xM1PcO+Th/iHN53P/M6mRpcjIiIiU4hWwI7hnq0H+cgPNvDa8+ZzzfkLGl2OiIiITDEKYEc5PFTjt790H/M7S/z5689udDkiIiIyBWkLchTvPX/41YfY01vm3991Ke2lpNEliYiIyBSkFbBRvnjnk3zzoV38r5ev5oLFXY0uR0RERKYoBbC6h7Yf5v/816Ncubqbd125otHliIiIyBSmAAYcHqzx7n+7h5mtBT7yS+dr5ISIiIicUtO+ByzPPb/37/ez61CZL//GpcxoKTS6JBEREZnipv0K2D/+cCPff6yHP3rVGVy0RH1fIiIicupN6wD2g8f28JEfrOfnL1jA216wtNHliIiIyDQxbQPYxp5+3vfl+zlzXjt/+fPnYKa+LxERERkf0zKAHRqs8s7P3U0xdvzzr15EKYkaXZKIiIhMI9OuCb+W5bz7i/ey81CZL113CQu7mhtdkoiIiEwz0yqAee/50Nce4qdP7OdvfvE8Lloyo9EliYiIyDQ0rbYgP/bDjXxl7XZ+68UrecNFCxtdjoiIiExT0yaAfe2+7fzt98Idj//rZasbXY6IiIhMY9MigP1oXQ//+98f5NLlM/nwL5yrOx5FRESkoaZ8ALtn60F+8wv3cNrcNj711osoxFP+LYuIiMgEN6Y0YmZXm9k6M9toZh84xvNmZh+tP/+gmV14otea2V+b2eP1679mZp0n5R2N8tiuXt7x2buZ217is2+/mLZScrJ/hIiIiMgzdsIAZmYR8HHgFcCZwJvN7MyjLnsFsKr+cR3wiTG89nvA2d77c4H1wAef87sZZWNPP7/y6TtpLkT867WX0N1WPJnfXkRERORZG8sK2MXARu/9Ju99FbgRuOaoa64BPu+DO4BOM5v3dK/13n/Xe5/WX38HcNJuS9y6f4C3fPoOzIwvvvMSFs3QrC8RERGZOMYSwBYA20Z9vb3+2FiuGctrAd4B/M8YahmTg4M1mpKIL77zEpZ3t56sbysiIiJyUoxlEOuxbhn0Y7zmhK81sw8BKfDFY/5ws+sI25osXrz4RLUCcP6iTr7/v15IHKnhXkRERCaesSSU7cCiUV8vBHaO8Zqnfa2ZvQ14NfAW7/3RoQ4A7/2nvPdrvPdruru7x1BuoPAlIiIiE9VYUsrdwCozW2ZmBeBNwE1HXXMT8Nb63ZDPBw5773c93WvN7GrgD4DXeu8HT9L7EREREZnwTrgF6b1Pzey9wHeACLjBe/+Imb2r/vwngW8BrwQ2AoPA25/utfVv/TGgCHyvPhj1Du/9u07mmxMRERGZiOw4O38T0po1a/zatWsbXYaIjCMzu8d7v6bRdYiInExqlBIREREZZwpgIiIiIuNMAUxERERknCmAiYiIiIwzBTARERGRcTap7oI0s73A1jFePgvYdwrLOZVU+/ibrHXD5K19rHUv8d6PfQqziMgkMKkC2DNhZmsn663rqn38Tda6YfLWPlnrFhE5GbQFKSIiIjLOFMBERERExtlUDmCfanQBz4FqH3+TtW6YvLVP1rpFRJ6zKdsDJiIiIjJRTeUVMBEREZEJaUoGMDO72szWmdlGM/tAo+s5HjNbZGY/MrPHzOwRM/ud+uMzzOx7Zrah/mtXo2s9HjOLzOw+M/vv+teTonYz6zSz/zCzx+v//C+dDLWb2e/W/1t52My+ZGaliVq3md1gZj1m9vCox45bq5l9sP7/7Doz+7nGVC0iMj6mXAAzswj4OPAK4EzgzWZ2ZmOrOq4U+D3v/RnA84H31Gv9APAD7/0q4Af1ryeq3wEeG/X1ZKn9H4Bve+9PB84jvIcJXbuZLQB+G1jjvT8biIA3MXHr/ixw9VGPHbPW+n/3bwLOqr/mn+r/L4uITElTLoABFwMbvfebvPdV4EbgmgbXdEze+13e+3vrn/cRQsACQr2fq1/2OeB1DSnwBMxsIfAq4NOjHp7wtZtZO3AlcD2A977qvT/EJKgdiIEmM4uBZmAnE7Ru7/0twIGjHj5erdcAN3rvK977zcBGwv/LIiJT0lQMYAuAbaO+3l5/bEIzs6XABcCdwBzv/S4IIQ2Y3cDSns5HgN8H8lGPTYbalwN7gc/Ut08/bWYtTPDavfc7gL8BngR2AYe9999lgtd9lOPVOin/vxURebamYgCzYzw2oW/1NLNW4D+B93nvextdz1iY2auBHu/9PY2u5VmIgQuBT3jvLwAGmDjbdsdV75e6BlgGzAdazOxXGlvVSTPp/r8VEXkupmIA2w4sGvX1QsI2zYRkZgkhfH3Re//V+sN7zGxe/fl5QE+j6nsalwGvNbMthG3eF5vZF5gctW8Htnvv76x//R+EQDbRa38psNl7v9d7XwO+CryAiV/3aMerdVL9fysi8lxNxQB2N7DKzJaZWYHQ2HtTg2s6JjMzQh/SY977vxv11E3A2+qfvw34xnjXdiLe+w967xd675cS/hn/0Hv/K0yO2ncD28zstPpDLwEeZeLX/iTwfDNrrv+38xJC3+BEr3u049V6E/AmMyua2TJgFXBXA+oTERkXU3IQq5m9ktCfFAE3eO//orEVHZuZXQ78BHiII31Uf0joA/sKsJjwh+4veu+PbmaeMMzsRcD7vfevNrOZTILazex8ws0DBWAT8HbCX0gmdO1m9mfALxHuoL0PeCfQygSs28y+BLwImAXsAf4f4Oscp1Yz+xDwDsJ7e5/3/n/Gv2oRkfExJQOYiIiIyEQ2FbcgRURERCY0BTARERGRcaYAJiIiIjLOFMBERERExpkCmIiIiMg4UwATERERGWcKYCIiIiLjTAFMREREZJwpgImIiIiMMwUwERERkXGmACYiIiIyzhTARERERMaZApiIiIjIOFMAExERERlnCmAiIiIi40wBTERERGScKYCJiIiIjDMFMBEREZFxpgAmIiIiMs4UwERERETGmQKYiIiIyDhTABMREREZZwpgIiIiIuNMAUxERERknCmAiYiIiIwzBTARERGRcaYAJiIiIjLOFMBERERExpkCmIiIiMg4UwATERERGWcKYCIiIiLjbMIEMDO72szWmdlGM/tAo+sREREROVXMe9/oGjCzCFgPvAzYDtwNvNl7/2hDCxMRERE5BSbKCtjFwEbv/SbvfRW4EbimwTWJiIiInBJxowuoWwBsG/X1duCSp3tBUmjxpeau8EV9Ec+GF/O8D495sOEVPu9HPV5/cuTzI78eWRH0I99XRMZPmQGqvmKNruNkmzVrll+6dGmjyxCRcXTPPffs8953H+u5iRLAjvWb7c/EHzO7DrgOoFjs4OIL3v3UC/L6dfmRsOXSHLzH0hzyHMs81NL6YxmkGeQ5Pk0h95CmkGUhiNVqoZAsBx++uc/9yOfhAaU0kZPpTv+DRpdwSixdupS1a9c2ugwRGUdmtvV4z02UALYdWDTq64XAzqMv8t5/CvgUQEfzfB/1ViEyvDNw9V8BLHxuHvJCFB5LPOZ9CGBpHFbGakcCmNVDGdUaPs0wn4fv4z3m6iHNe4wMiI4EMTOFMBEREXlGJkoAuxtYZWbLgB3Am4BfftpXeI+r1PDOYbHDm2GRhdBkQFRvbxu1tubNIALDHdlqjAzLPN65EMScCytjeY6PorCaVl8VY9SqGFlO+EH1VbHwyUhtIiIiIsczIQKY9z41s/cC3wEi4Abv/SNP+6Isww72YnEcwpZz+DiCyOGjCOqhjNjh66tk3hkeg2R0KvOYB8tyyD2W5WG7MvNhVWx4dSzNsCzD12rhNbUUPxzKsixsX3oLYcxQGBMREZHjmhABDMB7/y3gW8/gBSEM5XkIYWaY9/jcYbnHE2OR4b1hWX31C8IKWTQqgJnhPZgLW4mWOZwL/WIQmvg9YGYQuXq2qj/nDJ/lYZHNcnwG5oZ7xdyRLcp6vSIiIiIwgQLYM5Z7/FAZi6Kw4uUManEISnGMxVFYFUvCCpklET4OK2V54kb6xHxkYNQfAzzk9QZ+aw7N/C4tjqyKueqRXrGwWpZBrRb6w6o18HkIZVl9G1MN/CIiInKUSRvAvM/xlQreHDgLwStJwmpVHIcQFjksSUI4KyRYEuOdw/k4hK9CBBjejDyp9485jvSNDY+3yELzvss8rppAHn61NMdq2UgYI47DdmSaho/h/jE18ItMSWZ2NfAPhNaJT3vvP9zgkkRkkpi0AQwfVpTM5ZA7vCOsOsHItp/Prb49GH4NdzSG5nyL6ndJ5r5+92QE5snj+t2UxkgQG+khi8K1lnt8vXnfpXlYXcs8lsRYvVeMau1I036Whab9+qoY3td7xnI18ItMUvUTPD7OqBM8zOwmneAhImMxeQMYQJ7hvQEZmMPX0tCXVUtHVsXC9qQ70qwfRbhCIYytKCS4JMbHDl9M8JGRFyPy2JHHNrItGULZU3+05aE/zFJwqcdyT1TJsNRjtQxXSUMjf6UaQlma4avV0Ohfq4atyTwPK2QQwhlHbVMqjIlMZCMneACY2fAJHpM6gHnvw46CiJxSkzuAwaiQUu+zygDzmDe8haZ5sixsTXoXrrew3Whm9cdDSMO7sG1pPlwS+3DXZH01zBsjs8ZseNJ+BD4Cy8L4C5d63PA1w6tbLh1p4CfPw8+MsjAQtv4ePIRw5vLQwD/qPSmIiUxIz/gEj8ngz/7rUZ48MMgrz5nHy86cQ0dT0uiSRKakyR/Ahg2HFLN6IzwhJY1eTTILQau+QkatVl8Zi7BCgo8crljA1xv282KMj40sCY37PjLyQj2IRVb/gKwQtjMtDwHPZeBqCZZ7XM1jqcdVM6JK2IJ05fpYi1oaGveHx1vUtyl9fVXMj4y30FalyAR0whM8Rp/esXjx4vGo6TnrbivyvUf38MPHe0gi44pV3bzqnHm8VGFM5KSaOgFs2Ohg4oe3KAnBxtxTtyjrd1Ba5KBSD2KVGhZH+CTGmgph0GtTTF6IyAsubHVGkMZGHoUVsTzmSL+YWdierH9EtdDAH9UionI4EikqZ6F3rJLhylXIcqxShTSEr+G7KalZ/VgkqzfwHzXeQiFMpJFOeILH6NM71qxZMyn+h33PVSt594tWcP+2Q3zroV1866HdI2HsylXdvPysOVx12mxmt5caXarIpDb1AtjRjl4Zy11o3M/CX1Utz/F5VN+qrG8P5jE2POurPlHf0hxXi7AsIo8tjKuI6ytgOfVBr0B999CHHc2w1ZmHa/PYY7kjL7gQyio5rhRjWY4rh7sqqdawSr2Bv1KFfHgAbHrU0Fetiok02DM/wWOSMDMuWNzFBYu7+MNXnsH92w7xzQd38T8P7+YHj/cAcNb8dl58+mxedNpszl/USeTUNybyTEz9ADZsOJyMXhXL60NWXYqv1sKqWGXUqlh9ezIqFKC+KhaVEnCOrDkZWRVLmxzeGWnJ6kEL8gQwSOvHI4WfHRr3Xc1hObgaRNUw3iIeyrE0hLKonIbANxhWxaxaCytk3uMr1RC+amkIZsOhDDXwi4ynZ3WCxyQ0Oox96FVn8PjuPn60rocfPd7Dx3+0kX/84Ua6mhNeuLqbq06fzWUrZzGrtdjoskUmvOkTwEY7unE/d0AWGvepr4plYaI+0c+uihFHuMhGAlVe7wVz9a3I4YPA/XDzfv08cG+En+EAD9HIWIt6037GyAgMq4VxFZZGIzcMkB45+qg+M3ZkUr8a+EXG3zM+wWOSMzPOmNfOGfPaefeLVnJosMotG/Zx8+M93Lx+L1+/P+zAnjanjUtXzOQFK2ZyyfKZ6h0TOYbpGcCGHesOytyHBn1z+CyrB580rGIlcViJMiMaKuLiCJKYqCk08GfNSVgBK45eFSOsiMWQJyF8Da+O5cUQ4vBQTaN6074jqsZY5onLRVzmicqeeCgNzfyDNSzLsHINVw2N+75aHRkAO7xVqQZ+ETnVOpsLvPa8+bz2vPlkueehHYf56RP7uP2J/dx495N89qdbcAZnze/gBStmcumKmTxv6QxaitP7jx4RmO4BbNgxG/frs8UAS9Mwcb9i4a5J57BqFVyExRHRUDHcQVku4pOIvJTgKhF54rDckSVGVgyDXLHQvJ/H9T6xaPjnAhguI8wWy4xoKNxRGZUhGXK41BP3R7jUEw3WcOVamMY/FIXVsVoNovrg15qNBDE18IvIqRY54/xFnZy/qJN3v2gllTTjgW0hkP30if185rYt/PMtm4idceb8di5a0sWaJTO4aEkXczvU0C/TjwLYsQyPrPBHtiiHJ+6HqfbZyCwv7+MwJiwKSSpMw88hL4Rp+XlMlDiymuHSEL4sBT/cKzY81qJ+BJI38EkIZsN3VGYJZEWHZZAUXb1nLCIqF3BpTjRQCEGsUm/gzzKsEoa9WpqGsRZHDX3VqpiInErFOOLiZTO4eNkM3vdSGKpm3LP1ILdv2sfaLQf50l1P8pnbtgCwoLOJNUu7WLOki4uWzOC0uW1q6pcpTwHseI6zKuaz4bsi0yPnUEb1Pq1CASKHSxJcIYEoImkaXhWLSZsTfGzUWiOyJDTtp00hjKVN9VWxGPLEh/4xV+8by8OgV8vBVa2+KhbGWrgUCv3FsDo2mIetylpONFCtr45Vjswaq29VqoFfRMZbUyHi8lWzuHzVLABqWc6jO3tZu/Ug92w9wO1P7Ocb9R6y1mLM+Ys6OW9RB+ct7OS8RZ3M0dgLmWIUwMbqRI37ZvhqNcwW8x7zHqIoBLQ0wnlPZBYGvEaGZfU+s8jIc49Lwt/2slFHHvmo3iLmwlR+6uMu8mx4+9JwNcCDS62+khbjqnn9sRwzC9P903rjfp5j5sK8sfpUfgAjUwO/iIybJHKctyiEq2svX4b3nu0Hh7hn60HWbj3AvVsP8ckfbyKrr9bPaS9y7sKwxXnuwg7OXdBJR7Oa+2XyUgB7po7TuB+2JHO8S7FqFZ8koYG/UgnnTyYJri/BxxFxb5i2n5Vi0pYYHw2vikFaMrKmsD2ZNoUQlhd8aNx3kDWF1bE0J9yFmUNUNcghqkRE5QiXQjJQwKWeZNATDeVE1Zx4IPSMuaHayBmVlCvhXMrh8ymzTA38IjLuzIxFM5pZNKOZ112wAIByLeORnb08sO0QD24/xIPbD/O9R/eMvGbpzGbOW9TJOQs6OHtBB2fOb6e9pFAmk4MC2LN11BZleKgeysyNbFGGZv36TLEkBhfhBsNcMVcsEA0UIXZE5QJ5wVFrcdSqLgQurH7HpI1sS2ZFD5HHOz8y9DXNwt2UrmK4muGqUBs0LDWSfk885IiqnmKvw1JP0h/jynGYxO9cGK9RdqGJP3PHbuDP6u9VDfwiMk5KScRFS7q4aEnXyGOHh2o8vOMw99dD2V2bD4xsXQIsntHMWfPb6x8dnDW/XVP7ZUJSADsVRk3c9xmjhsD60LgPWOrCUFbvIY5IvCdPIlwtJqpG5InhUhdWxZqNqFJfFauFrcc8qa+KWT2IWb13LApjLvIkzBXLikZUDb1jtWYLq2ODEfFQIayK9RfrQ19rYeBrLcXKlSNDX+srYqN7xrQqJiKN0tGUcNnKWVy2ctbIY3v7Kjyy8zCP7Ozl0Z29PLLzMP/z8O6R52e1FjlzJJSFYLZ4RrMa/aWhFMBOpqMb93NGpuCPnEMJEJXD9mSS1MdaGFGhQBTHxKUChfqB4KXWAnniSFsiai1hnEW1PayKpc1G2hKGwGYteRhxUfCQhL6u3EIt1ZqDzLDU4QYdLoO4PyIqQ1SBQl+Cq3kK/TnxUE5UyYh6q2HW2GBl5MBwX61iuQ+/qoFfRCaQ7rYiLzotHIs0rK9c47FdfSPB7JGdvfzLLZtI63+BbEoiVs9t44y5bZw2t43T57Zz+tw2uloKjXobMs0ogJ1qo3rGfAaYCw3v5jBq+DwPW5QQjh3yPqyU1SJw4JJo5Fu5gpHHLoyuiCyMuUg8eSFsQXrnwwQN57HIYwa5hYZ8n3oy8+RpuL3SxyHImTcsJTwWGXkS+sksy4nMoFzDoghzbiRwjR7FAWrgF5GJp62UjIzBGFZJM9bv7uexXb08vruPx3f38p1HdnPj3dtGrpnbXuL0eSGUnTG3ndPntbF8ViuF2B3rx4g8awpg4+Vngli9x8qlYPX+K+dC0/5Q6BWLBosQOaJigULT8KpYQp44ai0RtWYjKxi1NldfFYtImz15DFlLBpGHQo5LclycQVOYA5ZlRpo5fGoMlSPIjHjAiMqOqBqR9Ca41FPoayIue6JyTtJXw6U5rr+KVWu4Whoa+PMcKpXjN/APr4w95Z+BiMj4K8YR5yzs4JyFHSOPee/Z21cZCWSP7+rjsd193LZxH7Us/J6VRMbyWa2sntvG6tmtrJoTApq2MeW5UAAbb6MOBQ9fHjVfLErDxP3hOyidQVLAFeuN+wMlfOyIW4oUmmOyoqNSDj1j1fbQhJ8VAOdCT1jsMeeJopxSsUbkcsw8ziDNHJVaTJ4b1cECtYrDyhFJa/g+abMRD4Um/rxouKoniR1RJcEqNZxZaOA3O9LAXwUshMyRsylHhzARkQnEzJjdXmJ2e4krV3ePPF7LcjbvGxhZLVu/u4/7tx3kvx440vBfjB0rZ7eyek4bq+a0ctqcNlbPaWNBZxNOwUxOQAGs0Y6eup/VD+w2N3IGpVG/wzKKMO/DlmDmcbWEKImwLAx4jaoR8aCRFSGq1FfFWh1pU0yW5NSaEyzOiZOMOM4wgyjKw28UzVWyQkTWlFEthVWxtDkiqkBUNip9cVgV642Iy564nJP0Nx0Z+lpN65P46437oxr4UQO/iEwySeRYXQ9U14x6fKCSsrGnn3V7+tiwp491e/q5Y9N+vnbfjpFrmgsRq+orZeHXVlbNVjCTp1IAmwiOO3U/rJJZJQxtxY06i7KQECVJaNxvKoJzlFqL5KWErBhR7YjJE6PSbqGBv+SodkTkCVTbMmrNKXGS0dJcIYkzSi0pxTgl90bujSx39FcKVNOYaiVmsL8AqZEcjoiGHPFgROFwTFSFYm+JaCgnHsqI+yphAv9A+akN/MeYwB/CmLYoRWTyaCnGIwNkRzs8VGNjTx/r9/SzbncfG3r6uHndXv7jnu0j1zQlEStmt7BqdhsrZ7eOhDRtZU5PCmAT0TGm7pvLIePI1H0IvVZpikE4AskZloUGfu8gLzi8i45Myo+s3sDvyImpFY1KkpF5I3I5iTeceRKX4eu/GdSSlEEX7grKUkfNQ5640KxPfRI/jrhg4WYAA6vlxGZQOaqB36pHnaXJU1fF1MAvIpNUR1PCRUtmcNGSGU95/NBglY09/Wzo6WfDnn427u3nzqNWzAqxY/mslpEVs5X1j6UzW9T8P4UpgE1ko/rFwqoYR6buZzneVUMYq1SxyMFgGRdHuEJCdLgIcUSpuUDWHIdVsfaIPA6rYmlzTNoUU+soUI09fW0ZrjklSjLaWsokUU5TUqM5qdJRLEN76FcbqBWoZhFD1YShwSJ5zdHXGxOVHfFgTOFwjKtB4XBT2KYczEl6UyzNifrCuZSuWjv2BP5aWn+79bsttU0pIpNcZ3OBNUtnsGbpU4NZX7nGE3sH2LCnbySgHd1jFjljycxmVnYfCWWrZrexYnYLzQX98T3Z6d/gZPEzzfvHmLofVfBRFA4Hrx8GHjWVcIWEuJAQ9xfxiSMZSKg1O9Imo1IOd1JWUyOtOmpNGUNxThpntBSqtCYVCi6jszCEw5NjZN7oq5U4UGmmksXsbW+lUk4oDyTUWsK2ZFoy4qGIpN9RSkIDf8EZrhLjhuJwRmU6vBKWhen7+XDgGnXeps81fV9Eppy2UsL5i8LZlqMNVlM27R1gY0//yMeGnj5++HjPyAwzgAWdTayY3crK7tBjtrL+ueaYTR4KYJPVqOb94YPBRx73PmxLujQEnHovluU5PnZhi3AgJmtyRJVwB2U8FJru06aISn/MUCGnv63ErqZ2SoUaXc1DJC4bCWQ5RnuhTO6N2OVUWmL6Wwv0t5ZIU0eto4ArG/GAo9wbVsWKh2KiiicZyEn6W3C1nKi/cqSBf6gcDguv1sKKWJpqrIWITCvNhZiz62dbjlbLcrbuD8FseCtzY08/d23eT7l25PfFWa0FVnS3jvSYraz3m81pL4a762XCUACbzJ5uVcxVw3ND9an7UYQVCpgzklKJJInxSUxTSxhrUWsvkjU5as2OSqcjK8RUOmPSlhKHWjyHulqIk4wZHQN0loZoL5SZW+olsYzW9gqJZZTzhKEsoZIn7BxqZ6BWZN9gM4f7mskqEdGBhKhsJH0xxYMRrgqlQ8UwZ6w/Je6rQJrjBoaglkKthlUqYdu1Wg3vN/dHtijVMyYi00QSuXqYauPqs488nueeHYeGnrJitnFvP//1wE56y+nIdW3FOKyYzR4dzlpZ2KUbABpFAWwqeZqp+yPP1afuW5ZgWdjeszgidobLYiyL8ZGRJWHavuWGZVCLE2pJxOE4J8sdtTyiFNUougxnnuaoSmQ5rXGFYp5SKcY0xzWsfiRSpZQwkBt5OcJH4fDKkBEdcdnwcbhJwNVycGC1DCuHg8x9nmPOoN4rhrN6EGPk/T71/YuITA/OGYtmNLNoRjNXnX7kKCbvPXv7K08NZj39/Hj9U+/MLMaO5d0hkK2eE1bMVs1pZcmMZuJINwCcSgpgU9HPNO9nYYxFlo1M3ffOhVWxwQRcRNwXmvaTQkKxtYhPImrtBbJifVWsw5EXIiqdCYebWznQ7NncNYsoyenqGKCtWKGjMMTcpj4KLmVOsZeSq4WJ/V0RNR+xZ14b5SyhZ7CN/QPNVKoxfQdLuLIj6Q1jLVwVSgdKRFVPoS9M4LdaRtRXxtIQyny5AnmGr9Ygz/FpeuxtSgUyEZmmzIzZbSVmt5V4wYpZT3nu8GCtvoV55AaAe7Ye5KZRNwAUIsfy7paRxv/Vc0Kv2ZKZLSQKZieFAthUd4wwNjJ131mY0WWGL5fDjLE4PnIEUl8TeSmm0JRQ6E/ICo540JE2G9U2o1otkhU9B7wx1JyQtTg6CmUcno7iIB3RECWr0uKqZBh9TU3UfMSutk62D3XRlxbZ3DKTwUrC4KEm0uaYqGL42BGVPVnByOOwKlYAXDUbGUw70rA/fCB4mkJ9XIea90VEjq+jOeGiJV1ctKTrKY8PVFKe2NvP+j2h8X/jnn4e3H6Ybz60a+S30iQyls0Ks8xWzWnl9LlhWO2SmS3aynyGFMCmk2PNFxt9MHiW1QelHpm67yoxrlzAVTN8EhFVkjDUtcVR6Q+zvyp9JYaaimxtbWFXZzuFQsrc9tl0FIboKgwxpxh6xdqicugZi8osa95HOU9oiysMZAX2tLexf2YzlUpC78wSVnEkvRaGvVagdDDcXZn0ZST9Ka6WEfVVoJbiKlWoVEMIq9aOTN/Phk8XGDXWQoFMROSYWoox5y7s5NyFnU95fKia8cTeEMo27AkB7eGdh/nWw0eCWTF2rJoTjmUaDmWnzW1jbntJzf/HoQA2HY1eFcsBM3xK6BlzVRiqT92PohDE4piokGDOEbc0QRKTlwpk7QXy2FHpiklLRrUtptKVUCvCE7Pa8E0ZTe1lFnQdpiWuclr7HmbEA8yK+5hbOkSEx7XmRJZzKGthf9bK4ayJ9QNz6UuLbO3tYv+hVmrlmIG9BVwFiocSCodi4oqndKBIVMmJB2ohjKUZNjAUxnLUalDfoiS1+halaYtSROQZaipEx7wzc6iasbGnn8d397J+Tx+P1w8x/+q9R4bMtpdiTpsbwthpc9o4fV47p89to62UjPfbmHAUwOSpIy0ywHw4jxLC6Io8D+dROgtT7dMMl3tw4JKovk3owDu8C2dR+tiRVo0hX2J3lNNUqFGKa/QVSpQLCZHlJJbRGQ2Ah8RSOqMBCpYyWCrSnxVJ87BVOlAp0JcZVolCU35kpOXwN6qoGlFIHEnkQq9YOGW83sAfesWoupFeMbNRQUzN+yIiz1pTIeKchR2cs/CpwezQYJV1u/tGQtn6PX184/6d9I26K3PJzGbOnNcePuaHj+m2WqYAJsFTQkg9iA1P3Xdp2N4DrFwJK2NJTNRbCMNe9xfDqlhzgbS1QF5wVDoi0pKj2lag2lGgrwh3zZyBL+UU2it0d/TTklRZ0b6PGckAs5I+5saHaXNl1rRsIsIz0Fqkr7tEX9bElkUz6U2beLK/i919bQyWE3r3N2EVo3gwIelNiMue0oHm0MDfmxL318KMsYHQwE+5ElbGslEN/Fmu5n0RkZOos7nAJctncsnymSOPee/ZdbjM47t7eXRnL4/uCr/+z8O7R67pak5CGBsOZfM6WN49dZv+FcDkZ/3MfDE7cgdlVh9rUTEsroSDwcv1IDZQxNXvpoz7S+TFiGp7TKUvIiuCq0RkpYjqYMTuNKJQrBG5nP5Sgchy5saHSSxlQXSYZhcCX4Sn7CN2FNsZ9EU2tMxhc3s3PeVWNpS6KVcShppKpE0R8aDhzRFVwgiNPDKiSkzsPdQywumV4F2E1eeJWf1tqnlfROTUMTPmdzYxv7OJF58+Z+Tx/krK47uOBLJHd/Xyudu3Uk3DX4iLsePsBR2ct7CT8xZ1cP6iThbPaJ4SK2UKYHJix5i6P7JFaaN+zbKw0uQccZrhk5hoqEBSv4MyGYxIi0a13VE90ERaLPHYzCasKeP+9gXc2b6U1qTC6tYeOuIhZsV9dMe9YbuSjGarsLywl+64j0PNzcxvOsxAWmTLzBns7W+hPFTg0P4SrmIUDkYkfVFYFTtYJKrmJL0ZcX8VV82wgTLuKati+ciwV19LNXlfRGQctBbjnzkrM81yNu0b4JGdh3l4Ry8PbDvEF+/cyg23hd+PO5sTzl3YyfkLOzhvUbhpoLut2Ki38KwpgMnYHL0qlhNWwwCqtTAodchh0UBYFTscY1FEVEiIS0WIIppaS/hCTNpaoNYWkxWNoRkxaVNCtaPEhhkd5KWcx2fPoa2pzJL2g5zWuoeueIAzizvojAZZmhyirT7ctdwKNYxtaTs7al3sS9u5v28hvdUmNh2cQd+hZvxgRKknxlUcpQMRhZGtyiJWy4l6K2FbtZbCYDn0jEVhm3LkA9QzJiIyTuLIsXpOuJPy9ReEx2pZzvo9fTyw7TAPbj/E/dsO8bEf7WX4eMwFnU1cuKSLS5bN4PnLZ7Ciu3XCr5IpgMmz9zOT9+tfZ0dmdBmEFbI4xpxhtWzkPzpXdeSxEVUADMyRlYyBuIlyU4Esd+Te6CoMEuGZEfczNz7E3Kgfhyeq/7w2V2ZufJiCZRxubuJwoYlqHrHb5fSXilTyJly5/v0LRjzkwQrhgPDIEQ3FWDXFudCsT7kSQmX9PEryPLwnNe+LiDREEjnOmt/BWfM7+OVLFgPh4PLhFbL7tx/izk37+a/6MNlZrQUuWTaTS5bP4JJlM1k1uxU3weaUjWsAM7NFwOeBuYQ/xT7lvf8HM5sBfBlYCmwB3ui9PzietclzMDqIpcPhpD6DK02xSoQ3wwbi0CtWKBAVC/g4otTahE8i0taEWuvwqlhCVizQ39HEvZ0zyUueH81aRalUY357L8vb9tMZD3J6007aozLdUS8L4vBxZnEXALs72jiQtbI3beexwXn01kpsPDSLg33NVAcKxHuT+tT9mKS3RDLkKR5sq29VVnGDVajWcIPlsNJXnzNmWTZyQLjmi4mINE5zIebiZTO4eFnYvvTes2X/IHdu2s+dmw9wx6b9fPOh8GdCV3PCxctmcPnKWbzkjDnM72xqZOnA+K+ApcDvee/vNbM24B4z+x7wa8APvPcfNrMPAB8A/mCca5PnalQI8Wka+sbqTfvmrD4GwvDlCpaEMObKVYgcrr9E3FIkL0TEgwWykqPc70gGHGkJyrVm+ppyttZianlEV3GQjniIsh+g0w3SbIOUDGZERRyOJXEvFX+I/ZmxpLCXQ1kzc4qL2dw6k57BNnYlHaRDMT4KoS8dMPAQVcPdNnFkuHIE3oe+Nu/DNmvNwopYFP4m5YdvTtCqmIhIQ5mFKf3LZrXwposX471n24Eh7ti8nzs3hUD2nUf28MffeISzF7Tz8jPn8rIz53D63LaGbFeab+AfGGb2DeBj9Y8Xee93mdk84Gbv/WlP99p2m+EvsZeMR5nyXIz6j9qiKHwSReFzM6xQgMhhxSIUEnwSk7eVRs6irLWGxv3yDEdWhGqnp9qZ40s5LbMGaS5WWdx+kGUt++mIh1hd2kWLq9DpBmlzVXJvVHHk3rEj7WJ/1kpPrZ31A7PprTax9VAXfQMlsv6EZH84Cql4AJJ+TzzkKR3McLWwKmZDNaxaw4ZXxapVqKX4ka1KnUd5Mtzpf0CvPzCx9gpOgjVr1vi1a9c2ugyRaW1jTz/ffXQ333t0D/dvO4T3sLCriZedOYefO2sulyybcVLDmJnd471fc8znGhXAzGwpcAtwNvCk975z1HMHvfddx3kpoAA2qQ3/x11fGcMcFrkjwayQhM9LRXypiE9iso5SmC/WmVBtdaRNUJ5lZCVPdWaO66rQ3Fzh9Fk9dBUGObtlJ6uKu+l0gyxPyhTNkRARmTGY19ibewbzmEeq89lamcX2ShcP7p/PQKXAob2tuN6YpN8o7TWiiqd0MCfpz4mHMpJDZawWpu5TS8NMsUoF73046FyzxZ4TBTARGQ89fWV+8FgP33t0D7du3Ec1zVnR3cKvXbaMn79gAS3F575J+HQBrCFN+GbWCvwn8D7vfe9Y06aZXQdcB1Ci+dQVKKfWMcZajDzufThsO8qPHLydZUSxw1XrvWQ+Iqo38OfFMPurlhfpG4zZ6GbRWqyS5hF9WYlZSR9VdtJiVWa4Ki3OyLwnAZpdSnfUC0WILKe/o0BvtYk0dwwWi1RKCRDhKoZ3jrTkSAbDtH9Xy4mTCKvUwtT9yIVG/chBLQ3vpZYCUX2bEvWKiYhMILPbSrz54sW8+eLFDFRSvvPIbj770y388dcf5q++/Thvet4i3nrpUhbNODV5Y9xXwMwsAf4b+I73/u/qj61DW5DT16gVMTiyVTncJ2aRg2IxLAs3lfClQmjcby/hE0e1Iw6rYiULq2IFqM7IyTtqJE01Fs06RFtS4bT2PSwr7qUzGmR5oYeSZbRYSmJQ9kZfnlD2MRuqc9lT62BbeQbremfTXy2wZ28Hvj8m7o0o7TdcFZr25ySDnnggo3C4itUyXF8Zq9bC4eDV6s9M3X9K4z4okB2DVsBEpFG899z75CE++9Mt/M9Du8i959rLl/H7V5/+rCbyT5gVMAtLXdcDjw2Hr7qbgLcBH67/+o3xrEsa7Gcm79fvoMyy0LxvDqvWwupXpQqFBBfHJEPVcCTSQJGkNTTuR5WYrABx2VEdKJC2JDyZO4qlGjlG2uqYWzxMd9wLrky3S5kRHRngl/mUudEmDhULbCnOYmahn4O1Zh6MFrCvv4XB5hJlK+CqYN6RFTyF2LDM49KYJPOhp80M8/5npu6HN6ap+5OBmf018BqgCjwBvN17f6j+3AeBawnLt7/tvf9Oo+oUkZPHzLhoSRcXLeli9yvP4KM/3MC//GQzd285yMd++QIWdp281bBxXQEzs8uBnwAPMXLbGH8I3Al8BVgMPAn8ovf+wNN9L62ATXHHWhVzFn6NojBXrBiOQPJNRSgWyAsxaUdx5CzKWouj1mxUZkBW9FRnZri2GqXmKktmHByZur+wcIDOaJD5yUFKVqNAjjPPYJ6wP2+hnBfYWJlDT7WNHeVONh2ayVA1oX9fC24gIu4zSgcMVyH0ig3lxP1ZaNyv5bj+oTDo9TirYuoXe6qJsgJmZi8Hfui9T83s/wPw3v+BmZ0JfAm4GJgPfB9Y7X39bxDHoRUwkcnpmw/u4g/+80Gcwb9eewnnLeoc82snzAqY9/5W4Hi/sSpNyRHHWxWDI837g1HYlkwSLHK4OKa4twSRo9TSRN6UkJcSqp3hKKShGRG11pi0pYl13W3kpZyHu+cxu72f2c19nN++nRlxP6cXdzE/7qMzTjnb1QC4omkXZe/ZnRXZMGsuh7Jm7p23hF1D7ezqbefQ3lasElHuiYgHHIXeiKYDMVHFUziYEJVT3FAtNO6nGZTLkKahiT9NMfMa9DrBeO+/O+rLO4A31D+/BrjRe18BNpvZRkIYu32cSxSRcfCqc+dx9oJ2rvn4bXzmts185E0XnJTvq0n4MjkMh5FRzftGfcYYNXxe3+qr31FpzuG8x9KcxBlRweGjGJc6XM3wkQtDX62Z7dWYg81N5N7oSMr0tZZYlIRVsUXxofrUfYgIh4PPjPpJLGVx0wGaohrOPGnuqFRiKnkTaZMjL4TG/agKeVwkHkqI+2Pi2GFpjsVRaNKv1rBRd1DifVghy/2Rxv3R718a5R2EYdEACwiBbNj2+mMiMkUtmdnCS8+Yw3cf2U0ty59VP9jRFMBkchm1Mua9ARk+c5hL671i1XD00eDQyFZl4UARIkexpYm8FIdVsY6EvGCUOyNqLc3UWpt5YEYnedFz86xVtLSWmdE8xGmde2iPy6xq2sPsuJc2N8TcqJf58WFWJHup+Yj9XS1smz2Tw1kzj/TP40ClhZ297Rzc3wqViMLeiHgwnENZOlAiqnqKB1uJyinRQHXkYHA/VMaGV8Wq1RDKRp9FqTB20pnZ9wkncxztQ977b9Sv+RBhiPQXh192jOuP+S9l9J3bixcvfs71ikjjlGsZ/ZWUgUpKZ3PhOX8/BTCZvI4TxgAsy8Pdk85BrRrCWKVKVCzgCgmuHIa9RuVi6BVrcURVIysa5bRIX1vMUFsRM09HoUzRpQC4OGdu1E9iOTNcSsGM7jyEsr68RGtUpqfWTnsym41RzmClwIBvIRtw+MhhuRFVDJdF5ImBGVEWgpZloSnf19+b5Xn4PPeMjOqQk8p7/9Kne97M3ga8GniJP9Iwux1YNOqyhcDO43z/TwGfgtAD9pwLFpGG+Oq92/nvB3fxuy9dfVLCFyiAyVQx6jxKqLeO+Tw08ed5CGJZBtUqliRE1Vo4i3KoRqEYkzUnFA/HZAWjeMCRNifUWhKenFnCFz2PzphLa0uZmS2DLG/bR2tUYWlpPzPifkpWo80NAbAgOUh33MusuJ9FzQfpS0ts6OjmcLnEwcMtDB0q4MqO0r6YqExYFTtUDL1ih1pxlRQ3WMMNDIUm/cEy5PWm/VptZFUs9IqhVbFTyMyuJhyJ9kLv/eCop24C/s3M/o7QhL8KuKsBJYrIKdZbrvGxH27kM7dt5uKlM3jvi1eetO+tACZTy+hVsZywohQWr7BqNQQyZ1gcg3NYISGKY+JCgUJTEZKYlvYSWVNM2hxRnhGRFYzyrBYGWpo51N7F5u6ZFIs1ls08wJxSHwubDnJGaSfNrsL8+CAtlnJ6YQ+1ZkfZx2xp7+ZQ1syGoTls7Oumt1Jix95O8oGY+GBM5UCEq0DzXkdcLpD0phQOJ2GuWBSF4FiuQMWFPrdqdVSvmMZZnEIfA4rA9+rDou/w3r/Le/+ImX0FeJSwNfmeE90BKSKTS5Z7/n3tNv7mu+vYP1DlDRcu5IOvPIPInbwbtBXAZGobnrpPvY+KDPNhm88sbAGS+/C1M0gzoshwaY7lCT6CrBCa9qOK4WpGhRKDxQKbvLGvuYW9La30txZpjSsjg14TSylZjcw7InJaXIVZST8DTUWa4yoD1YTBYpGyK1G2mKgS5opFZU+xkJAnjqiak8QOq6ZYEofQmGVhO3X4/MksC/Vr2v5J570/7l91vfd/AfzFOJYjIuOgp6/MN+7byVfWbmNDTz9rlnTxmV+7mHMWdpz0n6UAJlPfz4y0GBW60nTkLErfH86itN4w6DVKEgrNJYgjWtpL5MWItCWi3BmRJxHlme30NrWzv62bh2YuxBUzZs7op6s0xJzmXpY176fZVZmTHKbZVVhd2sXq0i7KecKezg76sxJPDs3gyf4u+ipF9u1rw5cj4oMxxUMJURma9iVhe7I3I+mt4aoprjfMFbNKFV+uhLtCq7UQwobDGGh7UkRkDMq1jO8/tof/vGc7t2zYR5Z7zlvUyUfffAGvOXfeST2cezQFMJl+vH9qGLP6pPoox7sUS1P88LDXWg3imLiW4gsJ0WCCqxXJE8OymLQZoopR8QlZIWa/wWBzgWoe0RTVaI0qdESDlKxGs6vQ7spUXUSLq1L2Cc2uSlNU40C1mSw3BstFypQwYqKy4WqOeMhjHizzRLHDahnEUbgVL8/DKliWhwWw+gofjLxFERE5Sl+5xk+f2M/N63r45oO76C2nzG0vcd2Vy/mFCxewcnbbKa9BAUymN39k8KnPCDcb5vXxD8MfzqBWw8UxbiAhGgh3UCb9RbJiRK01orwvzP4q720ibSrxZFsbW7tm4QoZ3V19tBcqzCwNsKDpEEWX0hEPkliGs5yFxYPMSAZoiaoMZAV2tnewf1YzlUpCZWYRV3YUDkUUD0dEFU/pQCGsivXViPqrWCULTftpFlbEalV8lmNHN+1rRUxEpqk89zy6q5cfr9/LLev3cs/Wg6S5p6UQ8bIz5/ALFy3kBStmndQerxNRABMZy9T9ofKRqftxjEWOwv4mSGJKTUWaWwrkhYhqV4G0ZFTaHZUZRbIC9HSX2N2UkXRU2N7VSXNSZWXbPjqTQWbEA8wp7AdgZWkPmXfsauukp9rGoVoTT8yYRX+5SO/BZiqHEqIho9YahV6xQ47i4ZionJMcjELTfhLDUBSm7ZuFAGkGWR4m7Q9/rRAmIlOY954dh4a4e8sBfrJ+H7ds2Mu+/ioAZ81v59evXM4LV3dz4eIuCvFzH6r6bCiAiRxt9NR9QvO+mcfjw9R9QsO8r1axPMe8JwJcNcLHDld1mAdw5AXAHFmTUas6dqQRSSGlnCa0F8vMLA4wr9hK4jKKlhJZTs1HJJbRFNXoKJYpRBlp5hgE8lKEZRFRxfCRkcdGXPF4B1EtJypGuCQOZ0/G9SBWq4ap+1kW7gLV2ZMiMsWkWc6ju3pZu+Ug9zx5kHu2HGR3bxmAGS0Frlg1iytXdXPF6lnMbis1uNpAAUzkeJ7SKza8TelCj5g5rFzBOxfuTuyNMRdRPFgaOSC8pb4qVutISEuOapuj0tFEXoQdM1p5ssnj21JaOocoJjUWtPfSGldoS8p0JkPMdAN0dg6Re+NAWwuHZ5foq5bYObedSiVm6FCJ+HBEVHaU9oVZYqVDRQq9zUSVjPhQBVdNscFyqDXNwrFHWRbOoayHsCNN+wpiIjI5HB6qcd+TB7ln60HWbjnI/dsOMVQLv5fN7yjxvGUzWLOki4uWdHHmvHbcOG4tjpUCmMhYjBr0GjJZBj4KT2VZuJsyiiDPwEVYLSWqpURJjGWeuBgRVcNZlGnR8GZkg0Y1ixmwEuViQiHOqBTC/5KtUQUcNLsqkeU48zRFNZrjKtU8olyK2Zc7UgrkBUdUCWdbWgZ4RxwbrppD7HD1VTpzKT7Pwpaq93gyyB315TpAIyxEZOLpLdd4eMdhHtp+mId2HObhHYfZsj/MRnYGZ85v55eet4iL6oFrfmdTgyseGwUwkWdiVEAZWTnKffgY2eazsMJUrWLOEVfCUUjx4SKFQyEwlQ7FZEWj2mZUO0Kv2N6uEntKng2tNZrbKhTilO6WAYpxSimqUYpq5N4xszRAmkckLqe/tcBQpcBgWxGqjsqhiKQvIi5DaX9EVPUUDpdI+lNcJSU6PARZhhuq4CsVyPLwa57js1zbkyLSUE8XtgAWdDZx9oJ23nDRQi5Y3MX5izppKU7OKDM5qxaZCJ5yFuXwFmUIZVaphIGpZtjAIDjDFYtEpSJEEYW2JnwSkbYVqLaHI5CGZjqyoqPaETHUWWCglNPfVaJQTJnRMsispn4KUcbM4gCReWYWB6h5R3+tyJ62Nsq1mMMdLQz2JbhBR9rkcFWj1GwU+hzxUEIxdrhajosibLhHDEJwTNOwmmd+1FFOatgXkZPPe8/2g0M8tquXx3f38fjuXh7d2XvcsHX2gg7OWdDBzNZiA6s+uRTARE6Go8+izN1I4z5pCmaYudDAX7+L0uIo/A/oISs6vDOyIrjMcKkjKxnVvEStkFMpJ/S1FCkmKZ2lIQpRhsPjLKeaxyQuhyRlqFijkkPuYqo1w1XBvJHHjrgJoIirhQn70UCMVdJwd2f9nEyqtfA5AFE9UGo1TESevYFKyro9fSFs7Qph6/FdffRV0pFrlsxs5oy5UzdsHYsCmMjJdJxVMXOGr9bCxH3nsIEEoogoSYhKxXAweEtpZFWs1hq2KMtdjqwQUW1PGOhoorfg2dOZ4goZTc1V2prKROZpSaqUYk+pM6XWFlFOY3o7S6RpRN/hAtGAIxoySgdiXAVKB2MK/UWioYzCoRJWy7CBMq5cxddq2FAZ732YJTa8NalmfRF5GmmWs2X/IOv39LF+T99I2Np6YHDkt43WYszpc9t43QULOH1eG6fPbef0uW2TdhvxuZh+71hkvBzduG/13ipzYYK9c0em2Ef1g7bjiDjzWOrJCw5vYXsS7wAjKhqVOCIvOMoGzuUU44xinBKTk7iMxGVELifLjVoW0Z86MovxzlGrGFECtaqFURoOokqCix1RLRxlZIBP03DeZB6F96FmfRGpy3LPtgODrNvTx4Y9fazf08/6PX1s2jtANQt/8TSDpTNbOHN+Oz9/4UJOn9vGGfPaWdjVdMqO9plsFMBETrUTBbFaGlbF6ndSRpUqrj+BJCYeKJHHjlJ7Qq3F1VfFIvIC1NpiBltK9Bdz9rWluCQnKaQUCynOwuyyyDyl5ippISVtjhkqxVjNUWs14gFHPOQotodm/eLhInF/jaiSHTlvcqiMr9VCf1i1Fpr107R+sDk6/FtkCsvzMMx0Q08f63b3h7DV08eGPf1U0nzkugWdTZw2t40XntbNaXPaWD2njRXdrTQVogZWP/EpgImMl9Hbkzlghk8Zmbbvq1XMDD8Ypu0TOeK+JogcSUsTeVNCXowpHUzIio5Ku6PaZmRNEZUuR554Km0xaWuKizJamqrEUU4cVTHz1NKIoVJClkZUigVqQ2FbMisYUTX8WmhyxEM5BQOr5bjIYeUq1GrhoG8fVsgY2ZZEjfoik1ya5Tx5YJANPf1s7AlBa+Pe8Hm5diRozW0vsWpOK7/6/CWsntPG6rltrJzdSus03D48GfRPTaRRvK+Hlzw07ZOFAa9HT9t3DjPDeY+lOYkzooIDH2OZIxsCvJEnRi010oojK3h6axEWeeI4I4pyvDfyvL70n+Tk9Wn9aauR1cAyw0euvuVZJKrmxI5wvFG1diR4DY/ZyI/8xqz+MJGJr5rmbNk/wIY9/Wzo6QuBa08/m/cd2TqEMMh05Zw23nLJTFbNbmXl7FZWzWmjoylpYPVTjwKYSCM9pWnfgOzY0/YHYkhiLI4pHKyPsmgukTeHVbFqR0JeCGdQ1loishJUO8JWZbUlx5dyiHOiYoYzT9yU4psgb3aUm2IsM9K2sCIWDxrV9gRXhdKhmKQ/IxpKiYsFSDNsYAgq1RDC6jPEMFc/ZxJtS4o02EAl5Yn6CtbGnn429PTzRE8/Ww8MktXbB8xgUVczq2a38qLTu1k1u41Vs1tZoRWtcaN/yiITxYl6xbIMi9KwOuYcDrA8x1UzcJAloWnfckdagzw28ip4c2QGPjPyyEMU+sOceYg9WSHHZ0ZWrK+OeSOqGC7ypGWH1f9i7CoxVjOsGmPeh8PK6yM28B7zdqQ3TNuSIqfcgYHqSMgKQauPJ3r62Xm4PHJN7Iyls1pYPaeNV54zj1VzworWiu5WSol6tBpJAUxkojleEMtyvKvfnegcVqthQ6FfLBkqkcQRSX+RrBSTlSIqhyPyGKrtRtockRWg1h7hI8iacnziw/d29Z9WyvFJ2Mr0sWGpkZWMeMiIByNKTREu9RQOFXBDNaxcww0UQoN+uYyv1WtLw2wfbUuKPHfee3b3luu9Wf0jvVkbe/o5MFAdua4piVgxu4WLl81gZX3bcOXsNpbMbCaJXAPfgRyPApjIRHWM7UlgZFCqr1axKAoT94eGwnFHAyWiQoIvJSR9JfLEUR2IqTU50iajUjHyBGptjqzkyRPwpbCCRpLjC5DFjjwJAcxH4XVxCbwLd0tCMZxtORCFvrRaGlbAzIXw5etHM42s4Gk1TOREstyz/eDgSMga/vWJnn76Rw0s7WhKWDm7lZefOSesZM1uZdXsVuZ3NE3IA6fl+BTARCaD4YZ9GNnmC5P2s9AcX61BFMKOeQ95TmSGSyIwcNWIqOrAO/IkNNxnJSMveNJahHfgE4+PPJYbeAODPAlbi5mHWs3IC+BSRx4nJM6IvcdqWVhEi0OzPhC2TNMwYT/Uq2n6IgC1LGfrSCN8/8idh5v2PnW0w+y2Iitnt/ILFy4YWc1aObuVWa0FzdGaIhTARCaLUSti4ct683stxaLwN+SRMyjjGNdfgCjCHSpSKBbICzFN7QWyxFHtiEmbjLRkVNsdPoa0GbKix0eQFzwYYZWs6MmKYUyFS42sGBGVPcmAo9gU4ao5hUKMG6pBpYqLXJieXy7XB81m+Nrw3+DVoC/TQzUNQWv98B2H9V837xuglh35f2DRjCZWdrdy+cqZrJrdxor69qHuOJz6FMBEJit/ZGXJZ4RZYrUUy3P88IgIZ/WzHnNcmkDssNjhY4flocE+jyFPjDwCb4aPfVgRq7eN+Pog/jwJq3BZITzo0nCGJQZ5MQbvcXkOcYxZBpHD53n4+W5Ug77IFFLLcjbvG2D9niMha0N9tEM66o7DxTPCHYcvOWMOq2a3snpOG8u7W2gu6I/h6Ur/5kUms6dp2CfLjhx3FEVYkoQp+3FENFgiL8bkTTGF3oQ8MaptjrQEWdFIWwzvICuCj+s/w0Eee9LmcGh4noRDvl0t9IxF5YS4PyFOIizNwzDZ4QO+K5UQDIen6GtUhUwyee7ZfnCIdfVzDtftDh+b9vWPrGg5gyUzW0KP1llzWFXfNlw5W3ccys9SABOZCo43T8xZOAzcDB9FWKUKkcNVargkxpcKuKEiPomIKglpkwvbkrVwN6TlkBUsbEsmHqKwTWl56BHDDFcDyyPiJMe7BMs9rprh0voh5BDujsyy0KxPhs8d2o6Uich7z97+Cut397NuTx/rdveybk+YDj9YzUauW9jVxGlz2njxGbNHjt9Z3t2ioCVjpgAmMtUcd8L+8CHb9TCW55j3RGb42JGMNOtHmHfkcWjWj4qQJyGIUe/PH5bX/6zJioQjlTy4aowrRCRpDlGY4u9zj+UZDIX5YUamY4yk4Sppxsaefh7b1cdju3p5bFcvj+/ue8p4h1mtBVbPaeOXnrcoBK25IWxpWKk8V/ovSGQqOtGE/TTFm2FJDINFnHMkA00khYS8mFBoL5DHjmp7HLYkS0atJayEZaUjwStPwMcARpqGYbB5lBDVPN6FbcmoEOPMIK2vHlRr4WxJOHKHpEKYnGL7+isjIWs4cG3s6R/p0yrGjtPntvGyM+Zw+ry2kbA1q7XY4MplqlIAE5nqjl4Rc/mRbUkI87sih1ViyH0YKTHosMQRJw7zBoTxFXkEPqqPw3D1Bv16kz5RCGN5AhC2MF3m8EmEj6MwLsNFWJTh8wiyvF4LWgmTk8Z7z7YDQzy88zAP7TjMIztD6NrbVxm5Zm57iTPmtfHi02dzxrx2zpjXzrJZLUSaoyXjSAFMZDo4ulk/96Fj+CnN+vUtw6GEqFzCxRFuqEReislLEclAHJr1WxxZob4tWRzuBaM+N8yoNYNLQ19YVnIksSMxOzIvrJqEMyThyFBZrYTJs5Dnnq0HBnl4x2Ee3hEC18M7DtNbDmNPYmesmtPGlau6OWNeG2fOa+f0ee3MaCk0uHIRBTCR6eV4d01CGJ4aRWGLMM2wyBGlGa6ckJcKWOrJE4flEWnR1c+ODFuNeUK4SzICCoaPPGnNhcGuGbg0xqoOqxRCVqvfmekhBD+ynylVZLThla37th3koe2HeXjnYR7Z0UtffUp8IXKcNreNV507n7MXtHPOgg5Wz2lTU7xMWApgItPR6CCW1wd+ZUdWoKxWxecRVgnN+i4PZ0a6yIEVcFVPmoZhYT7yYSxFNGr7xhPunPT1sRa1iMgZvpgcmdafplgWHVkFy0I9aBFMgMNDNR7cfoj7njzE/dvCx3BzfCF2nDGvnWsumM/Z8zs4ux62CrHOPJTJQwFMZDrzvj7F1fBpmKxP/UBtM4NaGlbFCglRuRgm65dL+EJM1pQQt8b42Ki21nvERnrA6p+PhLKIvOqwtIirRESAZTmhcz8Lv0KoQaadNMt5fHcf928bDlwHeWLvwMjzK2e38uLTZ3PB4k7OX9TJ6jltOmBaJj0FMBE57ugK6odrmzOoOojy0KwPOGdEiSPPjag6HLRsZIL+yLc2wztPHhs+NnxWn8QfubBM5iz8bGchAMqUV65lPLDtEHdvOcCdmw9w79aDDNRnbM1sKXD+ok5ed/4Czl/cybkLO3Usj0xJCmAiEvzMtmS9L6t+nuPwaphlGRbHWKmAq2b42BFVC2QFR1ZypCVXP7rIRmaG5Ulo1k+bIlziwINlHmoplqYQ1aBc/1laBJtyess17tl6kLs3H+DuLQd4YNthqln4F716Tiuvv3ABz1s6gwsWdbFoRpMOm5ZpQQFMRJ5q9LZkRshhuQ9N+q6+QhXVsDz0jFkSE5vhClHo6bL61qMd2YL09Qb9fHiYa8GRF+JwV2Qch5EUURRW3dAfvpNduZZxz9aD/GTDPm7buI9Hdh4m9xA54+wFHbztBUu4eNlM1izpokt3JMo0pQAmIsc26rBvfD2MDTfPexcGutYHrDrnII2J66toeeJGpukPzwkzH7Yj85hwhmQxIgeiJMa8x+dZuFOy1qg3LM9Wnnse393HrRv38pMN+7h7ywHKtZzYGRcu7uK9L17FxUtncMHiTlo0QV4EaFAAM7MIWAvs8N6/2sxmAF8GlgJbgDd67w82ojYRGaUeqHyWhQRVXw3zzrBaCrUqxDGWZkRJjKsUcdUCeRKR1mLygiMvGFlyZCUMM7Kiw7IY5wxXLmBm4bzIKAr9YDLhHR6qcfO6Hn74eA+3bdzHvv5wh+LK2a286XmLuWLVLC5ZPlNH9ogcR6P+z/gd4DGgvf71B4AfeO8/bGYfqH/9Bw2qTUSOdtRqGHn9bknnMNKRuxiJI6wW4QCXRuD8U7Yin/ItneGdgatP4jfDIs1smsi2Hxzk+4/u4XuP7eHOTQdIc8/MlgKXr5rF5StncfmqWczraGp0mSKTwrgHMDNbCLwK+Avgf9UfvgZ4Uf3zzwE3owAmMrHU75QcWQ3z+ZG7J+FIg36a4eOIOPP4xJE3xVgWh6OLonCeJEa4I9I78lIcwlcthWo13A05gZjZ+4G/Brq99/vqj30QuJawJvjb3vvvNLDEU8Z7z2O7+vj2w7v43mM9PLarFwirXO+8YjkvO3MO5y/q1BE+Is9CI1bAPgL8PtA26rE53vtdAN77XWY2uwF1iciJHG9cRbUGUVghM8DiGBdH+DwGZ7iofmdk8UiTvXdGHllY/Yo9Po7ARUykJnwzWwS8DHhy1GNnAm8CzgLmA983s9Xe+ykzzn/bgUG+cf8OvnH/Tjb09OMM1iyZwR++8nReduZcls1qaXSJIpPeuAYwM3s10OO9v8fMXvQsXn8dcB1AieaTW5yIjM1R4ypGDvfO87BCVp90b+Vw4LYj3P3m66sk/qhp5T5y+NhjUYTF0UTKXwB/T/gL4zdGPXYNcKP3vgJsNrONwMXA7Q2o76Q5MFDlmw/u5Ov37+SeraEF93lLu/jz153NK86ey8zWYoMrFJlaxnsF7DLgtWb2SqAEtJvZF4A9Zjavvvo1D+g51ou9958CPgXQbjN0YIlII9X7wkYf7u29HznnEe+xOIY8J/I+NNgDPqmfHxnXA1lk5ERYEmOFZMI04ZvZawk3Cj1w1FyqBcAdo77eXn9s0vHec/eWg/zrHVv59sO7qGWe0+a08ftXn8Zrz5vPwi79RVfkVBnXAOa9/yDwQYD6Ctj7vfe/YmZ/DbwN+HD9128c73uIyAQyqgfM5x4zD94fCWJmWJpB/Yghy+LQPxaDecOG/xplQP0OyfFcAjOz7wNzj/HUh4A/BF5+rJcd47Fj/oVw9Kr94sWLn2WVJ19/JeVr9+3gC7dvZd2ePtpLMW+9dCm/uGYhp89tP/E3EJHnbKLcH/xh4Ctmdi2h1+IXG1yPiIzVqDskR7qg6gdsjxxl5P3IHY4+icgL8VMji9Xvhkzicd2C9N6/9FiPm9k5wDJgePVrIXCvmV1MWPFaNOryhcDO43z/kVX7NWvWNHzVfk9vmX/+8Sa+fPeTDFQzzl7Qzl/9wrm85rz5NBV0B6rIeGpYAPPe30y42xHv/X7gJY2qRUSeo1HN+XgL9wZmGT4PE+6pOog9VgvBzDmHj+zI6pmFIa3mJsYkfO/9Q8DIzUBmtgVY473fZ2Y3Af9mZn9HaMJfBdzVkELHaOehIT754ye48e5tZLnntefN520vWMp5Czt07I9Ig0yUFTARmeyGx1Tk9YWeLAt/uNfSI435lTj0i5nhotCwPzJ2InTrT4T89bS894+Y2VeAR4EUeM9EvQOyp7fM339/A/9xzzYA3nDRQn7zhStZPFO9XSKNpgAmIifP6IGtWcboFi9fP+/R6tuTzizMBUuiI3dIOjdhmvBH894vPerrvyDMMpyQqmnOZ27bzEd/sIFa5nnT8xbzrhetYEGnhqSKTBQKYCJyco1aCbP6nZFkWVjYqjfmk+eQ5xgOH/kJGbomqx+v38uf3fQIm/YN8JLTZ/PHrz6TpZrbJTLhKICJyMk3akSFUe+39x4qVcjysDVpBsNHEOU+LJy5p/2u8jQGqyn/978f5Ut3bWPZrBY+82vP46rTNdNaZKJSABORU2p4PIUnG9l+JI2wNMYDlufhTEmvlbBn66Hth/mdG+9j8/4B3vXCFfzuy1ZRjHVXo8hEpgAmIqfG8MR8I9wdmbswMR/waYbV0hC6kvipZ0rKM/KVtdv40NceYmZLkS++8xJesGJWo0sSkTFQABORU6t+biRkkFqYC1Y/msjyGJIY81HoG4sUwsbKe8/ff38DH/3BBq5YNYt/fPMFdDYXGl2WiIyRApiInHo+B6s3eOV5GEWR+5FmfLLQDzbhZ1BMEHnu+eBXH+LLa7fxixct5C9//hySSA10IpOJApiInFqjh7RmGd4cpGn9DskYqyWQezyxhoKOgfee//Pfj/Lltdt471Ur+b2Xr9Y/N5FJSH9lEpFTr94P5nNf7werr4JlefgYXgkbda0c28d/tJHP/nQL77hsmcKXyCSmFTARGR/egw2HsFGzwdIUAHMOPxzC5JhuXtfD33x3Pa87fz5/9KozFL5EJjEFMBEZP/7IlHzMwnywNA1BLHKQRU89pFtG9PSVef+/P8Dpc9v48C+ci3MKXyKTmbYgRaQhvK+vhHkfPh9eGVMCO6Y/+trD9FdS/vHNF1BKNONLZLJTABOR8TMcuPKwBemzDGpp2IbMcyzTFuSx3LZxH999dA+/9eJVrJrT1uhyROQkUAATkcbJ/chZkWRZPaA1uqiJxXvPn3/zMRZ2NXHt5csaXY6InCQKYCIy/nweVsFg5I5In+WQZiiBPdWP1+/lsV29vO+lq7X1KDKFqAlfRBrD5+AtbEUCNjyeQvnrKa6/dTOz24q89rz5jS5FRE4irYCJyPgbPRds+Gvvj8wCEwB2HhriJxv28ZZLllCI9du1yFSiFTARaSif5eF8yCwLc8A0iHXENx/cBcA152v1S2Sq0V+pRKQxvD8yFwyOnA8pI7776G7Omt/O0lktjS5FRE4yBTARaSyfH7kbMlcT/rCBSsp9Tx7iytXdjS5FRE4BBTARaaiR8yGzrD6MtdEVTQxrtx4kzT0vWDGz0aWIyCmgACYijXN0v5dXE/6wh7YfAuC8RZ0NrUNETg0FMBFprOGZYHkeGvK1BAbAwzt6WTarhfZS0uhSROQUUACT/7+9+42VrK7vOP7+sLssrIhAFQSWrUuzaND6B690ta2lQCKxTdc21ewDIhrtRmOs2gcthKTGBySWmEZNWpMN2mBrpVtEITb4B1prmhToqqgLK7KKhZWVxWqVWtl/99sHc9Drdi572Zn9zZy571cyuTPnzJn7/e7M/uYzv3PmXGkq1BOzYeYvAO7f+xgbTj9p0mVIOkYMYJImrzsQ3/OADRyaLx76wU9Z/yy//SjNKgOYpMlaeBzYoUOTq2OKPPLjx9l/aJ51p62ZdCmSjhEDmKTp8MQB+J6Ile/9+HEAznrGiROuRNKxYgCTNHlP/GmiKg8BA/Z2Aez0k1dPuBJJx4oBTNJUKM+C/zP/9ZP9ADzzJAOYNKsMYJKmhyEMgB92AeyUNZ6CQppVBjBJ06M8DxjAY48f5IRVx7F65YpJlyLpGDGASZoOngX/Zx7bd5CTVjv7Jc0yA5ikqVH+LUgA/nffQdYc7+yXNMsMYJI0ZX564JABTJpxBjBJmjI/PTDP6lUGMGmWrZx0AZIEdOcC8zgwgP0HD7F6pZ+PpVnm/3BJmjL7D85z/AqHZ2mW+T9ckhaR5O1J7ktyT5JrFyy/Ksmubt2rxv17DxwqVq3IuB9W0hRxF6QkDZHkt4FNwAural+S07vl5wObgecDZwG3JTmvqsb2l8QPzhcrjvPzsTTLmv8PT3JKkhuTfCPJziQvT3Jaks8nub/7eWrruiTpMG8F3ltV+wCqam+3fBNwQ1Xtq6oHgF3AheP8xYfm550Bk2bcJD5ifQD4TFU9D3gRsBO4Eri9qjYAt3e3JS03NVUnATsP+M0kdyb51yQv65afDTy04H67u2Vjc3C+OO44A5g0y5rugkxyMvBK4A0AVbUf2J9kE3BRd7frgS8Af9ayNknLT5LbgGcPWXU1g/HxVGAj8DJgW5JzgWHJaGhyTLIF2AKwbt26Jdc1P1+siAFMmmWtjwE7F3gU+JskLwK+BLwDOKOq9gBU1Z4njrWQpGOpqi5dbF2StwI3VVUBdyWZB57JYMbrnAV3XQs8vMjjbwW2AszNzS15em++YIUzYNJMa70LciVwAfChqnoJ8BOewu7GJFuSbE+y/QD7jlWNkgTwKeBigCTnAccD3wduATYnWZ1kPbABuGucv3i+aug0m6TZ0TqA7QZ2V9Wd3e0bGQSyR5KcCdD93Dts46raWlVzVTW3itVNCpa0bH0EODfJDuAG4IoauAfYBtwLfAZ42zi/AQmDQ+HiLkhppjXdBVlV30vyUJLnVtV9wCUMBrF7gSuA93Y/b25ZlyQdrjtG9fJF1l0DXHMMfzfmL2m2TeI8YG8HPpbkeODbwBsZzMRtS/Im4EHgtROoS5KmwnyBh4BJs615AKuqu4G5IasuaVyKJE2loohHgUkzzVMtS9IUchekNNsMYJI0ZabrfLSSjgUDmCRNmcIZMGnWGcAkaSqZwKRZZgCTJElqzAAmSZLUmAFMkqaMB+FLs88AJklTyIPwpdlmAJMkSWrMACZJktSYAUySJKkxA5gkSVJjBjBJkqTGDGCSJEmNGcAkSZIaM4BJkiQ1ZgCTJElqzAAmSZLUmAFMkiSpMQOYJElSYwYwSZKkxgxgkiRJjRnAJEmSGjOASZIkNWYAkyRJaswAJkmS1JgBTJIkqTEDmCRJUmMGMEmSpMYMYJIkSY0ZwCRJkhozgEmSJDVmAJMkSWrMACZJktSYAUyShkjy4iR3JLk7yfYkFy5Yd1WSXUnuS/KqSdYpqZ9WTroASZpS1wLvqapbk7y6u31RkvOBzcDzgbOA25KcV1WHJlirpJ5xBkyShivg5O76M4CHu+ubgBuqal9VPQDsAi4csr0kLcoZMEka7p3AZ5O8j8GH1Vd0y88G7lhwv93dMklaMgOYpGUryW3As4esuhq4BHhXVX0iyeuADwOXAhly/1rk8bcAWwDWrVs3lpolzQYDmKRlq6ouXWxdko8C7+hu/iNwXXd9N3DOgruu5ee7Jw9//K3AVoC5ubmhIU3S8uQxYJI03MPAb3XXLwbu767fAmxOsjrJemADcNcE6pPUY81nwJK8C3gzgyn7rwNvBNYA/wA8B/gO8Lqq+mHr2iRpgT8CPpBkJfA43a7EqronyTbgXuAg8Da/ASnpqWo6A5bkbOCPgbmqegGwgsHXua8Ebq+qDcDt3W1Jmpiq+reqemlVvaiqfq2qvrRg3TVV9StV9dyqunWSdUrqp0nsglwJnNh9qlzDYJp/E3B9t/564DUTqEuSJKmJpgGsqr4LvA94ENgD/KiqPgecUVV7uvvsAU5vWZckSVJLrXdBnspgtms9gzNIPy3J5U9h+y3dnwTZfoB9x6pMSZKkY6r1LshLgQeq6tGqOgDcxODkho8kOROg+7l32MZVtbWq5qpqbhWrmxUtSZI0Tq0D2IPAxiRrkoTBiQ53Mvha9xXdfa4Abm5clyRJUjNNT0NRVXcmuRH4MoOvb3+FwUkKTwK2JXkTg5D22pZ1SZIktdT8PGBV9W7g3Yct3sdgNkySJGnmeSZ8SZKkxgxgkiRJjRnAJEmSGjOASZIkNWYAkyRJaswAJkmS1JgBTJIkqTEDmCRJUmMGMEmSpMYMYJIkSY0ZwCRJkhozgEmSJDVmAJMkSWrMACZJktSYAUySJKkxA5gkSVJjBjBJkqTGDGCSJEmNGcAkSZIaM4BJkiQ1ZgCTJElqzAAmSZLUmAFMkiSpMQOYJElSYwYwSZKkxgxgkiRJjRnAJEmSGjOASZIkNWYAkyRJaswAJkmS1JgBTNKyluS1Se5JMp9k7rB1VyXZleS+JK9asPylSb7erftgkrSvXFKfGcAkLXc7gD8AvrhwYZLzgc3A84HLgL9OsqJb/SFgC7Chu1zWrFpJM2HlpAuQpEmqqp0AQyaxNgE3VNU+4IEku4ALk3wHOLmq/r3b7qPAa4Bbx1XTtX/4q5z+9BPG9XCSppABTJKGOxu4Y8Ht3d2yA931w5ePzcXPO2OcDydpChnAJM28JLcBzx6y6uqqunmxzYYsqydZPuz3bmGwq5J169YtoVJJy4UBTNLMq6pLj2Kz3cA5C26vBR7ulq8dsnzY790KbAWYm5sbGtIkLU8ehC9Jw90CbE6yOsl6Bgfb31VVe4DHkmzsvv34emCxWTRJGsoAJmlZS/L7SXYDLwf+KclnAarqHmAbcC/wGeBtVXWo2+ytwHXALuBbjPEAfEnLg7sgJS1rVfVJ4JOLrLsGuGbI8u3AC45xaZJmmDNgkiRJjRnAJEmSGjOASZIkNWYAkyRJaixV/Tw1TZJHgZ8A3590LWPwTPrfxyz0ALPRR997+OWqetakixi3bsz6z6ewSd+fxyOxv36zv6VZdDzrbQADSLK9quYmXceoZqGPWegBZqOPWehBs/882l+/2d/o3AUpSZLUmAFMkiSpsb4HsK2TLmBMZqGPWegBZqOPWehBs/882l+/2d+Ien0MmCRJUh/1fQZMkiSpd3obwJJcluS+JLuSXDnpepYiyTlJ/iXJziT3JHlHt/y0JJ9Pcn/389RJ13okSVYk+UqST3e3+9jDKUluTPKN7jl5ed/6SPKu7rW0I8nHk5zQtx6WmyONXRn4YLf+a0kuWOq202DE/j6SZG+SHW2rXrqj7W+x8X8ajdDjCUnuSvLVrsf3tK/+yEZ5jXbrf+H976hVVe8uwArgW8C5wPHAV4HzJ13XEuo+E7igu/504JvA+cC1wJXd8iuBv5h0rUvo5U+Avwc+3d3uYw/XA2/urh8PnNKnPoCzgQeAE7vb24A39KmH5XZZytgFvBq4FQiwEbhzqdtO+jJKf926VwIXADsm3csxeP6Gjv+T7mnMPQY4qbu+CrgT2Djpnsb5Gu3W/8L739Fe+joDdiGwq6q+XVX7gRuATROu6Yiqak9Vfbm7/hiwk8Gb6CYGYYDu52smUuASJVkL/A5w3YLFfevhZAaD/YcBqmp/Vf03PesDWAmcmGQlsAZ4mP71sJwsZezaBHy0Bu4ATkly5hK3nbRR+qOqvgj8oGnFT81R9/ck4/+0GaXHqqr/6e6zqrtM24HmI71GF3n/Oyp9DWBnAw8tuL2b6XwhLyrJc4CXMPiEcEZV7YFBSANOn2BpS/F+4E+B+QXL+tbDucCjwN90U8nXJXkaPeqjqr4LvA94ENgD/KiqPkePeliGljJ2LXafPox7o/TXB2Pp77Dxf9qM1GO3e+5uYC/w+aqath5HfQ7fz/9//zsqfQ1gGbJs2lL2opKcBHwCeGdV/XjS9TwVSX4X2FtVX5p0LSNayWBXx4eq6iUM/qzVVB5Ts5ju2K5NwHrgLOBpSS6fbFU6gqWMXYvdpw/j3ij99cHI/fVg/B+px6o6VFUvBtYCFyZ5wXjLG9lR9zfu97++BrDdwDkLbq9lsOtl6iVZxeA/38eq6qZu8SMLpjfPZPDJYVr9OvB7Sb7DYOr24iR/R796gMFraPeCT2c3MghkferjUuCBqnq0qg4ANwGvoF89LDdLGbsWu08fxr1R+uuDkfpbZPyfNmN5DrtDOr4AXDb2CkczSn+Lvf8dlb4GsP8ANiRZn+R4YDNwy4RrOqIkYXDM0c6q+ssFq24BruiuXwHc3Lq2paqqq6pqbVU9h8G/+z9X1eX0qAeAqvoe8FCS53aLLgHupV99PAhsTLKme21dwuC4kj71sNwsZey6BXh9902sjQx2Le9Z4raTNkp/fXDU/T3J+D9tRunxWUlOAUhyIoMPid9oWPtSHHV/T/L+d3RGOYJ/khcG31L4JoNvM1w96XqWWPNvMJjq/Bpwd3d5NfBLwO3A/d3P0yZd6xL7uYiffwuydz0ALwa2d8/Hp4BT+9YH8B4GA9wO4G+B1X3rYbldho1dwFuAt3TXA/xVt/7rwNyTbTttlxH7+ziD4xkPMJiFeNOk+xlXf4uN/5PuZ8w9vhD4StfjDuDPJ93LuF+jCx7jIkb8FqRnwpckSWqsr7sgJUmSessAJkmS1JgBTJIkqTEDmCRJUmMGMEmSpMYMYJIkSY0ZwCRJkhozgEmSJDX2f5T/naBw2XrOAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 720x720 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"coupling = get_coupling(mu123.cuda(), mu231.cuda(), cost.cuda())\n",
"pyplot.figure(figsize=(10,10))\n",
"pyplot.subplot(2, 2, 1)\n",
"pyplot.plot(mu2.cpu())\n",
"pyplot.subplot(2, 2, 4)\n",
"pyplot.plot(mu1.cpu(), transform=matplotlib.transforms.Affine2D().rotate_deg(270) + pyplot.gca().transData)\n",
"pyplot.subplot(2, 2, 3)\n",
"pyplot.imshow(coupling[0].cpu());\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This looks a lot like the coupling form Python Optimal Transport and in fact all three match results computed with POT:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.2526288628578186e-07"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAvBUlEQVR4nO29W6xk2Xnf9/vWZe9dVefW3dNz4XAYitaElGRYkUNENyMQTDtwFMHSiww5UMAEAviSRLJhQKGSpzwEEBAjsIAEBgg5BhMLtgVZsATBsC2MozwEMCXqAl1IUSJFYTjDnu6e7j7XqtqXtb48rF3nHB5293RP387p+n5AYZ/adar2qtP9399a3/ouoqoYhvH84571AAzDeDqY2A1jTTCxG8aaYGI3jDXBxG4Ya4KJ3TDWhEcSu4j8DRH5koh8WUQ+/bgGZRjG40fe7z67iHjgT4C/DrwF/Bbwt1X1C49veIZhPC7CI7z3PwG+rKp/BiAi/wz4YeCeYo/VTJvpJVjdXxQEQLWcU0V09fzkHJx9Xo66+hCLCzIMAJYc0Wkrd3vtUcT+KvC1U8/fAr777C+JyKeATwHUkx2+66/8JJJBsiKDjseMGzKSFOkGSIp0PeSMtOWoXQ85QUpoPxSxpwRZQTM6HoGTG4JhrBmf0zfu+dqjiP1ud49vUpmqfgb4DMDWxqta3enKO+Ub367eoU4RHyErUnvIIE1CcoZ+QNJ4HAY0Z2QYjsUv9xN/GcgjfFXDuPg8itjfAl479fyDwNfv9wbJil/0aHAggnpBRU7EP57DKTgpwnWg2eNE0JzBCeIdMqRyZ3G+3HWyri6CJhCX0Tz6HzWXzzfBG2vMo4j9t4DXReRbgLeBHwP+y/u+I2Xc7iEaPDgHwaPOQXCjyN3xDUCdgHflxgDkJiBapvwkRVJC+gQpI/0wTvk7VMclgOb3nu6D3QCMteF9i11VBxH574B/A3jg/1TVP7rvm3JCD4+QGMF7iAG8G48egpaP8pD9yvq7b1gwSNZR7AE3ZBgyrhvF7n2Z8jtXpvniIKVi1VNCVRAS4MsNAMziG2vDo1h2VPVfAf/qgd+QFW07SGU6Th8Q5yAEJHg0eCQG1LmyZneOHEfRh2Lt1ZUbgKiS1ZfjUNb5rk+QtIh/SMXip4R0PTokJCfoh2L9+x5VNQefsTY8ktgfFtWMLhao94gI+DKNlxCK4L2DGIvwU1Wm+k0ozjvxIEIOggaHOk4s/rg755IiSXFdRlLGtQMyZKQbilc/JVjdbLxDsqJDmRWQMqgU0WOiN54/nqrYUdCsCAkVV/bVR8ebaEazL+eGYrlxDqdabghZ0eCQ4NCoZD+KXihxgAJZBEK5IUhyuMqP23oR6SukT0hblyn/sisiH7f0pB/QlJCUji3+8VTfLL7xHPB0xQ5l3a4CFCGpOMQNZR/dCeI9KoKM63hZVOVcU4H3uCaiVSBXnlSDBiF5h3qK1XecbOupIrlYfNeXfX3fZmTI+OVo9Rddmeovu7KlN4xbeymfTPVH0R+Lf/xsw7hIPH2xw4lQRMb1sgMSooJCsfYray5SLLEI+FFsSXE5F097HL31vgi8rOm1ePNFyk1AynkJivpi9TW4IvrxSAxlnd/1xdoPqSwtUjqZ6ku5lqZUlhBm8Y0LxLMR+4rj0NfRcuroNRdXrCtAW9b3GuO4pq9wwUMVcVWEGPCTiHpHmkbUC8PEkyPkKORQPPt6KuVHlBKtlyG0Ou7/Z1yv+OWA61JZ5y86ZEjIsi1LjbYt0/tR/JqyOfeMC8OzFftZVM9Ye47X90KPZj8a1OKFL9twuQTcBFcceV5wQZAsIK4sCWRl6UFd8eeJk/EHkFz29t2gaBB865Do8c6hQz4O4jne0htFzzCu80XNuWece86X2OHu1p6Epm9e20sY9+ePIs45XFOjwRMmFRo9aRLHtb0jNY4chaEep/axCH+Yrtb3Uiz+4HFDxPeKbxvcoISjhEuKP+qRPuOWHfRDOQ4D2vdlS2815YcyQzGLb5wjzp/Yz3IskhIGi2hZ22ct62jnyrraj9Y+hOKYi+WrSVJEfbHgKmTvyAjE8eO9HE/xVcCVoDxyDymO1t4Lri/+A5cyOJAhlGt3/ehUbJHkj8dcfA+6umdhFt941px/sa84K/pjay/jVFrQrkO8h2WJ0POHFT4UT36si5WvppFcOYapIwehnzhy4Niznz1oBamR4yn+cgDJ4FuPJIiLerT4Gd9l/DLh5n3Z2lu2Zcq/bMuavm2LR/+UZ19Xcfxm9Y2nyMUR+4rTwtBxGy+PltRnlL5kw4lA1RePetcjVSyPPqO1R1Ikx7JPn2NJwMmiEAQNFNF7jvP4JIMbBEmQmvJzqAXfOeLcESqHaxPeCwxl92CVrCN5PKZ0bPGP/RKn4/QN4wly8cR+Fj1ximm6y2tdB32ANiAh4NsODR53WKPREw4jGh39LJBqYZiUdX2uII3r+1QBUqw/jFZfwfWCS+Baj28DvoM4L+v8eLCB6zN+3hWLPx8tftuVNf4wnCTq9MM4fgviMZ4cF1/scBen3kr85bl0/sSpN1+MwTl1OR41ZftuVpHrQL8RSBNHPxH6GeSqbN+ph1SXlNvsKVF7uWy3u24Ufg9h7nGdUm04fA/xIOC7TDiIZUtv3pa9/LYD3x0H7JBXAj9l8S1Jx3iMPB9iP8up9T2Ue8A3OPX86EhzchyW64eEiwHXRnIdCFNPXHhSJfiFkCP4WbH0Q1PEr0GPPfs56jjFF2QQ+g3BDRCPBBmgOgz4TglHE/wylf38+RmvftedVOJJyfbxjcfK8yn2FacsvmZKRN0AiCt75eKQxRJEkMMK8Y7QFEsfm4o8Ha39ZnHqtduOFIVuS8h12bZLTslVeeBAfamJJ4NALjcKlyAc+jLNP/SEOcRFJh40+DYTDlqkT7jDZbH082UpwdWNyTvjw/bxjUfh+Rb7Wc4E7YgbC1y4stcu3hcH2hCQnHEpI20sCTzBgUZyJUh25KpM3VMtpEZIjZI95Ho1twdEyU2ZTKiDYaDMFDagW3jizBFapZp5XK+EgxrXp7K0GC2+9sXhqF2P5DxG71larvHwrJfY4a7WHih17JzAwpWwXF9y68V74p0GgidOG9R7Jls1ufJ024GhcfSzMm0fJtBvOHKEtJHLNH82IE4ZpHjhu95DEqR1uIXDd0I8iLgO6t2I75R6dzqu83v8okcWHW7ZwpDQZbH+x+K3qb7xgKyf2M9yZv9ePOhQkmxQBT8KKIRSKs978ILrSqCO6zySHZIdrgco63vEkaOSKVN7iRm8IiGDBxUleSV3JW7f9SVZx/UlRdd3jqpy+GUgzCP+KJZ4/TFhh+WyRP90XRmnTfWN98DEvmIUhq726FcJOU7QZVuceaFYetmPiPfUk4YqBnRakyeRNCne/KGRsr6vHN1OmfJ3WxmtFZ0kfJ1wdUeMJyHBKTmO2oAODjn0uE6I+4GwDMSDing4JS4y1X7Ct4mwN67zjxYnFn+1nWfOPeMumNjvxn1CdNWVstWl6IYiwSMpl730PiJZR6svpEpLNl5dcu1TkmLNnSKiiCScywSXSdkRQiIlR+cjqXeoD6RlsfSpgjR3qCtWX0VwXcKLIP1QUoCHoWznrQQ/DBayaxxjYr8fp9f3q4IbYxFLFYd0Xcm7P5pDFfEh4JuaGAPNrEGjp9+qSLWj3fYMjdBtBYZpYJgp880GrTNuo8eHxKTpaaqenY0FAN3gSdmxbCOLNqALjz/w+NZT7QZcB83taVnn7yX8MhEOO9y8K9t5bVdmKovl3UN2zeKvFSb2B+Uuwv+G2PyuKzn3bQch4Lq+OPn6KRo9rqtIjcf3jn4pdH2JwkuNJ4WMVgJNjxOYVR3RJZyUa877iuUQOGor5tOabhnIsYhdveCXJZknLEuKbwgOFz3iBOl9mYHcJWRXExa8s0aY2N8PZ6f5eWxqkTL4vhTK7LsSobdYQgg0+w0aA/VmTZoEhpmn2yyRet1ORapgfqnmsFLe3e7xdWI6bZlWPU0Y2K6XbNdL8pbQJ8/ByxXdEDg8aNDe4XcDrnVUezXxqCYeKvV+xi8z1W6H6zNuf1Eq7i6W6JCg72yNv0aY2B+F+wXtdF05rmrpzedICMTDhlBFqmlNtVGRpoH2oDj1lkspzr0hkpvAnOK8q3xiFlumoWcnLsgISYUuB25e2mAxRK5f2qRdRo52K8KBJ+4L/YYnzB2pdvhlpvZSnHrelQKbi7HmX9eX4iBjXX0L130+MbE/Tr6p0s6qZl4ac+9LAgwh4NqOOK8ITSQc1uTaUx1EUiUs7wipcXQ7U5aTCV/b2uDtzR3qpufSxpzKJ640RzhRdqoFW3HJZmzpsufOpUmZ7s9rFvsVbuGodss6v7k1rvN3Z/hlJu6f7OPLsoWuPwnZHRtrropxmMW/+JjYHzdnk3Jy6UizymyTpf+GSjsSA7GqoIpUswkaPdNLk+LUuxTop0J7KdBtBdrNmmuXa+Kkh8swix0vzm4zCy3bfkHjepY50qvnZrfJjXaDW8sZ13a3WCwjixs1vhWam4GwUJrbYZzyD4SDFrcckMN52cf3Y+hu54pPwiz+hcfE/jQ4k4a7qqJbmlOMfeiGVNpXx0BUJVQBN9RUjSMsPfFA6Dcc7VFFair+/KjC14kbOxtMY8/Ls31mvmMWWia+p3Y9Lzf7zELHLHYcdDU3Jht0beRwq8YtheVtR1g4qj1PdVgRjzLV3hTXJ/z+soTszpfFurdjMY6+t/X9BcXE/rQ4u74fveLfkJTjXSl1VdeIdzRjUs5kNiFPI2kWaXciw0RYXKlIDdy52vDuJPP1F7aZTVte3d7j1ekuV6tDXq+uEyUxcy2devbSjHmu+JOjlzgYar5y5wqH84aDOw1x1xP3A/VtT1jA9N0a12aq3RbpBtxBce7pclnaaaVUliSrZhqr0tom+nOLif1Zcdqjf3p9vyq44cbqukMoOTUplf70CmFRUnSHutTGT7VnmSfsTiraPnBnY8LVyRG70ymz0PJS3CdK+fza9bzS7LGTIlmFvemEd6pNjiYNw2ZkmDjCQkhNwLcwmXhcl6mmVZnmzxuk7UrH3Ha1vu/u3kzDhH+uMLE/a1S/uejG8fp+DNM9KGG6rorUdV2s/bRBq0C/05ys7yeB9vIWtze3eGc78/uXO5pZx4cu32GnXvBtG++w6Zf8pembNNLjt8v1bg5b3B42eLvd4atHV3h3scHXb22TFoF4PRIWgebdSJhDs5uIByV4xx8Wp547WpStvLb9ZosPZvHPCSb288T9im74VV57Lkk5qsjSEwFfFUsflg7RElcvg6PLFYul5y2n7DYTKjewExdESWz7I3b8nEZ6dvycqWuJkgguc6kqEXwHy5pd3WBYetSVPPwcPFXjqCaOWHl8m3Del/177yCVuvpF9DKm48rJ+t5E/8wQfYp//C25rN8tn3hq13tuEDluPSXeHyfllPV96YEnkxK0kzcm5EkpuNFverqZo70sDA0sX8zkSWZ69YhZ0/GtO+/ySrPHa81tPhDvMHMtm25JVseRVhzlmi8vX2YvTfijvVe4s5xw/dY26SASdz31bSEcweTdTGgz9a0e1w74vVPBO31pp6VjQw0L1X2yfE7fYF9vy91eM8t+EVhN9UVK+u0Yn39ylOP9e98PuKOIWzSEeUWcBXzv6SelokZqHHNmLKY10SeWKeBEmbqOShIz6fCivCSH9Dh23JyDPGHqOm52m3wxDNyczTioZmQfidPyuWFRwn99G6igtMt2grThOHuQ5Eu5b0vOeSaYZb+oyEl1HaBYdxGkGi19XUFdoVUkb0zQ2tPt1KRaWFz2pAaWV4RhQ+kvDcTtlp3NBR/ZucXlas53zN5m5lpei7eIkvCjKG+mLXbTlK+2V3lzcZlr8y3evH2JdhGR6zV+KUxuSNnHv5MJR5lqv8cftsiyR44WJRV3uTzZyrPknMeGWfbnkXvl3w/D8XaehFAq5x6OWXkHU3IViAcNqXHEw0A/g3Y/0l4K3NyuWfaBy9MFlRt4MR7wcthj0y152bdMRXg9LkiqfL1+i7enO3xt6wq/N/sQ15eb/El9lW4RyaEmzIUcHHHi0CDE6AhH487CWEL7OClnFLo4q6z7JDHL/jyx6ku/mjaP1p4YS7ReXUEI6GyCxkDaqklNoNsJdBuOblNoL8EwU/oXBtxk4NWru2zVS75t6x1erPZ5Ld7matinkZ5GBpYaOMgTdtOUL7cvcaef8oe7H2Cvbbjx7hZ6FKhue6o9Ie4rzR0lLDP1ra7k49v6/rHySJZdRF4D/i/gZcoC6zOq+nMichn458CHgT8H/paq3nlcgzbeB/cK3On6Iv75ONXfr8baehNi8NSzCXlaMWxWtDuRfiosrkaGSeStl2t0krj2yiZXZ0d856W3+QvNDf5CdZ3Xwpyp9Gw7JXPIwfQt5qr8wfYL3Bo2+PdXv5Wvz7f58q0XOLw9we8Guncd4cgxbVxx6q08+nu+BOssStAOwwB9Dynb+v4x8SDT+AH4e6r6OyKyCfy2iPw68F8Db6jqz4rIp4FPA//Dkxuq8dCcraY7tr8+HbSDd0jK+GWHW9T4RbH2cREZaiEeFqfe/u4V7swu8WdXrrCzueCDm7u8vnmTl+I+31q/w8y1XHEZJ8pVf8COm1NtJXZnU/544xXeurrD24fb3Ly9RToMLG4G/CIwuelL8M6tCWGRiPul+Ia0HbJoS0LOuJWntn//SDz0NF5EfgX438fHD6jqNRF5BfgNVf3o/d5r0/hzwt2m+6sW2FVEYkRnE6giw/aEXHuWL0SGRlhcdXSb0F3O8ELL1uaC77j6Di/V+3zf5pe54g/5aNxn0wVqiTiEO3nBXla+0l/iC+2rfG15md++9SHuzCfsX9vEzx2T645wCJPbpc5ePBgI+8u7OvW060z09+CxOehE5MPAdwGfA15S1WsAo+BfvMd7PgV8CqBh+jCXM54Ud7P440sCqCoC0A8EQKvy3yQ1HskQD4WwdLTzhr3Nit9cVMwmHdde2OZyNecvbXyNK/6QV8Mdtt0SELzAjp/zev0Om24JwO2tGX8YXuFoWXE4meKPPN22o9p3VPuBZjcS5om415zE5w8JXSzK9H6VjmuJOQ/EA1t2EdkA/l/gf1HVXxaRXVXdOfX6HVW9dL/PMMt+jrmHc0+qqgTxNM2Jc6+JDJs1/VaknzkWVxzDVFi8qKRppn5lzuZ0yXdceYcPTW7zbc3Xeb26zmXXcdWXG0evmSPN/Nmwwe20wW8efoTr7RZ/cOsV7uzNSLdr6nc94RCmN5SwVCY3O1yXCXfmxQ9xtDCn3hke2bKLSAT+BfALqvrL4+nrIvLKqWn8jcczXOOZcCYx55vW91Ja2shYzTaoIknxbQACw9yBCEPjWbopy0nN72fh2myLO1tTbk03eDXe4SPxXSKZTZfJwI5bUpH4DyfvcCke0ebAtbrl63GbZZjgD3ypq7cQsq8IrdIEKe2xq4iMhTVp22/unIMl5ZzmPS27iAjwWeC2qv6dU+f/V+DWKQfdZVX96ft9lln2C8ZZa78qs7Wy+Kv1/bRBq0jabkoG3pXSKWf5QumE217J5J2e2faSj1y5xUvNAf/x5p+z4+d8rHqHqRtoRHHAXvYc5Iqv9C/y1fYqX52/wBfuvMT+vGFxbQO3cEyuC2EO03cz8SgR93rCYVcq7qz5+v5RLfv3A/8V8Aci8nvjuf8R+FngF0XkJ4A3gR99DGM1zhNnK+rKOD1ehekOA4TSi05iRHLGR4/kybi+98SpIMnRtxVHrecrCrdmU6JLXK0OmLqWHT/nVX/IVGDbJTbdAifFw7/hy/r++mSTLydHt4i0qWaYC6LC0ATUCRocIfrjoB2BMj4oY3Vj4FEee22f/n5rggXVGA/H3az9qjdeVZWtvNkUvCdvTtAq0G1XDBuedtPTXhqt/QuZNMlMXypJOR+7fJ2X6gNen1zn5bDLjp+z45b06o4Dd77SvcidYcbv7L7GXjfhzRuXGQ4j8Vag2hOqPaW5rcR5pr7d4doBtzcfi2uOHXO6/rjazjdYe3guxG/hssbj457184exA42UVtPe49oOYqCZl6CdaiNSHZWgnbB0DBPHYthgPi2fdWO6CUDfeLxkrrgFM9cxlY6lzrnqDzjIDVPXcXuYEV3i5tGM23GTVEdSVTrlDBNBcsS1niopritdfOhKbIEC4k6Efhyi+5xjYjfeP/eon7+aLcowFKu/WOJjxDUV8VaN1pHpO6VpxuJtT2ocRy+8wP70Cl+68ip+s+fS9hGvbd3hanPIR6fX2fBLXg67VJL49uZtEsKH6lvspQlffulF3p7vcO1gkzu3N9CjQHMj4pcwuVERlkpza5OwHHvkzdtSaWe+svZntvCeU6eeid14dO4SpgucaoO9PKmmW9e4KtLsNlBF6jsTUu1p7kT6ibC4Ghg2AreuVNy5MmV7c87RlZoX6kOaWceOn/OhcIdaEq+FXbIKH62vcXNziz/Zfpkvbb3E9fkG70x2YOHJIeCXUtJx557aO0Ll8UehzELarow9Jei6sXY+J91wnyPBm9iNx8/ZNtiM4llV02192ScPntD1hOAJ8xKp1+xF+qmj3Xa0O1MONyb8f1d2kOnAv3/hw8yqjo9u32AzLPlgdYdNv8CRaVzPt9Q32b604NbGjK9Mr7LXNby1c4lF61lerQhzT3PLEw8r6v1MtbuBXw6E3UWppHs4R3MuSTlnw3Ofg3W9id14cpxOwwV0OFVxZ7EswTqHR4gI/nZDCIE4Bu2kzYZuK9JvOBZXPMM0cOvFmhvTzNsvbzObtHz08k1eafb4cHOL16pbXPGHfHvzNsscub2xwW6a8oVLH+B2N+MLV17icF6zd2NC3HfUdzz1bUd1GGiaUEQfPJzuiDsW1CyDlxNrf+q7XSRM7MbT4y718+n70gm3bYuDD5CuR4aM62riPOC7UvXWt47UeJbzDXabGb+5P2MybXl564BXpntcrua8Wt/Bo9Sup1fPpTindgPLS4GDWcNbcZvFTk2/WdFtOeJh6bkXFpFmM+LbTLwzhuceLsr4xm64crpm/gVc15vYjafLmfX9sWiWbVnfHx4dr+9dFfExUk0atI5srpJyrlQMtbC4OqGfTfjq5W2+vP0y1XbLBy7vcame87HN60x9x4eqW0QZ+NjkGr163rx8hd1hyp8dXOHa/haHBw3zm6XYxuR6RVgo03cDfpGpbzfIolTPlWVbHHldX6x9P1w40ZvYjWfL3ZJysp68NiblSEp4EVz01CKE2oEEwpHgBqFbBvqF583Oc32yyUHfMAsdH97YoHYDm36JG9N9Jq7jcj0nbTq8y+xmoZsGUF9q6YnHLz0ahDCPhCbgjipk2SGLVYec7htFfwHi8S2oxjh/3C8F13tkOoHg0Y1SZmvYrscW2J7ljjDMhOVVJTWKXG2p6p5XL+2xXS14uTlgKyxwonjJHA41N7sN9rsJb+7tsGgr2htT3EJobjriETS3M9V+Jh4OxN0l0vbIwfzeYbnPUPQWVGNcLO6WlLPy5q8KbIaAAG7sjee6iAwRyYF+CSpCaoSlq1jUga+Lstc0dDlwpa6O++IlHBuhI6vj0nRBFRI3twOpCnS9kKOAOrIXNAA0+GXAZ717WC6cW2eeid04v4wltE9H6kFfylY5KT3vvUf2amLwxKZmMm3ITaDfLp1ylpc9qQ4sX9jmoIFbl6+QZ4m4WarpzqqOq5NDgmQ+tHGHPntenB3SDoF3rm6ybCO7dxr8gafaC9R3AmGuTN9t8G2mur0szrz9OYy98LTrz6Uzz8RunH9Ot8EWKdF6Y8SeugFJqVj8vkf6Ab+IyJDRKiC5JlcO1JcwWnX0S6FPwq4K7SRQuUQdBmo/EF1iKy4ZgqPLnkUVuZUcg4904pHsyFHwvSMsBddVuNYjfUKcQ/K42wAIfRmrnOr08wwFb2I3LhZne+ONGXgqDulKs0nxDjma40Jgsls64dbvFk9+txPoJ45229NvzeimM/700jZaZ+LOkhgTW9MltU94l9msW8LlTLsVONypWbxQoQvP8qrHLxzNuw7flq63vs1Ud1rcvMPNlyU4ZxiTcFRLD79naOlN7MbF5S7TfBmGUfg96v1x/fy47NAY8PMJ9SQQ55Hu0NFtCK4tlXY6beirjAh0sWdnsiS6RKwT1FCHgYM4sGginTS4pQN1hCW4VDz5kiqCFzyUfny9L/35Uioiz+5UqvDTrY1vYjcuPmcTcs7k3av3SN9DCIS2iD7sNTSzimEW6LY8QyMsbwZSDe3lyLJSbm8lXDMQ64G6GooH32XqekAuLRn6Ms13naOfOXzraXcc8ShS79XE/Sl+0eP2S218dziHnNBlW8Y1DE91y87Ebjw/3DX9tkyZpSuJLzpflICdgwpXV4SmptpsyHWg3qtKBd2DUj67vSwMk0C7NTBsDMRqYNq0NHFgUvUMyTGvKobBs6wq3NKRoyMeCUMtNLUjHJW4AGl7XM4wpOJrGMauOKu+d0/B0pvYjeeTM+2vVw0zBE72wocB6frSDLOOSJ/IlScsK1ItxCPHMBG6rVhmABNlOWuQmAn1cHwpAWgyOSidQpoIqRaGiScelRuHX9bUVUD6hNsLJ1Vy+5Ji+zQ89yZ24/nmbDLO2P1Wuu6kys5haY8VdmsIgXh7ikbPZKdhaDztTqDbEPpNR7flSBOl3/YQFD8bcC4TpyVVdqgDaXD0G4GwJcRDYZh6wsKTo8MvE5VzpaCGdyUSz5WCH5rSE02vNbEb68cqNNetUnD7ktrqBIZUovZiIIjg2oAo+M7jO4frS7CO6z25gqF3JK9QZcQpmkvraoKSaxgySBJyBDc4fOfK5y1L+K9UsXTiDW1xKnbdScurovzHJnoTu7FenC20IYIOjNb+JFgH55C9puTaTxqapiLPSrZcmjiWlzxDLbSXPTlCt6XkStFJhqBolUkVpKkwzBy+ldGJB/1kLIk9C4SjAX9Q4Q6XxcrPS3VcjsNv4XFZeRO7sd6cTsQZ024VSk67d2j2JVhGS6nrCLhUKtr6RlDvyJWgIqMld+RKwYN6RbKgDrKHXJdLDlNQL/jWo05KOa8M4mQMEHInhT5WNfIeg+BN7IZx1ot/ttf9YjlWz434qsLXFXHaoFVgstOQK8fyUiBVQrftGCYwTCBNtQi9UjQo/aYiUyFXguthmHj80tFsOOJhRTxo8E2NtB0uhOJnWCxL//rT23TvU/QmdsM4zZkknBKXS+kyA6MVzmMvvEDwQq4ClRdS41DvkARlwS1F6GPDXARUIMcym0hNOTm0pba+6z2uLZKUrj+u1KsAKR/7GN7vFp2J3TDuxiosVwQd8kl0nvdI62GxhBjw8yU+eMJ+8eA32zVD4+k3Pf20WPl+w5FjsfbqFFwRfD8ThgZydHQbSj+L1FOPXySq2iNtwoVQkmrmi2Lp+7Gn3fuY1pvYDeN+3K1cdkplTT0WpMS7Yse9J6jiJhFJEUkBN0ix5lVZu6sXUq3FyodSejvVAMIwgOvLNCAvIk4EXcayO9D3ZTaR8/tex5vYDeNBOCN6zVpEmHNpfDkMEAKu63Ex4A4nVHuRNA1U+5FUC+3WaOFnQg6QA+UG4DgWvHoYGo/6Cr+MVE5wXcKJlKo480Xx1A9DSax5CMGb2A3jYbiH6DWlsl++En0/oMsKWdRIn8m1R3IkVYIopEoYJpDj+LGesm+PoKJ0gyMExbcRDQ5p6+M1vKRUxrHqXTdmAb4XJnbDeD/cqzY+jHXptHTE6ctD64D0Dbny+C6Qo9BtCKkqU/wcxrR3AQ1lLa9ecH3Adb5U260CDiB4WLbj9tzoOHwAL72J3TAehVU4bioVNTSVCDwdhtIBZxlgWeNiJLZDSbNty3ad68v0vp+WaXwOpfRV9sBE0KBIcvhO8V1VOuQO+XjtLqtOusf96u5v4U3shvE4OFs3T/Q4Hn+VfCMiSPDgS5VcFSHVDkmOoXfkqOPavdTQg3IDEKVs6wm4Npa21DmfxNKnVOr06Riqew9M7IbxODkuqHEqn77rS9LLqrtt10PwuEWpjhu3KvppIE0c/VTIQUhVUW2OZTovyeNrh2QIweFVkazHwT/HMQD3wcRuGE+CsfvN6YQbxhz2YydbDLic0egofrqAihuddmXLrphuTkJuo5Cjx1UBiaHUvBsDcUhpFQN0V0zshvGkOFM26xuq54iUcFzvCcsObSr8tCJtVKTa0294sodUl7h79ZClpMuqLxF0QfV4xiDeF7/BcNeS8YCJ3TCePGOyTdmmy8fr92NLPzaTdK6E25KLBZcoqBfUabHyULz1TtDgypZc8EjwqObymffhgcUuIh74PPC2qv6QiFwG/jnwYeDPgb+lqnfez9/CMJ57Tlv5XCLndKFFoCmVphdth29rfF3h+rpM1zfDGHUnReRjNF5x7FUA+LYplr3tkPbegnf3fOWb+Sngi6eefxp4Q1VfB94YnxuGcT9UT8pPpVRaR/VDSXjpeqTtoe1wywHXDvg243pFEsjogNNj615SbDX4ErL7Hpb9gcQuIh8E/gvg50+d/mHgs+PPnwV+5CG/tmGsJ2PDSh3LS2vXk9u2dJM5PEIO57i9I/zunLjXEvc64sFAPMz4VnGpxNbn6MiVR5uI1hXUdam2cw8edBr/D4CfBjZPnXtJVa+Vses1EXnxbm8UkU8BnwJomD7g5QxjTThVNIN+LI9FKZ4hOeMWoXS3Ca7kwDgHqQhanYCXkokXcnHW8QhiF5EfAm6o6m+LyA889HdR/QzwGShdXB/2/Ybx3HI2d37lsFsxjI68GMZ21Q7JkVy5Y4edOkHjWFEnxpMOuHfhQSz79wN/U0R+EGiALRH5J8B1EXlltOqvADfez/c1jLXn9J78qf70qxJVMiRc8GjyxVtPCbTJYbTwwZWml8Hfz7C/95pdVX9GVT+oqh8Gfgz4d6r648CvAp8cf+2TwK+8z69qGMbKcTeWntJhKI927F+3aJFFcdz55YDriuNuLJhXtva8e2TLfi9+FvhFEfkJ4E3gRx/hswzDWFn4M1l0khLqBBli2YtPAZrISr4r7zwx8Ehr9m8ci/4G8Bvjz7eATzz0FzIM496cCcBhlfAyxr9LH0u5a+8Rn8dSd6PARe47jbcIOsM4b5yy8EDpAguIH0o9+5RO1t+h9JeD0VP/hKbxhmE8Kc5aeFVUdexY48px1b/OlzTY9/LAmdgN47xyyktPP8AYU685HUfLSQysak/ez6qDid0wzj+ax4LzDlUtW3IplU6wziHJCk4axsVnVfYqK5BKDfnswXes0t0JperNe03jHyYRxjCMZ4UWJ52msQdcSiW0dnzIGGZ7P8yyG8ZFYZzOa2LMgXfFO+9LZJ1IvO/bTeyGcRE4NZ1feedJCR298gLFW3+fctImdsO4SKzKRqfSh06GAYZxQz7dfypvYjeMi4bmk55zYztnkbEV1X0c8+agM4yLxHEv+YyuAm1G0ZfXbBpvGM8PqmXtToK+B+/G+vHJLLthPHfoqfX5mBZrvd4M4zlFsxanXErIIMVRdx/Bm2U3jIvKyrqrFuv+HoE1ZtkN4yKyyopLqVSUFbPshvH8shJ2PvHQ3w+z7IZxkVllxKU0BtWYZTeM5xrVlUf+3r9jYjeMC47mlYPu/kE1JnbDuMisLDoUZ51ZdsN4znmAoBoTu2E8B+gD7LOb2A3jecL22Q3jOWZV2CLl+y3ZTeyG8VzwHgE1YGI3jOcHtaAawzAwsRvGc0Np93zv103shvE88B577GBiN4znh/dw0pnYDWNNMLEbxppgYjeMNcHEbhhrwgOJXUR2ROSXROSPReSLIvK9InJZRH5dRP50PF560oM1DOM+PKast58D/rWqfgz4TuCLwKeBN1T1deCN8blhGOeU9xS7iGwB/ynwjwBUtVPVXeCHgc+Ov/ZZ4EeezBANw3gcPIhl/whwE/jHIvK7IvLzIjIDXlLVawDj8cW7vVlEPiUinxeRz/e0j23ghmE8HA8i9gD8ZeAfqup3AUc8xJRdVT+jqh9X1Y9H6vc5TMMwHpUHEftbwFuq+rnx+S9RxH9dRF4BGI83nswQDcN4HLyn2FX1HeBrIvLR8dQngC8Avwp8cjz3SeBXnsgIDcN4LDxok4j/HvgFEamAPwP+G8qN4hdF5CeAN4EffTJDNAzjcfBAYlfV3wM+fpeXPvFYR2MYxhPDIugMY00wsRvGmmBiN4w1wcRuGGuCid0w1gQTu2GsCSZ2w1gTTOyGsSaY2A1jTTCxG8aaYGI3jDXBxG4Ya4KJ3TDWBBO7YawJJnbDWBNM7IaxJpjYDWNNMLEbxppgYjeMNcHEbhhrgondMNYEE7thrAkmdsNYE0zshrEmmNgNY00wsRvGmmBiN4w1wcRuGGuCid0w1gQTu2GsCSZ2w1gTTOyGsSaY2A1jTTCxG8aa8EBiF5G/KyJ/JCJ/KCL/VEQaEbksIr8uIn86Hi896cEahvH+eU+xi8irwE8CH1fVvwh44MeATwNvqOrrwBvjc8MwzikPOo0PwEREAjAFvg78MPDZ8fXPAj/y2EdnGMZj4z3FrqpvA38feBO4Buyp6r8FXlLVa+PvXANevNv7ReRTIvJ5Efl8T/v4Rm4YxkPxINP4SxQr/i3AB4CZiPz4g15AVT+jqh9X1Y9H6vc/UsMwHokHmcb/NeCrqnpTVXvgl4HvA66LyCsA4/HGkxumYRiPyoOI/U3ge0RkKiICfAL4IvCrwCfH3/kk8CtPZoiGYTwOwnv9gqp+TkR+CfgdYAB+F/gMsAH8ooj8BOWG8KNPcqCGYTwaoqpP7WJbclm/Wz7x1K5nGOvG5/QN9vW23O01i6AzjDXBxG4Ya4KJ3TDWBBO7YawJJnbDWBNM7IaxJpjYDWNNMLEbxppgYjeMNcHEbhhrgondMNYEE7thrAkmdsNYE0zshrEmmNgNY00wsRvGmmBiN4w1wcRuGGuCid0w1gQTu2GsCSZ2w1gTTOyGsSaY2A1jTTCxG8aaYGI3jDXBxG4Ya4KJ3TDWBBO7YawJJnbDWBNM7IaxJpjYDWNNMLEbxppgYjeMNcHEbhhrgondMNYEE7thrAkmdsNYE0RVn97FRG4CR8C7T+2ij84LXJzxXqSxwsUa70UZ63+gqlfv9sJTFTuAiHxeVT/+VC/6CFyk8V6kscLFGu9FGuu9sGm8YawJJnbDWBOehdg/8wyu+ShcpPFepLHCxRrvRRrrXXnqa3bDMJ4NNo03jDXBxG4Ya8JTE7uI/A0R+ZKIfFlEPv20rvugiMhrIvL/iMgXReSPROSnxvOXReTXReRPx+OlZz3WFSLiReR3ReTXxufneaw7IvJLIvLH49/4e8/reEXk747/B/5QRP6piDTndawPw1MRu4h44P8A/nPg24G/LSLf/jSu/RAMwN9T1W8Dvgf4b8cxfhp4Q1VfB94Yn58Xfgr44qnn53msPwf8a1X9GPCdlHGfu/GKyKvATwIfV9W/CHjgxziHY31oVPWJP4DvBf7Nqec/A/zM07j2I4z5V4C/DnwJeGU89wrwpWc9tnEsH6T8p/urwK+N587rWLeArzI6hE+dP3fjBV4FvgZcBgLwa8B/dh7H+rCPpzWNX/0BV7w1njuXiMiHge8CPge8pKrXAMbji89waKf5B8BPA/nUufM61o8AN4F/PC47fl5EZpzD8arq28DfB94ErgF7qvpvOYdjfVieltjlLufO5Z6fiGwA/wL4O6q6/6zHczdE5IeAG6r62896LA9IAP4y8A9V9bso+RHncho8rsV/GPgW4APATER+/NmO6vHwtMT+FvDaqecfBL7+lK79wIhIpAj9F1T1l8fT10XklfH1V4Abz2p8p/h+4G+KyJ8D/wz4qyLyTzifY4Xy7/+Wqn5ufP5LFPGfx/H+NeCrqnpTVXvgl4Hv43yO9aF4WmL/LeB1EfkWEakoDo9ffUrXfiBERIB/BHxRVf+3Uy/9KvDJ8edPUtbyzxRV/RlV/aCqfpjyt/x3qvrjnMOxAqjqO8DXROSj46lPAF/gfI73TeB7RGQ6/p/4BMWZeB7H+nA8RcfHDwJ/AnwF+J+etbPiLuP7K5Slxe8Dvzc+fhC4QnGE/el4vPysx3pm3D/AiYPu3I4V+I+Az49/338JXDqv4wX+Z+CPgT8E/m+gPq9jfZiHhcsaxppgEXSGsSaY2A1jTTCxG8aaYGI3jDXBxG4Ya4KJ3TDWBBO7YawJ/z9bPgtvR75eGwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"o_coupling12 = torch.tensor(ot.bregman.sinkhorn_stabilized(mu1.cpu(), mu2.cpu(), cost.cpu(), reg=1e-3))\n",
"o_coupling23 = torch.tensor(ot.bregman.sinkhorn_stabilized(mu2.cpu(), mu3.cpu(), cost.cpu(), reg=1e-3))\n",
"o_coupling31 = torch.tensor(ot.bregman.sinkhorn_stabilized(mu3.cpu(), mu1.cpu(), cost.cpu(), reg=1e-3))\n",
"pyplot.imshow(o_coupling12)\n",
"o_coupling = torch.stack([o_coupling12, o_coupling23, o_coupling31], dim=0)\n",
"(o_coupling.float() - coupling.cpu()).abs().max().item()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Performance comparison to existing implementations\n",
"\n",
"We copy the code of Dazac's recent [blog post](https://github.com/dfdazac/wassdistance/) in order to compare performance.\n",
"\n",
"Dazac uses early stopping, but this comes at the cost of introducing a synchronization point after each iteration. I modified the code to take the distance matrix as an argument."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"46 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
},
{
"data": {
"text/plain": [
"68"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAwYklEQVR4nO29a4yk2Xnf93vOOe+lqvo2PZe9UyTttWRZhiyFiSgrCATTDhxZMP1FhhwoYAIBzIckVhwHNpl8ygcDBGIY1ofAAGHFYGLBtkATJiEYtoVNZCBAQIgSGUviiiItksvdHc61e/pS9d7OefLhvNXdOzszO7Nz6556fkDj7arq6jrbO//3Oee5iqpiGMazj3vaCzAM48lgYjeMFcHEbhgrgondMFYEE7thrAgmdsNYER5K7CLyF0XkGyLyLRH51KNalGEYjx55v3F2EfHAHwJ/AXgT+C3gr6nq1x/d8gzDeFSEh3jvfwR8S1X/CEBE/inwceCuYi+LmdbVFpLGG0xS0BNfLL+H45tQfmwYxnvTcEinrdzptYcR+0vA9048fhP4idt/SEQ+CXwSoC43+eif+q+RboCoSNcjQ4R+gL5HY4K+G699FnyMaFLQlH+hZfwZxl35sr5219ceRux3unu8S4mq+lngswCb0xdV+ngsWO/yG5wgTpCYUCdIjCCSf67v83UpeiEL30RvGA/Ew4j9TeCVE49fBt5+z3cN6fh7EfCjj9B7EEFUQQRVRWJEk0dU0aSIS1nwy/ea4A3jvnkYsf8W8KqIfAh4C/h54D+/5ztSQtoOnAN3YmPgHSqSXxcBDYj3ECP4Pp/tl49jHM/66Xh7b6I3jPfkfYtdVQcR+W+Bfw144H9X1d+/95tA+gENHtRlqy6j6L2Akyx61XxGCD7v2mPKj8efVSIkB0QQB9h53jDei4ex7KjqvwT+5f2/IUHbIVrkrXrw2covt/MiEDyoHot+ubXvejQlZBjyDSNGJLp3OvFM9IZxVx5K7A9MUrTr8nbduSx677PoId8A/GjpC5+1GzwkRbzPjrshoK5HUkS7HkkJBUQUVEz0hnEXnqzYVXOYDVDnEeeykFWzV947kFH4XsCD4vMhAfK23g35fcMAzkOK0DlICR2G/Haz9IbxLp6o2FWV1LbZQjuXheo8UhZZ8OMZHedQGZ12xXhOr7KllxiRIcEQka6HmCB0EPNjvVO4zkRvGE/YskP2pMcx1t4LFKCD5LM5oG4MxSVBnMv6FBnP9aOll/zzCvlIMF5VXN7eA5JSfk9Mtr03DJ7CNl6HHjTvyzXG7HALAR235TLE8Zxe5m19GVAPWnj0ZLhOwQ0pi70frX3XIzEhTYum5eOYf/cYstMYjxNzxjUZxirw5C07HCXGiGgOo41WnYJ8FgcIo/UP4zk+5WSbHJ4DUUiFQ9J4A3DuOEynigzZwjPIcdguxuUC0LQs+LM4vbEaPHmx6/F2WiMQyRbXDcgwZAee9/l7748891oXqPf57O4d6gQNo2DrMAq8hKS4rkKGlHPw+xyqkzbn3Eu/vFruvbFaPBXLfiyoY9EL/sjKo5q38KpIP8bigx+ts0OdAmMCjjBaekGcIklJBMQnHOQbxbhzkCGimhBJOWY/RgFENN94lmsywRvPIE9H7EtOiP6klcfl1FmcG1Nm81legkfrEikDWngkBdQJqXSoQCpzYo5U+fe6vkAGRYaE64Z8tm9ypZ202YOvfT9e87leY3pnKu471mkYZ5enK/YlqqN5l3xRGS3ueB2r4ySFnDs/Ik5y8g3ks7zPlj6N2/sUBEnguoQGhwwJ7xw6pFxlN8QxAuCPIgFCn9fgxnO9Jiu6MZ4JTofYl4zneV06zBmv0SFJ0TGLTro+x+b7Mos4lqgTYh3QIKRAPtN7IRWQCo8kh0SIk4AkxTdltviLPsfs2y5v89su+wuWXvwxWQcwa2+caU6X2OFIRDp6znNpq+Qw3bLyrQg5ZDd+79JxaE6TI/ls7VOQHLYTQPINQ8Ybh68cEpVQ+iz6eciOvOChH3IqrpOjIhxVRYhm7Y0zy+kT+5Kxrn0ZJstC0+PXxpZWMsRs9cNY9x4cEgtS4YiVy5a+EJJXVEALyZ2uvM/OvCKLP0wC0ifcosJ1Q47Rj510tG1zY42uy59xMmYPVmZrnAlOr9jhWEAaURUgZgvvJIu1Czk81xZICDldNnikr9HC42YFsXCgHq2EFPK2XsWhY9TORUUS+M7jhry9923EtRE/77NTb17AEKEJYy5+duplR+LSk2/be+N0c7rFfpKlpWfc2ovms7SO2/xlOWwMOcFmCNkzX3pEQdSRCmEgh+5SGMN2To6En0LeAaRCcKVDC4f0CV94GBKuCDkHv+1y0U3XH5famhffOOWcHbHDsdceUE0gDh2TZnCChDaf65sK8Q4/r3FFwE9K0qQgVR43C1n0E0fyEKt8rh8mY2ZeEkTB9R7fBdwAYV7hBqXYr5GY8PttTtZpuyz4vj8SPsNgabnGqeRsif0ky0w8GR1mKcfkdWnpk8818OOOwEnecnvvkJg99TI68JKCFoyW/rhrpjrBDQDZmQcFrk+oCK4bcMFnZ2Eb8rFiGI7KbRlk/MVWhGOcDs6u2OHI0h+f592R5x7IDjbvoChxZYErC/xBhRaeuFaRCke/li19PxViKaSSbPnLsV5HR2ufwLc5fBcWBb5XwmHCtwm/GPDzDukG3KKFIaJNA0unniXrGKeAsy32Jbdl4h2l3fR9vgGQt/2yLIcdAi44JPox8caRiqxsHZtm4DhxlgdJjG2ucygvDnnbr2E88zvBtT7P0xpi3lEsvfXj9l5jfHdq7jvWbxiPj2dD7EvuFaMXQYsCaVoIAb9oIXj8fvbcF2slqfQMM88wcQyVMEyylY91tvJ9lT9G1nP4znceGcC3ntCAb5XicIrrE8V+j+sT7qCBIeIWbU7N7fqjkltddu2xEJ7xBHi2xL7k9pz7Zb790toOQxbW8kwfPAFIpQfKfGhXQb0DlFRm650t+NjYViAVIBHSuP33Lah3uD7vLVyvFKosB2PIsqmmk9y4g7GmX5ffY8k6xmPj2RT7kjuJPo7tp5dTZ8amGa7tccHjD2uKOhAnBcMskEpHt+6IhdCv5VTcOIHkQUMWfKxzfb0Mgutl9OA7XA/lQYHrlXJ/husS4bBHFj3SduP5fkCXRTnjVn+5/T+2+CZ+4+F5tsW+5A7JOUdOsmU/vL6DkFNm/TzgJhW+rYiVR1JBLB2Iy0k5IZ/Xk4NUaE7J9ZobaaR8vg9TQYYc0vOdEivBt0pZOnzl8QuPcy6n6IqgY1EOR+22FPEce/JN8MZDshpiP8lyvNTSwg/DOGiC3DgjpiPRu67PHvy2IpUe3xSkUvCNJxXQrwmxEmIFscrneg15u9+vKaJ5FyBR6DYF10Fx6PBNQbGoKA6m+DYS9rucqXfY4IaILpp8ExrP98v0XDSZtTfeN6sndrhrGi4ch+t07pGyzM68wxrvPf5ggpaBcFiSKkfbePpJ3t5LzI68oVA0KKlKOSwwlu+5MWzn54JvhGLuKQ4cvgnUuyFv9XcLpI+4/SInCi0a6Ps8FGPpzQeL2xvvi9UU+0lu65qDynGJ7TJGDkjwuQa+8ZSqpNIjsSRMHKEV+kYYasF1ghbCMBXUK6nUHMYr9Mjqxzpb/H4q+BaGqcf1UM0cvlPKvRLpEv6wzc022u4oRVfbLg/GOJmpBxa7N94TE/uSZdhuGPI2fyCXxTqBts1594dz8B6/V+FDIKxN0KpgWCsZZoFh5mk3HEMN3Vautus2Uz7nTyIUCbyCU/oodIOD3uHmDtcJ5a1R9Due0CrVrYowj4TDHneQU3PdvMlOvaa5cy89s/jGXTCx34k7ldcuz/eq0I0TZxcOhkiAsVZegYBvc5FNLAUVRyqVQX0O4RWaRQ9ISNnRB2gpdLicnuuEvhNiIYTGUe57iknANRFf5i2+FAUyDGjTWrts474wsd+Ne53rG58Lb+YLxDtkt6SoSoqyoFqr0TLQb1bEytGc88RSaM85Yu3oN5RhmtA64aYDoewpNgdUBVUhJeFwUaCD4PZDPt/vBYrDQDhU6ls1vlHK3Q7XRfytRRb/onln0s7JDjvm1DMwsd8fdzjXE/OZHu+P03FjxAFaFYQxfVZ99thrcAyD5EEX6hh0bKhRCM4nnFO8T3g/pvyn/DNaOtTl8txYCjiHbwBKXJ8onSBtxPkcxqNp8/irfshjtEanHpjoVx0T+4NwWzouSd9ZYut9PtcXgbBXQ/AUNydo4Zls1MTK0W0F+mlO1Ok2SuJE6bYKtFT8Rof3idmkI/gEG7nNddMHui4wbwMHBwWucZS7AddBfaPAt1DvzgiLRNjvc1HOonvv871t81cKE/v74cQWf9kR96gNdkr5bB9jjtcnRYInqOKr4qhWXmLOvx0GIQVHqpRYBrSMpEoQUeowUPhIHQZiLczbksMiMrSB1pX4VhDNHn3wFJU7arTpg8t98/uxwccyqnBixLVGsH56q4OJ/WE52QZ7SO8osxWRHCrzDrdoIATcbkVZlUzWSoZZwTD1NFueWAnN+ZJUwsG5ir064TZ6ympgWrfMyp6t6YILa4cMybG4WNANnv2DCXFw7O8W+MZR7haEw4JyX6n21giLRLnb47qIuzXPs/Tmi3yeP5m0Y9b+mcfE/qi43aEXswdfYsqlsf2Qu+i0Ha4skKbCzStCXYxpubnkNlaAOuJE6By0SQghMit7Kj+wXc1xkj9rUMeN2Yx2CFybrNG1gcWkJBw4hqkQS0cxzxNxfJcoVXPNvWpO1GHsvd+LTcZZAUzsj4O7TLpRN7a/bj20be5kWwT84YRUBoqDilg6qluBWAnt9UA/DTQbFd/bWEOmA7ONhjIMXJjOKXxks1xACRtVw5Acu9sTDtuS+WHFfL/Azx3lTsjn+hsFvlXqm+v4NlHstbhFnwty2g7temhb9KhnviXtPEu8p9hF5BXg/wCeJ7ujP6uqvywi28A/Az4IfAf4q6q68/iWega5vWceZIeeE2h9Dt2FAAeH+BBw+1MoAuWtCanyNLdKhonQnHN0m4F+w3PQetx0oPCJWdlxsT5g5jtmoaV2PQdDRZsKrrZrXFussTufcGtnhjaeft3jm2z1w0KpJ45wWBIOCtxhizTd2KAjwliJlxdt03GeBe7Hsg/A31TV3xGRdeC3ReQ3gP8SeE1VPyMinwI+Bfztx7fUM85Ja3+yD34at8vLTrkh4GPCF7mPfawDYRHo9hz9GrR7BcMkcGWvgCpxdXuNSdlzaXbAWtEy8x0T37NZNMxCx35dcWO64LArubkxo2097XaBb4XqZkE4DFR7JeXelDCPFHtt3urvL3J7rcXiqL2WzcI727yn2FX1MnB5/H5fRF4HXgI+Dvz0+GOfA34TE/t7s+yblzhOyx2nz7BoclrufgneU9yscqLObEKalsRZQbtVMEwciwuBYQKLSwWHk8TuhQlr05aXN2/xwuQWF8t9XihvAeAk0aaCy90me0PNt/Yust9WXL++jh4GypuecjdQ7nvqm4GwUKqb1ejUK4+Lcro+35T6Pt8ArL3WmeKBzuwi8kHgx4AvA8+NNwJU9bKIXLrLez4JfBKgZvpQi33muFda7tEQitw3zw0RGRIkCBMPBIYqz6uOlaftp+xMK+ZNydW1NbYnc65N15mFlnNhDsDUdxQSGdY8h5OSwkf2m4qD6ZR+I9DtO7p1T1jAZM3hW6VaK3FtJOxlsbtFezQPz6z92UL0Pv+niMga8G+Bv6OqXxCRXVXdOvH6jqqeu9fv2JBt/Qn52MOs99lnOaV2WYQjDiny5Bspi1x2Wxbo2pRUBvpzOVmnOe/pp0JzXug3lH4rUm43rE0bXt2+zmax4AenV5i6li0/p5DIfqppUsEb7XmutBu8ebjFm7ubLA4r3NUSvxAmV/P5fnIjEeaR8laHm3fIvEUW7Xs79Uz0T5Qv62vs6U2502v3ZdlFpAD+OfCrqvqF8ekrIvLCaNVfAK4+muWuOLdbe5eyU29MhgGytRfBdYHCC770qK/GApzcA08GT58m7KwF/r3ARt3gRZn5lg9UN5m6Fi+JqWu5VO5RuEhwufvt7mTCNVmnWwSQnJ+fQg7jxdpR7BeEeYnfL/MRZF7AMObnq+ZMvZg46plv1v5UcD/eeAF+BXhdVf/eiZe+BHwC+Mx4/eJjWeEqcjJmf/vZflmEc3CIeI/brfDBE65OoAisrdfEaUG3GWg3PP1axeH5mltT5Y8uXcJPIi9f3GGrWvDq+lVeKG9xIezxavV9mAHbsB8nfO8D2+z0M35v70X22po3r28RDwuK64FiL1DtltQ7U8IiUd3o8ny8W3NcP6DzxbjN73PmXj/YNv8UcD+W/aeA/wL4XRH52vjc/0QW+a+JyC8CbwA/91hWaNw7bq8K3o1dcgM+Ka4dkKHC9QWh9aDkZhquIE49l8MG+9OSOvQkFZwkZq5j5lrW3YJaemrXsxdqAHaHKUmFW9OaPZkRq4AGh3ohHApQ4ptEBdmTLzIW4rS5t59Ino1nPfOfKvd9Zn8U2Jn9EXGnc71347k+t9KSukLrEq1K4kZFrAPN+Zyss7iYG2w0FxNpFpmcX3Bp44DnZ3v8ybXvsxnmfLi8SimRWnoArsZ15qniDxfPc6Xd4DsH21ze3WCxXxGuloSFUF+F0CiTG5GwSBS7LW4+Nty4n6QdE/5D89BnduOUccdae4e43KtOvM8Wtcljr6StCWWBbyfEyuP6IrfQGhz9mqMZHG91gXlfMPE9F8qK58MtSj/noj+kkMRz/oCE8Hy4xY3pGq/XL/KH9SXeXtvkctikOSyAfL5HPcUiN/EMhQcvuaWXc2Mp8FjHK0vrnqxn/hPALPuzwu3W3ufBlksvPiEgkxotAmkjp+d250qGicsZehtCvwHt+YhOI+cu7rNRt/zIuctsFXP+RH2ZLT9nwzXU0tNowaGWXBs2+HZ7kWvdOl/feZ79tuTmlQ1k4amueYoDqHaUai9RHCbK3RbXDLi9eQ7bHc4taecRYpZ9FbjdqbcsxHGSW1d5lxNjvMcfzHMPvVtTUlVQb1V0G9mh1+x4+jXP7nyLnWmkjZ5z9QK/nXip3GGrepuLfsFUFqy7wFxvcKP+DjdTzddmP8D1fp0vr32Qm4spV2ZbtPuBYeYYdjzlnkM9+EVBMU7ClTRmD5L7/wl9LiJ0J1J0jUeCif1ZZRxpnbfHY895VURyRR7jiGlfFrimJuyVVNOSeqdgmDrmO55Ye65fu8SVifKtSxeZTVs+dO4GL093eaW+yYfLa6y7Bef9IYVE/kz9XZqq4IVyl1txwte3X+TKYp3v7W6xuztB9grq6wVhDpNrJb5VJtfXjpJ2pOnH2H2Xy2/vZu3N0r8vbBu/Stxlq3+0za9rdFKhVcmwVRNrz+JiPt/PnxeGqdI9N1Butry4fYs/fe5tXqlv8h9Ovs2WW/ChIlHgAUgkvjsou6nia80P8I3583xr/yLfvn6eZr+ivFzg58L0iuaknetD7qS71+SEnaY7zstvW+upd5/YNt7I3Ck9N42Ciemoo410PYUqvgxIglg53JBz8Zt5Qb8R+M5OzeXtDbbWFnx9+0W2i0N+dPYGG77hxbBDnWdZse46Xi2/z5af83y5x3OTfa4s1vn2+nmaRUG/WeEbodkuCXNlslNSHETCQY/fa3P77MNFtvAn22uZtX9gzLKvOnez9mFM0Z1OIHjS+gytPP1WTb+Wu+u054R+HZrnIjobeOmFHbYncz567tu8WO7ww9VbPO9bNp1nKiWt9uyngSux4Pe7F7nSb/L/3Pxj3GxmvPH9bfQgUF8NlLeyU6/ejYSDSLnT5Kaa+4e5Em8+ir/r3t1Tb8VFb5bduDv3tPZjMU4IOOfQLoyOtQKJ4AZHWAgkzzBzvJW2uTLZoB0C5+tDrmxs8kKxwweKmzznDygl4YBKIi+FHWauZX+z5sZ0BsDOfMJesUa/7unXHP0sUBx6JhOPbxPlTpGTdvYKGCIsxky9fshOPmu2cU/Mshvv5qS19/kMflSMU5XZ6k/y+T5NCvrNmjhxzC8EYg3z54Q4VfrnOibrLR/Y3uFPbFzlw5Nr/On6e2y5Ba+EniLfWuhV+e4wYTdN+Z35B3mzPcfru8/x1o1N+r2K8kogLM/3TT7f+0Uk3Fogiy479Zqxk27XH53vVzF8Z5bdeDBuT89dEmNutlGEfL4fIq4v81SrJqAixFpQ5xgmglKyWHi+k4RmKNhdn5DUcSHsAW8zlYEtp3gRtn3D1PUc1pdZz43xCS5xpV5jX9cY5h5Rh18IEAiNp/KCnxe50UcRcr29X07AdflqY7GOMMtu3B/3OtuXRb5OJ2gRiJsTUhVozxf0y6SdTeg2leFCTzHr+dClG5yr5/wHm9/lYtjPTjzXUkukENhPnt1U8dZwjtcXL/F2u8m/u/EiB03FweU1/KFnck0o9pV6Ryn3IsXhQNhtkG5A9ufv6JmvXbcS6blm2Y2H53Zrn0bHWIz53Ow9EhPiHT7GbG2HKcUk4LtAWOTzfTMUDGuBb8s216YzJr7nUrXP1LXEcIuX/AFTEWqf2HZztlzLzLV8v9wC4EY74+vR0cxLmlTlaTtOSB5SkQfn+WbAxXTcM38M2b2rp96KWXuz7Mb7R+TdxTjO5bi998gkl92m2QSdFAxrufS2nzoW5x1xAotLiTRRqktz1qcNf3zrOq9Md/hAdYMPlteYSceWWxCR3GxDC/6gfYFbw5T/79ZLXF+scfnGJsOtkuKWp7ouFIfK5HoiNEp1oz3umd/1aNMez8MbQ3jPklPPLLvxeLi9n17MxS9H47C6Pg++nC9wZYmbVIS9mjgpKPdL+qnDL7Lo5+2Ma9MJTR+4sT7jcKPCk3i+uMVz/oBaEq+EnqjKi+EW+6nkQrHP1X6Dr9av8ObaJruzGSmUDAcOcIQFSMrlt0VUXBvy3DsRcIKOLb/yf8tt/olnELPsxqPjxLkeyJ58J0gYHWhFiVS5rVYap9122zWxdCzO+5y0c14Y1nNbrWp7wcYsW/vtcs4PTr/Pul9wMexREOnxRHW83Z/j+rDOdxfn+aP98+zMJ+xcW0cWnvqKxzcwuZYz9eqdgXA44A9a5LBB2v7dmXrLhJ0zmKlnlt14MpycgcdtvfJbAd8cV+LtlbgQmNzMTr36+pRYB9rtgm7maLYD3bl1bmzOuPncjOm05ebFKReqA3587bts+wNeCrusu55Xi+skhLemG3xv4zxvdtt8dfMVbjZT3lzbRheeYZLDd0NdUO17ytpTFB5pBpyTHLeH0ZPfoSrjiPvRk3+GBH83TOzG4+NOTr3RsUdKLOdTi3P4lHvlu3ZCVXvKg4Jux9GtOdqbU5rphP/3+gw/iXztwstsVA1/fP0aF4oDLoR9tsMBSR0z1/JCuQtbsDNMWStb9tqa729s0swD7flAse8pb3mq3YJirtQ3J7g2z7qnH3CHC3SIudHGMzQLz8RuPH7eUX47DsHss4dcFk3e6u/tI94TbtRIUVBOKnRaEWcl3VaZK/EuFMRJyeXnat6cJr753EU2Zg1/7Nx1Pjy9zgvlLT5cXWHbH/CnqjfpNXB1fZ39OOF3L77MzW7K6zee59b+hMXNmvKGo9h3TK7VuRjnWoFvI2GnyP3+Dj30fXbsLdt7n+G4vYndePLosVg0gqjkG4AqMratEtUcyhsSkpQwD6AFsRJIjlh7mnbG9bWaWwc139nY5tLsgA/OLrERGl4odykkb82TCudCbp893yy5Uc24UqyzqGr6g0CsclgwVrnstp54fBMJe9VR2S1Ne9QrX06W3Z6hc72J3Xg63NZsYyka6br8/JiwI97jq4oQPOVanoW3vjUhlZ7m/Dgd5/wae+trXN+8yO9eeIly2vPBCzdZLxt+aP0K677h5fImtev5ocllevVcvrTF280WlxcbfHfnHPN5xeGVirAQJleqsey2OuqV7/fz8Et3MM9WfjkL7/bOuadY9CZ243QwFuRojDlmT8x98lVHR1lA5g6CJ4iQykDlhDD2uvOd4AahTSX9LPBt3WZS9wzJs1a0fGBykzXfUriBQiIOZauY0ybP4VrJXhi42Tm6qUeGsZcejuJQUC8UhccvCpxzObToF7n4Rtq89n449X3yLfRmnE7u1kHXOaSq8jn/tqSdfr2k2wwMEzlK2mkuKHGSCBca6knHixt7XKgP2S4POVfMjz5uHkuutOsc9BXf2d1m0ZY01yf4Q0d1w1HuQbmnTG4O+EWi2GlyW629wzEtt31Hos7T2uJb6M04e9zeQXdsrbWM4R/dDFqPSwntSso+4vqKYZb/WQ91TqWNlad1NfuTgreBZihoJoGkjonvmPqOyg1sFQsqFzmcVRyUPVd6TywCXSxQl3+XJE9ROVwsSV0gpAT9cf9+BcSNacRwqpx5Jnbj9LPM1Dtqmx1zI81hQETQpsmZekVJWVcUVUE9G5N2zlVHSTuxdjTb57g82+KNzQQbPeWk58LGIZOi52J9QHCRF2d5+u3FySGLoeDqxTUOFyWLvYrD3UCYO+rrE3yjTK9X+EYpbza4dhx13XY5QWf04r9rIs5TEr2J3Tg73G7tk2ZLGlO2uv2Q++UvAq7toQhU3YCWAddXxInD9Z5+JvjW03VCOwvcEKjLnsoP1L7nXLmgcBFXHZIqwbvErapmJyQWriZWAYkO3wiu94RKcX2JLzxEzT3yRUYrnyfx5iSd+FQLcEzsxtlktPZ5tNS4vR8nzYjIWNvucWP77MneBK0C9bWaYeLpNwLteu6G022tM58oX9/eRMtEtdlQlgMbdUsVBoIk1ssWv64s6o5FVzDfrNHW0W4HfONY3KjwC2WyUxEOI8X+2EOv7XDzJp/nF83xNJynkKRjYjfONrdb+2W//GHITr22zWG8JqfqFvsTQllQrtdUayXDzNPs+ezUm+f22W0vtFUibjmmVc9G3TALHaVfsFUvWAwFu2VP0wcOwpS+c6j3ubGGcxR19uCXIrjG41SR3ufjxzgNR+HY0j+h3vgmduPZ4fb0XDnRSy8lcNn6Swj4tsMdVoS6pDgoGWpPuZdn4TU7gVhBd67gZq1cXxvw04GiiEyqDhHNrflcolpviYOndQXSOYapI8wd5Zaj2vMU84pqp8I1kXBrktNxD+bokLvl3nHK7WOy8iZ249njXum5TZvDdmXupefKAn+zoigLqrUJqQ5MzpXE2rHYdgwTR7dV0q8VNOuRbiMQisjatKEIkbrsURUOJyXD4GmmFX3j6dcd/Z6j2IehEopFovKC6yJecvkvMMbqR2fjyTP9YxC8id149tGT4S+ByNimKmXLmjT3p1fFNQGJiVR6XF8y1Dlxp18T+r1Av+cZauXmeomERKgGnNMxUgBSJJSxiM45UiGod/QLIfkK3ypVcLmhRlkgbZdTcdv23V1yH7GVN7Ebq8EoGj3Rokr7IXvLFzl0x8Eh4h1hp4YQKK4te+rVDLNAv+ZpNxz91NGeK0mV0m0VaKnIdMCHRCgHXK0ME8/QO/pFoNvw+Ba69VxbP506fJuo6oBf9Lj9BlmU0HXI6MSjI2fkPcLEHBO7sZosZ+Ell4dIRvJQyeRBOhi9+lIEcCB9gQwlkjy+dYgKscpJNrFUhl4YCoUyISFBkvwRoqQqAY5hChqg7XLozw0FqfS5F3/wyMJnx91tPfMeVWKOid1YXW5vqzWQPfjjeVrmi6NGGyEEQl1RTSpSXRI3SmLlac4FYim05zyxhm5dSRXESYIygYBOI7FKNLVDeqGfOXwL/VrAN1DvBor9iuKgxu9VSNNByN1xWTTIyT74D3GeN7EbBrxzMk7MyTAaYxYaQMz96kQVF1O29l2g9EKqBPWe2JJ75/cgyTEkAa9oGMXpQD2kMgfehomgAn0niHokgQwJJ4Kc6H2vyzl8D+nAu2+xi4gHvgK8pao/KyLbwD8DPgh8B/irqrrzwCswjNPC3WbcD0O+LhZQ5B75fq/MM+6XbbU2KlLlcyFO7ejWswWPNfQzRUMWuQrESvP3QXC9MEwd7cJRbniqPY9fJKqbJdL2uFsFMkT08PDYgfc+Z9s9iGX/JeB1YGN8/CngNVX9jIh8anz8tx/g9xnG6eYOM+4ZB0miua2WABLyiCxXjdfOjyYc3CCoCKnIgQBxkEIWffKgokgliGYduz4X+sQ64ASkKRHpc7POcU15Oz867kTuW/D3JXYReRn4S8DfAf6H8emPAz89fv854DcxsRvPGnfIxz/y5HuHju2yXZMz9dz+hLIIVOsVw7QgThzthieWQrchaIB+mq/qFXWgBfQeUhBiKfSNEEuHbwuq0uPbiCtC9iUcjv3vu+6Bz/H3a9n/PvC3gPUTzz2nqpfz30Mvi8ilO71RRD4JfBKgZnqfH2cYp5CT1XeiY/rrWIQzTsORGKEo8EPENSVxWiCxJNYC+NHCZ0sfK4GgqFPUgyQhTgARXA/JuzxDLziKtoQxGUePdhfjOT5yXxb+PcUuIj8LXFXV3xaRn37wv49+Fvgs5OYVD/p+wzh13KGHnnbd2GlnFH3fI/MCOSzxhxWpCoSDklQ62kNPKoRuTUiFkEpIoxLVQSyhnwmxBEkB3yswxTcD3rvcEy+EHJMfhtz6OulRC++7cT+W/aeAvywiPwPUwIaI/GPgioi8MFr1F4Cr7/dvZxhnjrsV4Jwot5UiQFPimg43tsnWMuC6klQ53JC39/1UiPV4zA+Ah1jnbb3EseVWF/LjPuIYy3qHITfI6IecK6AC9zCn7yl2Vf008GmA0bL/j6r6CyLyvwKfAD4zXr/4EH86wzi73B62Uznqn8fYokr6gFNFi0CpSio8MhSk0uFbx1BnCx+rnHarbvQJhuyKG6YODeD7EpzLgldFXe52e9zi+u48TJz9M8CvicgvAm8AP/cQv8swzjZ3s/Qub/HF+5wH7x2+7fDe4+d5tHWYlzm9duropxyf5xkn03roI8TCIUNBCI5ijPeLSO6BNww58+4e1bIPJHZV/U2y1x1VvQFY90jDuJ2Tlj45RBQlj5US71Hfjm2yHX5IuQovKpI8qDsaPY3krT3ksJ16IZVCTA5fBaQu8md1Xd7Oew/DHXtNApZBZxiPh3dY+hOddJZJOt4jbRZ9mNf4uiRMS4ppSZx4wponBaGfCOqAMS4/VELyDkkFAN45XD/krrtdh3QmdsN4eryr6Can4eJdnoIzJuc4l1tlo5qdcUUWtrps1XWpY8k3gFQ4XOHR4JGYB2rcCxO7YTwJTsbol40yVfOZux/G1lktfl7lhhpNTSo9fr1AgzBMHMlL3to7iJUDHX9HV+ebRFPBoVl2wzgdnGykMW7rGYZxqoxDnMspuE323vvSk6LgihxWe4eFd+R+9sEhweeafEzshnF6WDbSiPGdwy9izF1xyyKfw4uA9BEtPG6ocjpt7dFAFr4TUuFI9Xh+r/OknLthYjeMp8UdzvLAkWddVJEmIFFxZUDUocGRON7OLyfVsDzvm2U3jFPKycq6pOPXWNgSQpZu8ARAC48MSipdtuhekJRLZVNwuKo4Hot1B0zshvG0ub0FtqZxTlyOwTOOrmYIuOCOgu8qLjv7RMDlBhr3MOwmdsM4VYyJOIzdaejGSjrnkFjgvEOGkHPgl/eIpcC9mGU3jDPBbRb+qJrO+/w1RMQ5GCLLiLr67I0H0OCOPfV3wMRuGKeNk467MdVW+gG8Qtfn8dDB43zuSZ+O3nNvTOyGcRo54bg7svDO5TP8EBARHNlphxs76NxjCw8mdsM4vZyw8KjmFNuYx1kxROjH3vbD/Q2GNLEbxmlmKfh+yBa8G5134nI9uyoyOuaWZ/e7YWI3jDOCptG6i0DKDSskJoi5Wy331rqJ3TBOPSd73i3z6LusbAHwLsfh3fGU6jvxHvcCwzBOBSe97cssuxjRlLJ1T+meXWrALLthnB1Uj4tnxo40Ii5n2I3huHuF4EzshnEWSScm04wZdsR7m3YTu2GcJZbnd5XcCKMfwPe5Br4I97TsdmY3jLPGct7bOBXmaNJrSvfsG29iN4yziCY05TN8Fnoct/F2ZjeMZ5OlhY/L9lZ3/1Gz7IZxVtE8bYaUsnXXk3Wv78bEbhhnkWUfu6RZ6DGOCTd3f4uJ3TDOKrd73s2yG8YzzLiV13G4o1l2w3iGOZremsyyG8azjabceNIsu2E8wxwNkEz3zJ4DE7thPBPoyX7zd8HEbhjPCKp6r128id0wzjx60kF3dyxd1jCeBTRXwt0Ls+yG8UzxkGd2EdkSkc+LyB+IyOsi8pMisi0ivyEi3xyv5x7Zeg3DeGA06SMJvf0y8K9U9YeAHwVeBz4FvKaqrwKvjY8Nw3ga3MdEmPcUu4hsAP8J8Cv5d2qnqrvAx4HPjT/2OeCvvM9lGobxKNB7t6W6H8v+YeAa8I9E5Ksi8g9FZAY8p6qXAcbrpTu9WUQ+KSJfEZGv9LQPtnjDMB4Z9yP2APw48A9U9ceAQx5gy66qn1XVj6jqRwqq97lMwzAelvsR+5vAm6r65fHx58nivyIiLwCM16uPZ4mGYTwK3lPsqvp94Hsi8oPjUx8Dvg58CfjE+NwngC8+lhUahnF/vIeT7n6Tav474FdFpAT+CPivyDeKXxORXwTeAH7uIZZpGMZj5r7ErqpfAz5yh5c+9khXYxjGY8My6AxjRTCxG8aKYGI3jBXBxG4YK4KJ3TBWBBO7YawIJnbDWBFM7IaxIpjYDWNFMLEbxopgYjeMFcHEbhgrgondMFYEE7thrAgmdsNYEUzshrEimNgNY0UwsRvGimBiN4wVwcRuGCuCid0wVgQTu2GsCCZ2w1gRTOyGsSKY2A1jRTCxG8aKYGI3jBXBxG4YK4KJ3TBWBBO7YawIJnbDWBFM7IaxIpjYDWNFMLEbxopwX2IXkb8hIr8vIr8nIv9ERGoR2RaR3xCRb47Xc497sYZhvH/eU+wi8hLw14GPqOqPAB74eeBTwGuq+irw2vjYMIxTyv1u4wMwEZEATIG3gY8Dnxtf/xzwVx756gzDeGS8p9hV9S3g7wJvAJeBW6r6b4DnVPXy+DOXgUt3er+IfFJEviIiX+lpH93KDcN4IO5nG3+ObMU/BLwIzETkF+73A1T1s6r6EVX9SEH1/ldqGMZDcT/b+D8PfFtVr6lqD3wB+LPAFRF5AWC8Xn18yzQM42G5H7G/AXxURKYiIsDHgNeBLwGfGH/mE8AXH88SDcN4FIT3+gFV/bKIfB74HWAAvgp8FlgDfk1EfpF8Q/i5x7lQwzAeDlHVJ/ZhG7KtPyEfe2KfZxirxpf1Nfb0ptzpNcugM4wVwcRuGCuCid0wVgQTu2GsCCZ2w1gRTOyGsSKY2A1jRTCxG8aKYGI3jBXBxG4YK4KJ3TBWBBO7YawIJnbDWBFM7IaxIpjYDWNFMLEbxopgYjeMFcHEbhgrgondMFYEE7thrAgmdsNYEUzshrEimNgNY0UwsRvGimBiN4wVwcRuGCuCid0wVgQTu2GsCCZ2w1gRTOyGsSKY2A1jRTCxG8aKYGI3jBXBxG4YK4KJ3TBWBBO7YawIJnbDWBFEVZ/ch4lcAw6B60/sQx+eC5yd9Z6ltcLZWu9ZWesPqOrFO73wRMUOICJfUdWPPNEPfQjO0nrP0lrhbK33LK31btg23jBWBBO7YawIT0Psn30Kn/kwnKX1nqW1wtla71la6x154md2wzCeDraNN4wVwcRuGCvCExO7iPxFEfmGiHxLRD71pD73fhGRV0Tk/xaR10Xk90Xkl8bnt0XkN0Tkm+P13NNe6xIR8SLyVRH59fHxaV7rloh8XkT+YPwb/+RpXa+I/I3x38Dvicg/EZH6tK71QXgiYhcRD/xvwH8G/DDw10Tkh5/EZz8AA/A3VfVPAh8F/ptxjZ8CXlPVV4HXxsenhV8CXj/x+DSv9ZeBf6WqPwT8KHndp269IvIS8NeBj6jqjwAe+HlO4VofGFV97F/ATwL/+sTjTwOffhKf/RBr/iLwF4BvAC+Mz70AfONpr21cy8vkf3R/Dvj18bnTutYN4NuMDuETz5+69QIvAd8DtoEA/Drwn57GtT7o15Paxi//gEveHJ87lYjIB4EfA74MPKeqlwHG66WnuLST/H3gbwHpxHOnda0fBq4B/2g8dvxDEZlxCterqm8Bfxd4A7gM3FLVf8MpXOuD8qTELnd47lTG/ERkDfjnwH+vqntPez13QkR+Friqqr/9tNdynwTgx4F/oKo/Rq6POJXb4PEs/nHgQ8CLwExEfuHprurR8KTE/ibwyonHLwNvP6HPvm9EpCAL/VdV9Qvj01dE5IXx9ReAq09rfSf4KeAvi8h3gH8K/DkR+ceczrVC/v//pqp+eXz8ebL4T+N6/zzwbVW9pqo98AXgz3I61/pAPCmx/xbwqoh8SERKssPjS0/os+8LERHgV4DXVfXvnXjpS8Anxu8/QT7LP1VU9dOq+rKqfpD8t/y/VPUXOIVrBVDV7wPfE5EfHJ/6GPB1Tud63wA+KiLT8d/Ex8jOxNO41gfjCTo+fgb4Q+DfA//z03ZW3GF9/zH5aPHvgK+NXz8DnCc7wr45Xref9lpvW/dPc+ygO7VrBf4M8JXx7/svgHOndb3A/wL8AfB7wP8JVKd1rQ/yZemyhrEiWAadYawIJnbDWBFM7IaxIpjYDWNFMLEbxopgYjeMFcHEbhgrwv8P4h0ik6MxsFcAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Copyright 2018 Daniel Dazac\n",
"# MIT Licensed\n",
"# License and source: https://github.com/dfdazac/wassdistance/\n",
"class SinkhornDistance(torch.nn.Module):\n",
" r\"\"\"\n",
" Given two empirical measures each with :math:`P_1` locations\n",
" :math:`x\\in\\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\\in\\mathbb{R}^{D_2}`,\n",
" outputs an approximation of the regularized OT cost for point clouds.\n",
" Args:\n",
" eps (float): regularization coefficient\n",
" max_iter (int): maximum number of Sinkhorn iterations\n",
" reduction (string, optional): Specifies the reduction to apply to the output:\n",
" 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,\n",
" 'mean': the sum of the output will be divided by the number of\n",
" elements in the output, 'sum': the output will be summed. Default: 'none'\n",
" Shape:\n",
" - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`\n",
" - Output: :math:`(N)` or :math:`()`, depending on `reduction`\n",
" \"\"\"\n",
" def __init__(self, eps, max_iter, reduction='none'):\n",
" super(SinkhornDistance, self).__init__()\n",
" self.eps = eps\n",
" self.max_iter = max_iter\n",
" self.reduction = reduction\n",
"\n",
" def forward(self, mu, nu, C):\n",
" u = torch.zeros_like(mu)\n",
" v = torch.zeros_like(nu)\n",
" # To check if algorithm terminates because of threshold\n",
" # or max iterations reached\n",
" actual_nits = 0\n",
" # Stopping criterion\n",
" thresh = 1e-1\n",
"\n",
" # Sinkhorn iterations\n",
" for i in range(self.max_iter):\n",
" u1 = u # useful to check the update\n",
" u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u\n",
" v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v\n",
" err = (u - u1).abs().sum(-1).mean()\n",
"\n",
" actual_nits += 1\n",
" if err.item() < thresh:\n",
" break\n",
"\n",
" U, V = u, v\n",
" # Transport plan pi = diag(a)*K*diag(b)\n",
" pi = torch.exp(self.M(C, U, V))\n",
" # Sinkhorn distance\n",
" cost = torch.sum(pi * C, dim=(-2, -1))\n",
" self.actual_nits = actual_nits\n",
" if self.reduction == 'mean':\n",
" cost = cost.mean()\n",
" elif self.reduction == 'sum':\n",
" cost = cost.sum()\n",
"\n",
" return cost, pi, C\n",
"\n",
" def M(self, C, u, v):\n",
" \"Modified cost for logarithmic updates\"\n",
" \"$M_{ij} = (-c_{ij} + u_i + v_j) / \\epsilon$\"\n",
" return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps\n",
"\n",
" @staticmethod\n",
" def ave(u, u1, tau):\n",
" \"Barycenter subroutine, used by kinetic acceleration through extrapolation.\"\n",
" return tau * u + (1 - tau) * u1\n",
"\n",
"n = 100\n",
"x = torch.linspace(0, 100, n)\n",
"mu1 = torch.distributions.Normal(20., 10.).log_prob(x).exp()\n",
"mu2 = torch.distributions.Normal(60., 30.).log_prob(x).exp()\n",
"mu1 /= mu1.sum()\n",
"mu2 /= mu2.sum()\n",
"mu1, mu2, cost = mu1.cuda(), mu2.cuda(), cost.cuda()\n",
"sinkhorn = SinkhornDistance(eps=1e-3, max_iter=200)\n",
"def x():\n",
" mu1_ = mu1.detach().requires_grad_()\n",
" dist, P, C = sinkhorn(mu1_, mu2, cost)\n",
" gr, = torch.autograd.grad(dist, mu1_)\n",
" torch.cuda.synchronize()\n",
"\n",
"dist, P, C = sinkhorn(mu1.cuda(), mu2.cuda(), cost.cuda())\n",
"torch.cuda.synchronize()\n",
"x()\n",
"%timeit x()\n",
"pyplot.imshow(P.cpu())\n",
"sinkhorn.actual_nits\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.08 ms ± 2.68 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"def y():\n",
" mu1_ = mu1.detach().requires_grad_()\n",
" l = SinkhornOT.apply(mu1_.unsqueeze(0), mu2.unsqueeze(0), cost, 1e-3, 200)\n",
" gr, = torch.autograd.grad(l.sum(), mu1_)\n",
" torch.cuda.synchronize()\n",
"y()\n",
"%timeit y()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With this problem size and forward + backward, we achieve a speedup factor of approximately 6.5 when doing about 3 times as many iterations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Barycenters\n",
"\n",
"We can also do barycenters. Let's go 2d to do so. I use relative small $N$ because at the time of writing, my GPU is partially occupied by a long-running training."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0143], device='cuda:0')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD6CAYAAABnLjEDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAALX0lEQVR4nO3dX4ild33H8fen20gsUsw2k7BkQ8eLpRikJjCkKelFSVzYpuLmRkjAsheBvbEQQZBNCwXvciXe9GapwQVFCSjsEgRZVkMpSMxook26xk1LqotLZmIR2xtp9NuLebTDzuyes3P+zDnzfb/g8JznOTP7fGbYz/zO7zfPOZOqQtLB93v7HUDSfFh2qQnLLjVh2aUmLLvUhGWXmpio7ElOJHkjyZtJzkwrlKTpy15/z57kEPBj4DhwFXgZeLKq/u1Gn3PnnXfW6urqns4nabS33nqLd955J7s99vsT/LsPAm9W1X8AJPkqcBK4YdlXV1dZX1+f4JSSbmZtbe2Gj03yNP4e4Kfb9q8OxyQtoEnKvttThR1zgiSnk6wnWd/c3JzgdJImMUnZrwL3bts/Cvzs+g+qqrNVtVZVaysrKxOcTtIkJin7y8CxJB9I8h7gCeDCdGJJmrY9L9BV1btJ/hb4JnAIeK6qXp9aMklTNclqPFX1DeAbU8oiaYa8gk5qwrJLTVh2qQnLLjVh2aUmLLvUhGWXmrDsUhOWXWrCsktNWHapCcsuNWHZpSYsu9TERC9x1WjJrm/0eVP+ZV3NgiO71IRll5qw7FITztlvwV7m3/M6j/N8jeLILjVh2aUmLLvUhGWXmnCBbjCvxbdZ2S2/i3bazpFdasKyS01YdqmJtnP2ZZ+jj+P6r9E5fG+O7FITll1qwrJLTbSZs3eYo4/iHL43R3apCcsuNWHZpSZGlj3Jc0k2kry27djhJBeTXBm2d8w2pqRJjTOyfxE4cd2xM8ClqjoGXBr2tWSS7Ljp4BpZ9qr6Z+C/rjt8Ejg33D8HPD7dWJKmba9z9rur6hrAsL3rRh+Y5HSS9STrm5ubezydpEnNfIGuqs5W1VpVra2srMz6dJJuYK9lfzvJEYBhuzG9SJNzLirttNeyXwBODfdPAeenE0fSrIzzq7evAN8B/iTJ1SRPAc8Cx5NcAY4P+5IW2Mhr46vqyRs89OiUs0iaoTYvhNF4fLHMweXlslITll1qwrJLTVh2qQnLLjVh2aUmLLvUhGWXmjgQF9X4QhdpNEd2qQnLLjVh2aUmLLvUhGWXmrDsUhOWXWrCsktNWHapCcsuNWHZpSYsu9TEgXghzPXvgOoLY6SdHNmlJiy71IRll5o4EHN2TY9/AebgcmSXmrDsUhOWXWrCsktNWHapCcsuNWHZpSZGlj3JvUm+neRykteTPD0cP5zkYpIrw/aO2ceVtFfjjOzvAp+uqg8CDwGfTHIfcAa4VFXHgEvD/kKoqh037eT3qZeRZa+qa1X1/eH+fwOXgXuAk8C54cPOAY/PKKOkKbilOXuSVeAB4CXg7qq6Bls/EIC7pp5O0tSMXfYk7wO+Bnyqqn55C593Osl6kvXNzc29ZJQ0BWOVPcltbBX9y1X19eHw20mODI8fATZ2+9yqOltVa1W1trKyMo3MkvZgnNX4AF8ALlfV57Y9dAE4Ndw/BZyffrzpcSFK3Y3zEteHgb8B/jXJq8OxvwOeBZ5P8hTwE+DjM0koaSpGlr2q/gW40Zu6PTrdOJJmxSvopCbavlNNx3ekda2iN0d2qQnLLjVh2aUm2s7Zr7fbfHaZ5vHOxzWKI7vUhGWXmrDsUhOWXWrCBbqbGLXoNa8FPBffNA2O7FITll1qwrJLTThnn4Bz6dlyTWS6HNmlJiy71IRll5pwzq59sUgvMhqV5aDM6R3ZpSYsu9SEZZeasOxSEy7QaeoWafFtGnb7epZx0c6RXWrCsktNWHapCefsmthBm6OP4/qveRnm8I7sUhOWXWrCsktNOGeXpmAZ5vCO7FITll1qwrJLTYwse5Lbk3w3yQ+SvJ7ks8Pxw0kuJrkybO+YfVxJezXOyP4r4JGq+jBwP3AiyUPAGeBSVR0DLg37OuCS7Lhpp0X8Po0se235n2H3tuFWwEng3HD8HPD4LAJKmo6x5uxJDiV5FdgALlbVS8DdVXUNYNjedYPPPZ1kPcn65ubmlGJLulVjlb2qfl1V9wNHgQeTfGjcE1TV2apaq6q1lZWVPcaUNKlbWo2vql8ALwIngLeTHAEYthvTDidpesZZjV9J8v7h/nuBjwA/Ai4Ap4YPOwWcn1FGSVMwzuWyR4BzSQ6x9cPh+ap6Icl3gOeTPAX8BPj4DHNKmtDIslfVD4EHdjn+c+DRWYSSNH1eQSc14avedFOLcDGIpsORXWrCsktNWHapCefs0pzs97vZOLJLTVh2qQnLLjVh2aUmLLvUhGWXmrDsUhOWXWrCsktNWHapCcsuNWHZpSZ8IYw0J/v9Z5wd2aUmLLvUhGWXmrDsUhOWXWrCsktNWHapCcsuNeFFNbqp6y8E8S/ELC9HdqkJyy41YdmlJpyzSzOw3y962Y0ju9SEZZeaGLvsSQ4leSXJC8P+4SQXk1wZtnfMLqakSd3KyP40cHnb/hngUlUdAy4N+zrgqmrHTcthrLInOQr8NfBP2w6fBM4N988Bj081maSpGndk/zzwGeA3247dXVXXAIbtXbt9YpLTSdaTrG9ubk6SVdIERpY9yUeBjar63l5OUFVnq2qtqtZWVlb28k9ImoJxfs/+MPCxJI8BtwN/mORLwNtJjlTVtSRHgI1ZBpU0mZEje1U9U1VHq2oVeAL4VlV9ArgAnBo+7BRwfmYptdBcsFuO78Ekv2d/Fjie5ApwfNiXtKBu6XLZqnoReHG4/3Pg0elHkjQLXkEnNeELYTR1Hd7wYlHn5TfjyC41YdmlJiy71IRzds3cbvPbRZ7HL+N8fByO7FITll1qwrJLTVh2qQkX6LQv9rIItpdFvYO62LYXjuxSE5ZdasKyS004Z9fScP49GUd2qQnLLjVh2aUmLLvUhGWXmrDsUhOWXWrCsktNWHapCcsuNWHZpSYsu9SEZZeasOxSE5ZdasKyS01YdqkJyy41YdmlJiy71IRll5rIPN+xM8km8J/AncA7czvx5JYp7zJlheXKuwxZ/7iqVnZ7YK5l/91Jk/WqWpv7ifdomfIuU1ZYrrzLlHU3Po2XmrDsUhP7Vfaz+3TevVqmvMuUFZYr7zJl3WFf5uyS5s+n8VITcy97khNJ3kjyZpIz8z7/zSR5LslGkte2HTuc5GKSK8P2jv3M+FtJ7k3y7SSXk7ye5Onh+KLmvT3Jd5P8YMj72eH4QuYFSHIoyStJXhj2FzbrOOZa9iSHgH8E/gq4D3gyyX3zzDDCF4ET1x07A1yqqmPApWF/EbwLfLqqPgg8BHxy+F4uat5fAY9U1YeB+4ETSR5icfMCPA1c3ra/yFlHq6q53YA/B765bf8Z4Jl5Zhgj4yrw2rb9N4Ajw/0jwBv7nfEGuc8Dx5chL/AHwPeBP1vUvMBRtgr9CPDCMv1fuNFt3k/j7wF+um3/6nBskd1dVdcAhu1d+5xnhySrwAPASyxw3uFp8avABnCxqhY57+eBzwC/2XZsUbOOZd5lzy7H/HXABJK8D/ga8Kmq+uV+57mZqvp1Vd3P1qj5YJIP7XOkXSX5KLBRVd/b7yzTNO+yXwXu3bZ/FPjZnDPcqreTHAEYthv7nOd3ktzGVtG/XFVfHw4vbN7fqqpfAC+ytT6yiHkfBj6W5C3gq8AjSb7EYmYd27zL/jJwLMkHkrwHeAK4MOcMt+oCcGq4f4qtufG+SxLgC8DlqvrctocWNe9KkvcP998LfAT4EQuYt6qeqaqjVbXK1v/Rb1XVJ1jArLdkHxY+HgN+DPw78Pf7vWhxXbavANeA/2XrWchTwB+xtVBzZdge3u+cQ9a/YGsK9EPg1eH22ALn/VPglSHva8A/DMcXMu+23H/J/y/QLXTWUTevoJOa8Ao6qQnLLjVh2aUmLLvUhGWXmrDsUhOWXWrCsktN/B+xRi0uAJLuNgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"N = 50\n",
"a, b, c = torch.zeros(3, N, N, device=\"cuda\")\n",
"x = torch.linspace(-5, 5, N, device=\"cuda\")\n",
"a[N//5:-N//5, N//5:-N//5] = 1\n",
"b[(x[None]**2+x[:,None]**2 > 4) & (x[None]**2+x[:,None]**2 < 9)] = 1\n",
"c[((x[None]-2)**2+(x[:,None]-2)**2 < 4) | ((x[None]+2)**2+(x[:,None]+2)**2 < 4)] = 1\n",
"pyplot.imshow(c.cpu(), cmap=pyplot.cm.gray_r)\n",
"coords = torch.stack([x[None, :].expand(N, N), x[:, None].expand(N, N)], 2).view(-1, 2)\n",
"dist = ((coords[None]-coords[:, None])**2).sum(-1)\n",
"dist /= dist.max()\n",
"a = (a / a.sum()).view(1, -1)\n",
"b = (c / b.sum()).view(1, -1)\n",
"c = (c / c.sum()).view(1, -1)\n",
"SinkhornOT.apply(a, b, dist, 1e-3, 200)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def get_barycenter(mu, dist, weights, lam=1e-3, N=1000):\n",
" assert mu.dim() == 2 and dist.dim() == 2 and weights.dim() == 1\n",
" bs = mu.size(0)\n",
" d1, d2 = dist.size()\n",
" assert mu.size(1) == d1 and d1 == d2 and weights.size(0) == bs\n",
" log_mu = mu.log()\n",
" log_u = torch.full_like(mu, -math.log(d1))\n",
" zeros = torch.zeros_like(log_u)\n",
" for i in range(N):\n",
" log_v = sinkstep(dist.t(), log_mu, log_u, lam)\n",
" log_u = sinkstep(dist, zeros, log_v, lam)\n",
" a = torch.sum(-weights[:, None] * log_u, dim=0, keepdim=True)\n",
" log_u += a\n",
" return (log_v[:, None, :]-dist/lam+log_u[:, :, None]).exp()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's fast enough to just use baricenters for interpolation:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"res = []\n",
"for i in torch.linspace(0, 1, 10):\n",
" res.append(get_barycenter(torch.cat([a, b, c], 0), dist, torch.tensor([i*0.9, (1-i)*0.9, 0], device=\"cuda\"), N=100))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fc899008be0>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAB2CAYAAABMKevGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAB1iElEQVR4nO29W6is3ZoW9oyahzrXPK3///une+/svuiLmEAUNkbom0YT0oli50bRoHRA6BsFJQbd7U3IRaAhIOYiudhESQcl2qCgiCDSSRMCQe02BtN2TBrbuHfvf//rMA8163z6crHWM+bzvXOMr76asw7fXGs8UFTNmlXf4a1xeJ/3ecc7XJZlSEhISEhISEhISEhISDgsaoe+gISEhISEhISEhISEhIREzhISEhISEhISEhISEiqBRM4SEhISEhISEhISEhIqgETOEhISEhISEhISEhISKoBEzhISEhISEhISEhISEiqARM4SEhISEhISEhISEhIqgGeRM+fcTzvn/oVz7jedc9/Z1kUlJCQkJCQkJCQkJCR8anBP3efMOXcE4P8B8O8D+D6Afwzgj2ZZ9s+3d3kJCQkJCQkJCQkJCQmfBp6jnP1uAL+ZZdm/zLJsBuCvA/iZ7VxWQkJCQkJCQkJCQkLCp4XjZ3z3RwF8T/7+PoB/t+gLr169yr71rW8945QJCQkJCQkJCQkJCQkvF7/2a7/2Nsuyz0L/ew45c4H3HuVIOud+DsDPAcA3v/lN/Oqv/uozThlHLD3TudBlJiQkJCQkJCQkJCQk7B/Ouf8v9r/npDV+H8A35O8fA/AD+6Esy76bZdm3syz79mefBQnik5Fl2UaPhM2R7JeQkJCQkJCQkJCwHzxHOfvHAH7COffjAH4bwB8B8J9s5aoKoERBicNqtfLvAe8Vs9iD/094gJIvtW8I1obJlutRhtwmOyYkJCQkJCQkfNp4MjnLsmzhnPtTAP4+gCMAfyXLsl/f2pU9Pp9/VkLGZ7625KxWq6FWq/nXStD4uU8ZlpSF7EwoGQvZ7VO3ZQgh0huCcy5o64RySLZLSEhISEhI+BjwHOUMWZb9PQB/b0vXEjvHIyK2XC6xWq2wWCywXC4xn8/969Vq5YnY0dERjo6OcHp6ilqthpOTE0/Wjo6OcoTtU4IlYFQdSXKV7PIZQI7k6iOpk+9hFccY2Q0RiVqt5v+OKZOfmj0VMZK7jvDqc+h/CQ8om7qcbJeQkJCQkLA7PIuc7RLWmV0ul/6xWCywWCwwm838sxI0ADg6OsLJyQmOj49Rr9dxfHyM09NTHB8f4/j4/W3TIf5UCJoSBas6Kukl8eV7/A4J7fHxcY78WrJLuwIfvyO3johZkhsiZ7QbgBzZDRE2/v2xoyjNNmZzAI/sVIbgfgr2VIRImH0v9BnaKWRr+5mE90iENyEhISFhU1SSnKljS8JAIjadTjGbzTCbzTAajTCfzzEajfz7i8UCzjkcHR2hXq/j9PQU7XYbp6enaLVa/r1Go+EJG4CcM/wxwpIFS3ZVgVTiu1gs/PePj4/hnPNk9+TkBCcnJ54Ik6zxc0ouPkbYAEIR4VXiS6USeCBjSnKtMmnTcoGP26Z8VvKl6u66gj9WxVXiG/p/lmXRVN2PBaF1pEWkN4ZYoECJ26ccSCh6r8wxiuz1KdiyDBLhTUhI+NhRKXKmjhjVG6pio9EIs9kMw+EQ4/EYk8kE9/f3mE6nGAwGmEwmnlAAwMnJCer1Our1Onq9Hur1OrrdLtrtNprNJnq9Hk5PTz3poJr2sTlpljyQiNGu8/kck8kkR3z1vdls5o9DAtZoNHBycpIjuc1m0793enrqSZqSi48F1qaaZrtarTCfzz3xnc/nj96jGgnA2ydGcvmeKpWx9ZMvGSGiq+m1y+Uy90y7W1XSrjWNkdwY8f2YyG8spbbsw2IdyQ2lNX/MNtXnov+tIxMh+2z63ktHWYK7KemN2WfT9z8GbGq7TfEx2y4hYd+oFDkDkHO85vO5JwwkYP1+H4PBAOPxGHd3d56kTSYT/1nnHE5OTtBsNlGv1zEajdBsNjGbzTCdTtHpdHB0dITlchkkZR/bIEOnliRhsVh4e02n06ACOR6PMZ/PMZvN/KDONFESsUajgVarhdPTUyyXS9Trda9wkFAQHwtBs6mhqjxqmq0qkPo/Tb0F4JWyer3u10cyWECSu1qtcHx8/GjtH/HS7RpKYbaptapA6v+VsCk5UyXSqpL6HlXeLMtyRM0qQC8NVnVUwmsJbSz1NqZGhght6DXt+LGkOcfIrr637rXCktiy74Xa5ku1axG5LfqM/X8ZsqWvQ0olbRv6/EvDOiK2LaIWsmmZz37s2MS+n5JdEsqjMuRMnQcShOFwiLu7O4zHY3z99dcYjUa4vr7G3d0dhsMhbm9vMZ1OPTkj8QCA4+NjNBoNr5w1m01cXV3h/Pwc5+fnGI1G6HQ6WC6XaDabaDabaDQa0XU+LxGaWkdFbDAYYDqdot/vexXy7u4uR4CVsFH5AeCJA211enrqVcnz83O02220222cnZ2hXq+j3W779Eeb6vgSoc4slTASXNpvsVh4lXc2m3kFks980KGgOsZUUVUlm81mLi23Xq/n0nF1rd9Lba9KyEh0SWin06nv09Pp1Kc3W1WS9qRNSbpIwqwqSfvS7laxtDZ9SXYNqY8MCNC+IeKrpFjJLvC48q2m3tp1qLS9JcNK2PSYLwHWpjEl15LbIrILhJXHUIElq/jGPveSgoubqLchghtTd0OEls/r3lun+FbdrmWDAuv+XocygYEYKX5p6m+ZwMCuUcY+636DqmEfdtsl9mHbSpAzncxIJCaTCYbDIe7v7zEYDHB9fY3BYIC3b9/i7u4Oo9HIk7PhcOgdZBKJWq3mCcRsNkOz2cw5Ko1GA6vVCq1Wy0faqVh8DGul1KZ0eGkrErL7+3tPgCeTiSdu0+kU4/E4l9YIwNun0Wh44jufz70tSVhUlVQ146WvmbIO73w+x3g8zqXZkqTR3iFyRodO252SM6aGzmYznJ6eYrVaoV6ve8Xt5OTEO8TAe/UNeHk2tc6uqrq0LdXb6XSaUyD5WUvO2MZIzlSB5BpJkrOTkxOvSq5WK/838JBuSrwk22q/V7sy2KLEVgmbrotUdRdAkIhZBVLJsKbqrlYr/zng5ThmhFV0NT3Z2swqkSGCFiIBRQSMdo+tR7VEjceuKkIBhJiqGyNqZdXIGOnie6FCTPqb6Pf0+FVDyF6692uI3Op3y2Id6dLXZclw6BiHRlEwIGavbRCObRCxKhO1sgGDqqHIfruybWXImabcDQYDDIdDvH37Fl9//TVub2/xve99D/1+H2/evPGpjYPBIOfEcaIEkCNbnU4HjUYD19fXuLy8xMXFBRaLBc7Pz+Gc846fcw6np6d+IixKl6gy1JmgszsYDDAajfD69Wvc39/jzZs3nqDd3Nx4MqwFV+gME3S6SMyonDUaDVxeXuLs7Ay9Xg+z2QztdhtZlqHZbKLVagGAV89eUtScsDYlIVOSSzv2+31PcKfTqVd+SM7YRtnOVMGhbev1um+3Z2dnaLfb6HQ6mM1maDQayLLMkwwAL4742oAMgwK032w282tKJ5NJjqjRjiHlTNUbS85UlWQ6brvd9ook/0/CpmoQUH27qj1VeWQ75MO2x5Caplto0FENETFdF0mVnIGGUJVcq6rx+FUE75/EQUmupjCrKpllmX8OETaFkjGrROpzyN60I/+2iqaeo0qw/V7X6dq05XVkzapnRaQ39p61d+g3sIRZz1UF2CBXLFiwTonUZ0UR2bLPMSIc+z2qutY3pICHyC7/Dn2/LGL3XEReY2Q4FnwoOs8+oDaL2fElIGRrfX+bODg5U6eXa8z6/b4nED/84Q9xfX2NH/zgB/49qj+j0SiX8mTJGR+LxQInJyfeEZlOpzg9PcVkMkGj0fBRYkbQqUgAL3OtlFXMxuMx7u/vvf36/T5++MMfemJxe3uL+XyO4XCYc0DUnsBD9HwymXgFYjqdotFoeBI4mUxQq9UwHo9xdHTk16w557BarfxE95KIr22jTLklIbu9vUW/38f19XWOnGlaozofSs5oU932gQ7ucDj0ayVHo5EnKK1WC1mWeZIGvDzia206mUwwGo1yJPf29ta3KSq5DKQoqdDtHkLKGRUcBhS0cut0OkW9Xker1fLpo1TRSC4AVNIpU1ilXEkZ0745ZnKcjZEzW1EUQJCcqTpGJZLpt1zzq4WZqAxT6VUyXVW7WrLLPs2+OJlMfBu2apqmh9qiNUB+30irjml6rSq/dj1qrGJuFQsGaZ/XFGbOHaryWnJRtkorn2NrIC0hjq1Dte08lFFTBdsq0dXsFTs+xlJs16lqZdWvIiKmfTykAut79rz7hg3I2KCVDbQUkbR1xGNTBWwTu1t7H3L+smNgLGign68CQvaP2Tj0+eeiMuSMAwvXO9HZfffuHd6+fesVn+vra69a0PHVaBGhBiQ542eY5rhYLNDpdLyT0W634ZxDo9HINeiXBB181fEdDAbepnd3d3j79i1ub2/9+yRxmvJko706mDIaPp/PfSEQpk2dnp76dEcA3kF2znlHukoT3DrEHN/RaIS7uzvc3Nzg7u4O7969w3g89vbkQ9uoTRdT54BOFx8kvlyHuVgsfHtmyii/Czz8PlUnvjqx0ZFgyu1gMPBpyzc3N7m0UY4PGnXXiCbwMICqQsP1eVTGlJSxkA2PVa/Xc4UsYlHiKsIqEhwj7+/vMRqNPPll+7VOnC2wQvCe1Wnls10zSSLGMeD09NS3f/4ODCRoMRY9TxVgx1ESCQ0iaPq37rNpt80oImch5UzJglUlOe4qGV4ul358tcev2lhg2yjVXQYNdLxcZ0cdR6yzGlqbZ9Uxuz5SVWC1O7/H81QlUGPVHc7d+lB1ch3BXUfOrBoTUmliJNi2b1WDQ989JDQgoIHqmB35Hf2+PZ7FOhXM/h1TbEJBCB1D9DOxc+wSto2GyO5LUNBCYzbHjl2NCQclZ1bhGY1GGAwGeP36Na6vr/H9738f3/ve9/Du3Tv84Ac/8GvQOKBbEmEHFT4WiwWOjo68czccDlGr1dDv9wHAl+E/Pj72JfZJ0KoyYJSFDtR0JG5vb/H27Vu8e/cOv/3bv427uzt89dVXPh1vMBj4FEa75iTmTEynU6+MnZyc+LVWPFa328VqtfIRZv0d6Ei8BLvGFJ7r62v0+3189dVXuL6+xs3NjVfO6ADzEYsShQZUdcRGoxFOTk4wmUzQarW8AtJut7FcLtHpdLxDrYrESyASbGN0ePv9Pm5vb3F7e4sf/vCHGA6HuUAM+3xIlbABBHUM1BlQx5bFVgaDAZrNJtrttq/qulgs0Gw2/TmYzltFZ5fQ1DuShru7O79Ol31zOBzm1vDpGkhbCIT3qg9VdNShVdty3aTuL8n00cVigXq9nvvtTk5OAFSvvWobZZot56C7uzs/l7CfK6kIEQlLQouUBbuGT8kZK+VqRWIWYDo5OQna9NC2tc6ZZnRMp1NflIrqpF0Lqcfg6zLkDEAwVZQ21/WRoUADA2VcQ83jVWHZg871LPilQYOQn1S0ns+q5USIjMUeti2HAg06XtgMBXu+fUL7rGZZMTA4mUxy5GIdsd1EOQv105g6BoSJMMlCqPAVv7Nv21r/SVPqbbu03zs07G+hbdkGyrIs27qCVgnlzDpq4/HYKzqDwcBPihpBVzVCmTmhEQU6sNPp1L93f38P5xzu7u783mfD4dATDaY42mhnlaEDhE6CGu1leuNwOPSpcrqOx3YUvffVauUdB3Y4596nK2qFO6bd9Xo970h0u12fYsqBxB6/qrDKmbZTpjfyoSl4IcdXoQ4DUz75WTokqpI553zaXavVQq1W8+ulWITlJZDekBJp2ymDBnQ0bJ8PreEhrANG2zBlmb8F2yL/ZvEKBmboCNdqNd/2q4hQv6czQZuqXTUdz6br8DhKJIqcXKpfLASka4ioRLLtrlbvC9s457wDoYS3SmNBSDnTYkmj0chnJNgCNUXpTyHCqw6tttmQQ8vxgEqk9gESspBtqwLbRmlPkgnORzbN1pKJIiKhkewYAbYKjq6JpBKp52M713Haku1DwLZRElz1lWhPm2ar3+frEHQMKEvOrAqsATISW16DKukcZw9lU/VHtY1qZpGSM7WjHiP0WhFTsmJKGVGkTrItMy2fnwcexm4bHN4HrE0ZSGAQJjafH4qgWbto+2YQh+MA8NB+tz3WHoycWQeNEx1T7l6/fo3Xr1/jzZs3XqXgoBNKd+Ax9fjAA+PVwYkKzmQyyRGNXq/nUx01esnjVGmSi8GqPNxy4Pr62qtnTG/kAK6KWciWdqCgEkbb6OL46XQKABiNRjg+PsZi8X5rA62OCTw06CoTCY0oMnhAQsZ0xnfv3uH6+tqrPtbxDQ3ewIMtdTIiUeBgS4d3uVx6B2a5XGI4HHolmO0UAOr1OoBqFwfRfq8BGaoRVCAHgwFubm68amFV8pBzpvdtJzFbTIHKL1NHdT1frVbDfD73jgUHZXX6qmZXmyo2Go18IIb27Pf7ucIqVn0MTYY2YqvkTB2vyWTyKKI4Ho+9gjabzbxtOf7yt7FVRw8N20Y144IKL1NuSc44hqpCtM65sGlHfIScWqugUflttVo+TZ/znE23qQpBU3WXYynV3Ovr6xw5s2nLNgARgu3/+l7IoeVYQBuT1HKtZKvV8tvG0OnVtb08xyFggwdso9wDtt/v+y1dNBU81N/LKjyhsSBkWwA5AqxjL7eCYWErzncM8PL32DfxtcTfjqO0baiv8/v6bI9NhBx/+zpG0EL2VhtzfNCiYiQNJMQatAxdzzYRsqdWs7aF/GJz0CFhgw5U0ulrca4Dtp9GfvC0RqoE6vje39/7oiCq7ugPadMdQscmdDAnWRiPx3DOod/v+wH47u4OJycnGI/HXpHgGpSXAB2suYCdAwyj5rSnTRMrIhIxqFLD3xCAX2t2f3/vU8isKkniUXWEoueMnPOhUUqNoMdsagmE2lA/S0JCwuucw3A4RJZlGAwGOD4+RqvV8gVu6PRWXe21jq9GezWNUYkZVTN+P2RP67TpxMR+r44eJzV+j9FGBhbYR3S9atXsap0tKra6lQMfWq2RbauImAFhx4A2sKqXDYAB8BFGOrXT6dSvVa3ypAw8VnnUprru2aaHrhtHtb2GiC8dVpJgDdJowEsDObVaDbPZzH9O0/AObduQ6qWqhLZTq5yVJRLWobXOZ0idtASYqiQAb3M+OLZaJe/QqpmmimpWh9oyRs70ODGUVXJCJE1TdPkAHqo+k4zR1jHiuM/AuPWhbJ9XJTI0dm3S19apZ4S1sSVnfPA3VsJGH20+n3sb77PIlbZRDSLYQIxNqa8KaBvaVP0QcoNQius2bFoZ5UzTmhhNY/l3m8rIzlPmh+RnNE2MDq9zDoPBwEfI7u7uUK/XMR6P/doTToD7juRsClV5Qul3JGiWSNBJ0t/Dooic8rvz+dxfx3g8BgAMBgMfOR+NRqjX67m9u3QgrqpdAUQnQEskVJEoM9ioNE5HzRYM0c8A74kDn09PTz3RZnGLsv3i0FCbUm1UJ01tattoSIXkMx/st/o/Ja8AvHLGv7mNxmQy8fvMKdmucntVcqSOr7WnrjUrSyKAsALBc2o6rqpgSuSoPE6nU185V8mZXkcVbKvjoU0Z00eogEUZ2LapxBeAd7aUnJE8aIodI+a6/1+orVahzYbaqLUp5yZtG2VtCpRfH2XXn/GatLAK1TQSCNteeU9VIGjaRpWgxcgZv7sO69QcfR1SJ0l8bbvV6tmh1NVQeuM+7Gztaef7UBbHU8jZpgpajJxpNggAPyZwvFgsFp6c8XpVVd4VtI3FxlFdUvOUtrlL2N9A15nShiS8GoTfJg5CzkLEjIuDmTbCxexW5dmEmOn5gAeClmWZV3nu7++91H5zc4PT01P0+300Gg10Oh20Wq0XofAAjyO9JGVMbwqRXTsg8jhAPsKrBFURUhMmkwmA9+SsXq/7c7MCIVUeRtCqCjuwhNZF0Z66Cfo6YqbvW/KlBI02JfEF3tvWOefJLvsHBzm7VuLQDplC7ak2VdVM9+SKrS0tcx5GtHS9GG3Lv/kMwBMHPlvFrkya2qFgI+h23YkSCTqZ69qoOvS0lQ1SxRRfXg8dYNrSkrIyyt2+YcfDkHKmRFcruRXdi3W+dDJX4mvJFB0w67jQ0Q0Rw6rY0sKqEiFypsGYEHFXhBzaEMEPKQ86BnCM0DV7GhHXc4RUHB1H7Ll3iZDjS5tqcRWSs6c6kXZ+Kkse+LDzXZFN+burMqTH0/PuAtambI/a3znX8n70u+sQuvZ1bUfv25JfptoCyC3D4W/N+ZTrp5nqqHt47rq92jaq7dL69lUhZ0C+rdutNEIBxm1f88GVs1jEJxQ53wZD1WikVuSxE27I8FVydC2sU6GTn6o6tkpOjJjxtSVo9px0JNZNFC/F0bUI3RfvjfdkU5pi0bRY+6Ejy89ap0HJjF6DOmVVdHSLEGor1q72nkJRdJtaB+SJhbUtoc6ZJSxqT3vNVYYlvvbe9H9FE6ElCXwvFKRRO9vPqB1D9jyEQ7spQv3P2jPUTmOwNuU5QkEqfpZtPvYb8H/Waa4iQgGaUHuNzUn2WErKQgqhEmENslplhjaNEYcQMYvNlfzurhAal9Smakd9bx3ZDV23DSQAj9PhYjayvoEeUwmGEjP+P7R+eNc+mP09dY5Su9q1kE+ZF8qSNPqqtp3rdYYC5PStleiS7DKAq/e5C7vG2qkGaELBwqrMs3Y8ZZqoVXsV27TlwcgZb073NqPCQ6WFrN/+cE/98bRBssGH1BBN/bMRyapPfLSpva9QZUbr7Jaxqx1krXNGUsg1EBrFe6oScgjYgcTm82sAwcryIWLG17b9hJxgvm8dXTuYhWxZVXsS9l5spJd9LkQ6Q4EZtVHMUbMOiTq9lqSE7BeLmFcBvF51HjTaq2t1Y20l1qeLCJo9P1/Trvo9IL9HIp0E65xVycbq7GpARp9D7ZTf3QQhu/I4OraGHFytNKgbUmvBFe0X9rz7grY52tQGmmwVvNj8ZK87RGxDZElTnfkdS37VgeWDqoNuSs3jKZRU7HrMsP5QKIBgiUUZJTI0FvA1ob+Jjr1Kumx71hRHtalmz3A5RJZlvg1rwQUeZx9KT4jsWtV/kz5fdM0hQhx6Vrvq70Jb6b6T9Ktpd147VR8twrJL2Lk11lY180C/d0iE+q/2q21wkiLsnZzZH8nmoeqeRkW5vc+Bdfh0ErYqnTaWqjgOIcRsqukNMeUq5PQ+5fyx96rU4TaFHUxU3bHrILdxf2XamY08qsNR5TZKhKJntu89d+B7Sn9VAqET3b4crqci5kxYBy1EcGORv6fYTx0ybZ+0pzq7MQJRFcScX/vgZ/V7644bi4JbZylEMGhXtanu4ce/1xEze/x92N/255Cqam0eunZLGGyApsjGFnYcte1USYSWzLZrNvV3ocOs59g1rN1if/OzsWPo9ZaxY5Hqq6mJlpTxWddG0UEHkFujpsfjfeyS9PI5Ngc9ZW6KkWH9Xyy4GPpeaE2fppFzfLJjsKo+sevaFWyfjj32fV0h2MBEUTvYFQ6a1qiESBcEhxau20F8W+fXyCjPZVMoD91Q1kEbtEbPVbnS9NBtkgl7HbEJWFFFZywES3iVlIWqMoac33XHD9nBTor2t7KDsxI0tW3VbBxydmN9TxXrbZ47ZhNLynQTYCURVYVtq6FUzU1IhH6u6L6tY0EokVCHl06uXfNQVdva4EwZNTd2HL3HIruGnGKNnlub6mbJGlBQdUiPTcQi9tuGPb8luNo+yxIJ/s8StNj/Q/YO2VRJRMiuVDKzLL9pu66v2mewLBREKCIPMbvatsn37OsQadCggg1iWWWX9qRNWViB648JG2Tg73MIX6yIRPD/ZRGzs/1bbW2/p/O/2lbJGX1nncdIzva9tKSoDRbZdZvz/yYIpY3bMWkftjv4JtQ2xemlrk06NHRwtuW0rRJpbapFE3isp0wqdmB27qGaUCgtpMpOGZBXeGxRAJLedSRik3SG2Gc0KstBVvfbsOlMVbZryKa6ds+20W31/xh50BQmbp/BrTUajUbOiagaQQs5ZpoiZoNbu45K2gpibJ/1eh3NZjP3YLu1xLfMte3L4V2nQoa+s01oG7UpYWyf7XYbzWYTrVbLb/2iRIKpd6Hsk5DqHlJBtoGQc2MJ2jbsF5u7Qgok2yrHU22rrVbLV3FuNpu5rSC0b/G6OfbS/sDjfex2hZiTu4lDGbJb0Xv6P1XLda4KjautVsurZsvl0hfUor9CcOzVgiCqXPJ8u8A6wms/t+mxgcckLXYv6/wpjqPcO261WuWUM/18KCuMj33Pa0V2OxQps+ePkbR94eD7nIVe7wNWLo+pDy8JdkCJ5UevQ+jeN62URKdCN6CschpTDGrTonzjXQ8odnDWyHmI9FYZ1m5WKYu1UwYR7HtlESJoIcXM2rXqhBcIO7wxpwJ4vL5E3y/6O/R5HUtDNlUFwrbXkOMamxesg70PhJw0e20hpeE5KFIhbYqYBmjUrrwuVXn0+vV4lhwfgkjwf/uCnfc16GXbq1bAs2sPef2qBgMP1d303rZt1xjhDZEI+73Y8Z56jSF/imNnSDUH4APyJLwkZ/QdnHOPAvX79hOfSr6A4jHK2roMQdP9y9S+bHO0VSgIb32Xp9xbQhi7suNByVlowrGTDDspEK6o9Jxza0QytKCakZqqO2YWoYkvBN5XlmVBx5ewDnAsMsmH5uwzEsmorioRVbartSEHt1AhAILtuUwb1cnMPtvX2k5tZDcUMa+awkOEgge2KICdVNQebKcWMWe2KIKujgM3Sm+32+h0On4LDavwlBkLDhWFtOvNirIPykSfN70HHa/Z/xktp03VrpoqRoTIpF57aO3fLhxePlubFiln64jZU0mvbadq01arhU6ng3a7jVar5dsqx3J11HT9NoDcvMe5T8eNXbThENHdttNddN2qDlriwP7PLXQajYa3KVWJxWLhl10wi4LXzeOsViuv+PB69qmchf7eBdQniwUPqH7RthxPnXu/FYwuZSHZ5XHYfuv1ulfhNeC0jzF2E4IbsvVzya7aVcfW4+Njv/0TtyVirQb2edvP7J5yNhB6iHlLz08cWjVTxNZTErsmtwcjZ3aQ1CggX9vo6rZIGQ2uE5SeP+boVtHhVYQ6s0ZXrINpHbR1KkQocq3ftal3TGcgeaBdlZzZ41YVoWvV17Xaw35PSibKHCukStrfUsluiPBqiph+v4ooipyvc3DLBBti76vD65zzToS201hKoyWJ9lpCr/fRvu15rcNbZNOYPcsSiZADoWoZHVxrU9tWreoXctqt0qFrqvahSBS1UzuWliVp9rpD5DOk6tj0W44BJAS63pCp15qCzeNyTOF7x8fHOYdkV+12034fgx1/Q/+3wS6rlLH/q011rsqyDPP53BMI3XeV5Mw5h9PTUyyXS7+R/WKx8OdZFwjZFp7rH5W5vlggzM796lM1Go3cZt6r1cpXjiY5Y0CJv8/p6WluXbcqPrzXfY2rRe9tC7EgbWhs1fV4umbb7rkIPGxmH5oTqqKcJSUvjL2TMztQ6qQTm3B0se22zm3z+Ok80IGw534psBO6LWygzuk6h9MeV5/te+r42t9Tf9OXsjaKWOfshxyq2PdCiKmSGpTQAEZoTURs/U7VsW6i2GQtZIjYxj7Dfm9tSruyvdoCC3rMMs6lvY59OLtlCyuwr8dsG3IWQvcVcs44fqpNbVtV0mtT76hQ8z54Dn7Hlta217ctxGy6CYpsqq91HNCxVOcou3ZHbarkTNdyklQwqs5y8rVaDfV63duc3yWZ2JY9YyS3TBvdNkLkzAa7rLJLx3c0GmE+n2M4HOaKbGVZ5p1kKj0kZyR3uyQS64hubFwqY+t1Y61VzbSt2kAC+z3Xa49Go9yaeN2T6+joCLPZzAc8Q9WzdwVrP6vk7OMaioIJSs4A5LZN0aweHiNWzCih+jiYckayxUmn1Wqh2+1iOByi1+thMpmg2WzmpFoAwUG9LHTC43mbzSa63S56vR7Oz8/R6/XQ6XT8pPcSHN7QAGmdTiWdnIw5sajEXXSfITKm5EFTxDqdDs7Pz3F2dpazK68jtt6kKohFrjT1NbaWjo6WDuwh21qVi//XoAADB0xhOjs7e2TTsm31ORPythGyb4jslr2m2Gftcfnb0XHodDrodruPbKopTZpaXYYI8TqUXIcI/L6g52R/t8GYMsTMtlfenzoNTGFqtVro9Xp+bLU25XVoqnCoMIwqPXSYleAx9XwXY0nI4S0i4vx/UWDAjqH2MyFlh/273W6j2+2i3W6j1+v5cYHEjJFzOrx0gknSrANMkrdcLv18QKdv0yBTWZSN3uv5Q4EFhSXotq/Fglycq+h76FjqnPMq2XQ6xWAwyJEztlUAuSp4/O58Ps+lje6CpFm7rCO6RaQs1Bbta33PKthaoELbK229Wq0wGo1yJJekQtdIsd3TfrZI1KHSGtehTDuN/S/Ubq3PQdvyb7Y3rteze9fyOAzGhPpdldSzquMQdlpLzpxz3wDwPwL4EQArAN/Nsuy/cc5dAvgbAL4F4F8B+MNZlt2UPbE69pZMaCrc6ekpJpOJl2ef2imt48eGbxWemBJRVRJhYSPYdhNNHVA5oITuzTqaodf82/6WNk2ETllof6OXALYXVSBDtuRnaVclvnZi02PHHnouS7ZVjbSponquUBS1yLmM/b1NxJyokEPFzxUpu0VkwtrTRnc1nVGDGFbdUdup0qPPOtlpGjH/tr/7LrGODBT1/dD39f0im2qa2Lp0RpIvOrrq9E4mE29rVdhOTk6CTsyusI6U0SZKcmNjZ+y9UEBC1V1tq6pE6BozADnbcW1UjJyR1HKMWSwWfo6lM7erNhoiFUUouo5Y5oF9T0mvtaldTsHfcDab+bU89pkKBYAc4bVre3blANtxvegcmxKH0OdC44HalfdvN+7mGMgqglrp2G5Jo8e16zt3rfgUHbvsee04UPS52Hs6rtnCHzp22rXaaitVx4uU64Rqo4xytgDwZ7Ms+yfOuS6AX3PO/QMA/ymAX86y7Becc98B8B0Af77MSUMEidHA+XyOV69eYbVaod/vP2pwnDTYeYHizmMncQ4gjPCenZ3h8vISr169wqtXr3B1dYVer4d2u52L8vJYVYV2aFUFW62Wj2SzeAQnHgA+KkWbxo6tr3UgoX3oNNCmZ2dn+Oyzz3B1dYXLy0sfPVebliG9ZSN9u0AogKBOkqZq0PY2UqVrDuy1hxxnHYSpmnU6HfR6PVxcXODy8hLn5+e4uLjw0V7m9CsZAOKV0UJKqSVF1lncpj21rdqiBLx/pmZYmxY5GpYAW+dBSS77BNUy2pUFFkLOBcchOhOa46+To7YZjm/sl0omtmXbEFmxD/uZdQuv1xEI+/uRhFGFbLfb3p52PAWQS8UZj8eYz+e59TyTycSP8WwTVIm63S6WyyUajUZuDNqmTa0dioIIZYKG6wII+mzHGio6zO6ggkZHmKrjcDjEZDLBaDTythwMBp6wsX3yN9NtFur1OgB4cqKO3bZsWibdjs8MbgHxIgExxUz/p+MplVdNX+52u/6Z4ygJxHA49LZkGh7TRKlSOPdejef9TKdTr/zwt9mXMxwKvj3lt1sXsFHSwAACbad+RqPRAABfLn86nWI4HPr2qETWju2qBoWqTu9aiSwKIGxKvGL/D40toTRGVcxUGeda0lChIp3jQxke9l6r7NtWDfsitmvJWZZlXwH46sPre+fcbwD4UQA/A+CnPnzsFwH8CkqSMyDv+FrC1Ov1MJ1O0ev1/MQ9mUzgnMulOAKPB27b0HgONnxbRZAV2pguwoGFkfNdOKm7gHZ0du5Q3ne9Xvf58JreyMmwTJoJkE9rsJXEtOod05l035IyqXdFA2UoQrWL38c6olotyVb1tFGtdce157COtW4uq2mNJA8hhZfQwVcntaI0PF2PqK/VSdoW7CRkHzEywd/d/tYxIqHPdrKjUtZut32/t2uirLPAFDwWV2BKiUZ/aS9VrrMs84rPttf2rGtLIUJh+846xyP0e4QIqCrlOpZqYRVVy+gEM2VsPB5jMpnkyBnHKNqLY5f2E7Vr7B6eY9sYMYuda51za4mZ2pVjiV0PpWvMWKmtVqvlFMfhcOgJGckElTM6cWyLbOOnp6eeEHNO2JVKEXJMY/bSdqrBhNhYFDqeDVDY9e1aAIT9lO2QtmNVwfF47G1tCQUAn2oWqjyrY+2hHGEdO8sGPGOkN6SW2YJutEeWZV7JZX/X4hW6hYpzLrpGqmprpYrs+NTjxcgZ+3qounFo3zJCbVbGr0qoFjZac+ac+xaA3wXgHwL44gNxQ5ZlXznnPo985+cA/BwAfPOb37T/84MmI4TL5RKXl5deOeOi28lkglqt5sutkqCRsOkxCeuk6QBCR/f8/NwrEWdnZz6SZhWelwC1py4gpwPKUsGLxcJPRqqmxCLAIedX1R06EkXrotaVew/J7/o+X4cmjNh1PgchsqtpW0X7C1H1iU2IIcfODspapIbrdvRBgqZFK9TZ1hK6nOz0wc/YIIkWkbH3sC27hgiTqkr64P2UVShiSoRVlNkfaM9ut+udXyW8odQ7OrOj0cgXW+BkyWvVUtLaJlSZ2HZ7Vbsq8dVzWcLL99bZU6/ZuQe1SteZagBBgwiWTDByTnWn3+97kkYHbjqdetuzjaxWK38s3hvX+Gj64zYRImTWZlaJLEohKyLTdk0UbUeiq1sRcPwmiZhOp7i/v8dkMsH9/X2OYDDCzmvTdbMcP2azmVfTNL1xF3NfyB5lCFrsOCESbUmZDVbSjropMtPsqEAOBgPfRknOSCo0wEF/hIqZfm6XSs9TESIWZYmZ2jaWysx5XsdJptkyvVbtqPOQJWQhxeylItT39XXIB7DFPyy51Y3QQz5GLHumanas2vVUAaXJmXOuA+BvAvgzWZb1yw4yWZZ9F8B3AeDb3/62/wX4fXWaOp0OnHP4/PPP/WBZq9XQaDSwWq1wf38P51wu/YUpTx/OpdfrHyxAoZHdq6srnJ+f40d+5EfwxRdf4NWrVzg/P/fpDU9VeQrsl3veBdTRphI5m81wdnbmF+D2+3045zCbzXLXYgfB0LWrAskJjwSBZOz8/Nzb8vLyEmdnZ9HUO7VfaEC2ETNObnYgC6kE27SnkjOdhKbTKRqNRm5C5sSXfUiBW+d8qApRq9VyxI+E4fLyEhcXF77N8n275kRtxr7BwVzX8zDgwevTlB9N12Q63rYU5HWTkO6Dp9fGewqliIaObx0IHpNObrfb9Xbksy2sAiA3CdLZZYoTU3SoANlF7SQsjUYDy+USzWYTwMNkqdf43PYa6g9KYJR8q6IaShtTJ9gSE16rbirLdkgbakCG6aEkVmrH4XCIwWCAyWSCfr/v7WnLlJNMkDCQRGh/4TVtK5AQum/tp2pTFoXg94qOqTa2x+Y9cu4hGWM6I/8mOWV/Ho1Gnkzc3d155UxJLp04guMrifXR0VGOnLGABbB9pSfUnux7sXPHVPOQk6vBQ5IvtlVbSIk+BdWdu7u7XLCAtqRyTvvxt+M4wXHLBsS0r+0DsbFaSVnsN40RiBDZ5VzB9qoFxyaTiR8ftfiH2tGSViVnMWJ2aGJhx0dg80Jb1iew7VWzLjT9mH0+lEpvr0H7SpHdEimqLkqRM+fcCd4Ts7+WZdnf+vD21865Lz+oZl8CeP2UC1Bnn1V9uKbg/PzcR6fv7u48qdDBmxOkJRXqtGuuOaOPZ2dn6PV6Xi2LRc5DsCqPfU/vjc92Ytk2SdNz6YRvy4S3Wi0sFgsMBoOcksJrjykU6qCo2sEBmkoEnTS7ka9VY9RWOgjrxKbv8/OhydeStG3aN0QmOAlpxTj+jwEFkokisqs2ZR+wxT/smkE6/FqwgtCUEBIxpgVrSonmqfP8nGR1jQRJyrZIhEJtalND9XmxWAQdjZCTYR0+Pbau4QmliWmKKI+vihkDQsPhMJc+puXK2U45udLGdAC1H2zLyShyTNXG1gnWPhhTJUIP/a1sOrMtVqP2pENBB5ikYjweeyeYThwdOYKEgUE2tmWbGqVBum20VUsi7PjFz6xTIRQhYmYDQHYrEnWCaU86uuPx2D80BU9Jro6heg20dcjZs2l720aRs6rzZWg+jR1Dfys7ttjy7rHiH1pMhcRM10aR5NJGQD4opnPqvlPxYoTM/h27pph/YhUzDQDp2Mr5SDfp1sIfqiYWkQobrK0CKXtOP1jX1jluWT+D74dSGW0RKttPQv+rgh0TyqNMtUYH4C8D+I0sy/6i/OvvAPhZAL/w4flvb3pybZinp6c+SrpcLtFqtXy649nZGY6OjnB7e4t6vY67uzsMh0Pc398/WhQJPDRyOrCc4DSF6csvv8T5+Tm++c1v4tWrV/jss89wfn6eSxVTh7RI5bGDiN6fnYh1oLMddRtg52Y0CwAuLy99usp4PM6tO9N0Iw6eer96fRrV4STHtXqvXr3yCs+XX37piyyQWNCx4LEsIeP5SSA04qu2pQ3p6DabTT9B2D3UnhuttJEtVcw6nQ4WiwV6vV5uUqbzaAm5vrbkTNUdkjHatNvt4osvvsDZ2Rmurq5wdnb2qHQ+VTASCXXQ7u/v/d9K1Cw5Y8EGrr9sNptYrVY+LY+pFc8laXaSV6dJNyxdLpf+nCRooWCIJSdKIOiUUcFhEZWLiwu8evUKnU4HFxcXnqyR6LM9kkQwmj4ej3F/f+8VH5Iz3aeH96eFHAD4ojG8XpsK+xSbahuzZFRTYnQ937riKnps4OH3VlKmqWFUdS4vL3PbZ1AByrIs1x5px36/79PG7u/vvXNs1+9pMKhWq/mNfvlbtVotT3y3EUSwNrUkVwMHqu4CeYfdQn/nUECGwQGmglLVZdEaOsB00Pr9vie2akeumVLSFbomKmfT6TT3zEIWmvb+HFh72vu38+Imc6K2UT6H0u05XnJc4zYEmno3GAxy6Yy0Y2wNGc/HOUkdaLbhXe4tZW0VC75YW9l+H/scn63yrqqOBl9ZUEYVSK53VIKrqYy8Hp6PQc116pl+bx9Y1yZj42gZUqZkjGMr50X6trqRvAZgi+wR8ltDzwnVRRnl7CcB/HEA/8w5908/vPcX8J6U/ZJz7k8A+NcA/tBTL0InKgBoNpvIsswXBgGAwWDgF5lqBTTt/NrRNarDtVZMu+v1evjiiy+8o8bqYho9tx1RCRgH3VAkIxShtJEmjWZvU+nhIMFJg44vHf75fI5er4csyzAYDPy1apEVq/boZEqCpSWyWZGNxEwrM2oxAE4YGhmj/TgAMRKsE6PN8acd6cQvFguvUmRZ9igFb1t2VTKhJK3ZbOby6dk+tD0Cj8mZjewq2WXqHVNCqUbaVEbnXI4YUlFQNeLu7i5Xwc2SM7bN2WzmU/DYJuhch9Ian2tXtYFGtukYMpjA9mKLrISi5sAD6dFUSRamUFWXQQPt8zwPo720I1PGlJxROaPdlUyo+rxardButz1pZzU3OvYxJ+kp9gw5vBqs4P813Xbd8dQxo51sQRVVITXzgO2SKWJ8Ho1GnlSwbdpy7/y9NZgzm81wdHTk7aj7om3bCbZkwqqR+n87VpY5ns5RWlAhtBk6+zvbGtMZx+OxryhIxUcL1tgIO88fCooVlYDfhYIWG6NDZCHmeOpnbZu3almo+IdmFmiFS7t3lLUjQbVWVbOiFLxDKJE6HsU+H/oeEFbMlJxxbtdCMrQf58RQxcVQ2wrZKqSa7ZJUFLXDTb5b5rM6DoSyR0LFP2JpniGEbJYI2stCmWqN/xuAWMv7fc+9AB0IOAm1223vENZqNR/F7/V6OD4+RqfTQb/f9w4yHXmd0DXl7uzszK8zY7GKL774Ap1OB59//nkuBz20LkpVMl3MqtXauOZEnUir4Kn8zxQdfk5t8Vx78niMZvEaa7WaJ7lUdzhwcw2a5oPzeCSUdBS0wMjV1RU6nY5Xd7jWjIRQVUgAOTLBwZtpYhpFZ/SSn+H10EmkGnd+fu6LnFDpseT3qXbVtmmJGYvXdLvd3P5Czjnv/FLxAfITM69NFT8GB7gGqtfreXWHqoS2UXUK2A4Z+R0Oh7i7u8NoNMLt7a134rjeZzqd+u9yYqAzqA4eSa+Ss20pPVY5sylHuvZF07Ji51VFg+SdvxUJL5UzqjskFZwU2S6plg0GA69QvH371q+P4v9Ho5FPfVTnh+mt/B+3r+Bvx+tjf1Ml6ykIESmt9KUTv45PRdFeS3bV2aXd2E65fQb7PFUs9l9mOYzHY1xfX2M8HnuySzsWFQnIssyTMudcbs8pW8hCnZLn9vsQ2Q2pPaqcrfuNrBKpYwqVa7ZZvrbreCaTCW5vb33xj1BFQdtv2M60rSgxs3tOWWVjG4gFEPQ9SyZC/cK2UavuqBLJAAILq3CssRUuue6RdrRKRcgemrVAh7lMFb1tIhYwi7X9IrKrr/W3sGq8pt1zfmfQgPZksFLHx1AARdulJbX60O/twpbr7FLm/aLP2jGF7ZXBZl0iAeBRv4yRf+BxllOozyQi9vKwUbXGXUEHBAA54gLA75fBtMPb21vc39/j+vrar1XQRc/WQWPqEvcw63a7nkDwmFoERB1qdgjtLLqoXTemZJSDICnhIKYTRLvdzkWfOAiqPZ5qSw4AdG65zsymNXLy4uBKO+pWBTrh6eJ/EtpXr16h3W7js88+8yrP+fm5d7R5bxxgNf98OBxiNpt5IkFnmE6HrdrGAY1kptFoYDgc+j2V5vM5ms2mJ+dMySuKaJeBEgkWp+l2uwDg1RPeG6OytKM67fxtaBPd5kCLfyg50zZKIkECyPaoqXej0Qj9ft+TspubG69OkHhpQRA636PRyBeQIVlXkmQdKxvxfkr7VHLWarV86iWDMVTNVYUKTaC8JjplbKskC0wPY1vlOlN+1jnnlZvZbOYJ2P39vbfp9fV1roIb0x7tOhT+vnxvuVxiMBjAOZdbf6ntNDahlrWnEl6d7FWJ1IIFtCm/HzqmdRw0IEJCpql3tCfTcZgqxr5Mksv2yP9pWmhonY6OxUwNZaXHer3u0/HUGeY4+hwoiVHVQB8a+FnnyGmfYXCFZJ3BHgYSmXrH/zOrgG3SFv+wdmQwy6rNmoapUXmbihdSibZpzxBJs//bxKbaVnV9OddCa7VgkntN++b4qOv0bHsMpYbyfZuOZ7+7axVS7Ruypdozdn59T9s2g9UMInDu5fjCfkn1m74Rg6ohW1hbriNoSoz3TTJsMDJky9g16Wdtu+c8Qdtq8IzzL+3H/lykyhL297VqZFLNXg4qQc6A/IBLh0kn2/l87gcDpn9wjxZubKjkjIMJU++azaZ3fOmwkSxplTiNgnKC09Q7rYZFh4MPXgMbPiMhWhJ5uXy/eaqmF9nJZlu2ZCSZ9sqyzJMKRv3VMWW0W9fPaNRM00NZROXq6grNZtNvjso1E1qwgmSCBIYRSq0yNhgMcH9/7yPC3EZBlTNez+npKSaTiVcfqAoy6kQyrxHV5zi/fFYHoNlsYrFYoNVq5ar3ZVnm2wrvm+1BB2U6ZUy3oVPGNFHd1iGUzkiCy4mRTvBwOMTt7a0nZ7QnnREl31QkmMrK9siF8qPRyCvZ7B9KJtQ+m9pUiaqqyuyL8/nctz1OUBwT7HHUYVYHwu4NRwKsexnSsWIf0H2i+v1+jqiprenQalSd10TbsA1MJhPvZDOqHCpk8VTEiISqZ5qJkGUPhX+Ax+s+aFM6ZzqW0uHVtZG2XL6mhbLwB0mtquRaLCC0pof3pETXqj32d+D3tzWOriMT69JE7TH097BFarSwigYO2G6475bdFFntqOtR9LfltdhMEI3Il02Z2qZNY0RCU4NDsIqbDSRo4IvtluOcFv/QbAJNw7PrekLkTAMHMed5H4RinS1pL9qzSCGybVvX8NmAD/0FLf6ha6RCpCJkRwC5awsRs33B2s3+z/5dhpjpezaVUYt/0F+g7exymU3blI4BlqQlVB+VIWcAHg3OHBzoSGlZeC4q50CrC585CdLRY3l8OrucBLWUsC7O58CsOelMHxkMBri5ufHRdDoeutknQQVQ1xGxSMbnn3/uF9TT+SVpe44ywfsHHsgJQVWHka9ut4t+v49ut+udezo8qgLQTrTjxcWFd3ovLi7QaDTQ6/UeVRXjZEiVIWTH4XCIN2/e+FQ8rpEaDAa5yoJKzpQoDgYDdLtd7/SdnZ35qLQO9iRuT7Ep28XJyYmfRHq9Ho6O3peg5uSVZRlGo5FPwbJKqiUQJA3tdtsruVdXV7mN0akq2Vx0XbzO9mjtqEUCtNy7RtR53Ol0inq97ifWyWTi/1YlexuFQSzR5bosTXN1znlVmg6qKqi8dpJJttFOp+PTXTudDnq9nlfOGZDR1FAluFTJ3r596/9mAEaLVoTW82hb0RTr5XKJ+/t7ZFnmA0pHR0fodrtwzuX6/HNIhdrUOqdUylkNF3io0miJEK9FC24whYlqDgMxVM95LqbeUc159+5djuTqGMMxO+TE2XsCHlKztV0wOEeF16Y+PYegWQfVFllh0IAqnV1DyGPwWdNDOZ9R0dEiPOfn5/73I2lgoIVFVGJ2ZNAllpJYRHY1rTFElJ8DVRmsbWMpjjGnXOfGEHHQ4h+0LYkZCRjnIFvURwuLFSlmhLY1Jbf2oQ71NoIGateyj1C7tHYNEQfOVTpuMtCrQRYt/qHjY8iOMXJoyUTo732pZ5bgFv1uZQIzNp2ZQWZb/IOBqqLKlkXql9oydF32O4moVReVImeEEjRGD5vNpp8IOcExwh4qCKILVzWVUNe2aNRCFTNNv+MgfnNz49UdOh1v3771pbVD5IwTOVPGuC1Ap9Pxjr1Gs3k9OqA+l6ApoQHgq8cxtY3OtjrjunZPyRknOlXJmM7EgZtpEJrKSHJGVYLpi+/evcNwOPTPVNA0xVK3TuBvT8dIiytwAgGATqcDALkKbs9Revg9HVxZdr7VankHaTQa+egX1zXQgVcHjamkJOZUctUBVnWHJCTLslxqLdfysHrp7e0thsOhd4Y13URJtzoKmlICvCc8k8kEx8fHGA6HOD4+9qmGnEQ0ZXZTm2q7VnvaxftKJmhDtSXbAvs5+zQdXhb54dodOhccE3jvdPQZZKFNqZbRGdZKgkUFKEh6+DkSQI3O78IJVnvGiqxo+nRMNeP/bfEf3XeLRCJUYIEOGxUzLZlPZdLaUZUea0sl47SXLdQQcmK24QTHSIQtDqLE0kbKVb1XVUeLf2gBEAZCNMtAszPYFjWFUQluGQc2pJ7tQzWzNgXypMA+QoqyEjgtVGPVXbZNZgFYxUzboo6NaocYyeV9cOy0Ss8+7BiDDeqWIRdWLeP4rj4U+zmzlzSbyCrgtj8WETOL0HgasuM+7Wrb4bqxxQYhtC2Hin9wTNOUWA38hcjYuv5tCe+67yRUD5UlZ/pgQ+bg2+l0clEaLQFPaAdgiohWF9IJluCAotF0krLXr1/j7u4ONzc3XqFghJ3OB5UegtFnVtrrdru4u7tDt9vFYrHA+fm574QkbOqMbCPdSY/BiUTz8kejEVqtll9Dp7ZUcsZILxU0ddqoUnLg4fXTidJUsTdv3qDf7+P6+hpv3rzBYDDwdqRzrIuJQ6Sba+dIkFhQgApLq9XyhM1GWZ/iuClx4fdbrZZP4eJvTbLGgitUgmh72tKu3WEJclUg6WTQCeGArRul9vt93Nzc4O3btzk7klzYTWhtH9E2ok4y2yFJLlN/j4+PfeVBttOnOMHaLunYt1ot3/YGgwGA92WZVbFTRZcRc7ZjBl+YHkolkum3/D/tyd+G7e7u7g7v3r3zdqRapuljSqhCtiTYThgAYcow09IajcajtVJPnTjVEeC90YniBukku7r/kCXpSpRpU6rrJGa07cXFha+CyT5Gh5epycPhEDc3N17xYVptSDGLRYKVlPFZlZ4Y0X0OQuqOJQFsdyTf7DN0REPqjjq57E+0J5UzBhjYz0ludfuBUEXBUOrTOqgTrQ5hUbT+uYTXOqsaGNV1Tla10N9FSbHt//QNOI5yXqK9mJrMFFtVwjU1lu2tSKHgNSkBOQRJK6uc0XZAfgN6/YwSB7ZX2pZ9ncE7qmWh9Y7WjpvYUov6xB6HglXS2Cf0mkLjhx2bdR1flmU+iBtK8Y6pZjGE+qm1WxVsmbAelSRnhKpKOpBz0LWRdXXidWLkd/UYOhHweyQDTHe4v7/H119/jX6/j+9///u4vb31zjCJm67nsQ4wzzkcDtFoNPz6FZazZ6oUI/i1Ws2vo2IHtqmJT7UhX5NE8RwkWZoio+mh1mFjsQWul2CEXfOmVTHjYuvb21sMBgN8/fXXuL29xbt37zypUDuGJstQdJ8kmsoE8N7RXywWfm2fFnjRdTc81qZ25G8KwBcb4Ro+KkksVKC59/wenQnaT9dAKYHg5MhrV/WRDtv19TXu7u5wfX2N169f54pW6OJ2rYBnHXLel40QU6G+v79HrVbD/f29V52ZxmmL9mxqS+2jVD15rb1eD7Vaza/VOj4+9mSdypRzzq9z0mps3GOL6ba6BlKLtZAoUK29vr72acpMrWXqqKbXavGPWESdvzk/yxLwurmtVsazE+9zCK86VyRllpzZcU/7Ofs00xlJIFikhunhmpLKVFQSCSq5TGXkOildH6brHmOpY0rElTxoIQslKCFS8RxY0sq2yLa0XD7sxcfrVWeM479mepCMce2uEoqjo6NcMEsVXB0fQyrFOqVHoYrPunVn+js815Z8LlIi6ZwX/RaWQLCtMtDFtg8gRyL4rJt020p4of4YChpo39HPWX9kX05wTHXUvq5zmP2sVcuoOpKUMWjArVrsesfQOj2ijA0s0SlDzLbRLmOIEVz7GXsNlgxrIIG2pa+ky2eKAi1FxKzIBqHP73sNX8LTUWlyBoQjZhzA6RCHJiUbtQilT+hndZLSvaJYXOHNmze4vb3F9fW1T8NjJE7z1UPVGrWiIyMlHOw4AALv0/Gcc97x3UY6Hr+r0UpVgjQKzH26lJzpoE1ixxQHOh5qW+BBgaQzRrLLKoLX19e4ubnxa850PY9dTByKSvH3ZwXKLHuvBvb7fU8mmDZIx5zXROLxVKiSA8BXb2QqG9NoSM5oSxIaJWeabsdJkE6aknK2Sy2ycH9/71MaqThq9VC1Y5Gzxbahail/s9Fo5IuvqBK4jXUU6mxlWZZTe1itsdPp5NKfeN7Q2j3akASCtuXDrtnTIgtMaVQ7UgnSqGaRs2BtwPRPVS1ja3u2pfhof7Wpolw7yKCCtgVVd/h53ZScypnuvcXfgGSBTi/VQRKyWCrjJs6rdX6t2rMtG8bsadUzjoe0m814AB7Sl9jnda8tDSYwBY/jspYhpz1pQ7Uj22OITMUcMPYla09+t6xD+BybWtvaFNGQQ1xG3dH9DKnys89xTNRxzKoUfFj7hGDn5FBb3qUdY/ZcR9B4HaHPWWVHU8XZ3tl/bSqjkgrblmL3rvOGnUMOregUzWcxoqbfi40dmrXFebZIsbYEVxH6n7Whvc5D2zVhM1SenAGP5WR1kGMN2HYU+55+R9MZWSnw3bt3ePPmDb73ve/h3bt3+K3f+i1P1BhZZ7EQmyes16CTOtevNJtNzOdzn64ym81wdXWFWq2GXq/nryvLMp/a9twIkab0qeOwXC69MxwiRKEIp1Ugea/8LidFLev++vVr3NzceAXy+vraVxJktc3YZGnvg7amY08SROeGlRS1Ip9WcHyqPTVQwOMCj9dAnp2d5aKJ+h06aiRiIQWSbVaDBSRk19fXXtFlKh73jmJaqm4/QNIVc17Zn7Qf0NkcDoc4OjryRWPo5GsxhOcQNLYp4P36QCoP/O0A5CLdtKUWVqDaqOui7FYVPAcdC+6zxdRkriNlYRUWW7DV24ocrpCjxEANj6VOti7+fi6x0HNqQRAGfaju8hx0rghGdZXMsqCCVrVlAEHL5ZOEMf375uYGt7e3uQqXRWW1ixw4DU4wgKCEt0g54zGfOm5ashsqsrJarXwxkhAxYx/XAALVHa4vpT35mzC74vb2NlcplHOTjiv2ftdFxVUxCxHdkIK0bYfOEixVJDXYpr+BtSl/C5IyBmRIdhk0YJotlUcW/6AdNQi4KTEtUnn0t7Ht8TnjZew61FYxwqup6EB+GwISBrZV3bCbwUcGsnQds6Yo67pgvVciRtD0HvS9UACnbDBnWyhSz2JKmooCSsbYLpndQ4Wc83sROVtnxyKUacuJtFUTL4KcKWwHKdOwQqRMwQFVlR7dd4vre1gtK7SeJxZNt4P28fExVquVV3kYna7Vari8vMTR0VFuktHS5c8d0K0DoYMOI+jWnjrYhAb8UHSGNqFDqsoE1Qk6dbb0Lgcnm/fPa+HaMp6X32N0VAuwaPqKdYKfak8bJGCAgGRKqwAqWbcRX10LqcUaNFWQgzXbGtsl1R1Vy2yqnE2PiIGfseSak4au7bGpeM8lFLRbKB0PgE+f4/tqSyVnjJhr0Q+7BlL7N22pSg/biqbp8FHGOYi9Z9PG7Lq1baiQtKd10NSRpVrPLQpIWJ1zubU7usm6qhF03KgYazBLUxdVLdMtA2LETO2m5EKzBnidOjZYB7go6vwUW/K3sO1TCcViscilNRKa/mgDCCwAxLQx3qeOl+tUHksmLCkrQ3RDhGGbNozZNaQq2DRHJRMh5VI3QdZn7mPGMVOzMXR8tI5waDyz9277ZujvIod6F7a0do2pj6FAtdpVx1i2WS30Q3Jm13mWIROhv0NQe+7adkUo8hXXkTT9P9ur2pf9zir/oTTYpxCz2PyRyNfLxIsjZ8DjDvEcaGdhZSyu5Xnz5g2+/vprvH37Fm/evMltpqrromKTJPCg9NBh50AHwEeearUaxuMx2u22r1LH9Dg+P3V9j4WSMToGJBgx4hJSBfR9a0c6wCy0wHV6WryC6XgkVkooykTcaDf9DInvzc0Njo6OcHZ25gts0CndRuRSiSkHYq4/Wa1WviCJdTY1t98qkJrGqLbUFFvaku2TUWFNZ9SB3w7wIeJNMGoNwK8HYgVPdbpDjuJz7QjAVxOlU0ZypY4VHUslZ1oQgE4bHWMenw4Z+zfTQVm8gv2alULt4nZ14NapE7wH7Sdse5rWGCtk8VyCpgEDW72UtqOTxc/T4VUiRnJGJZK/B/CgvGnQhfbr9/s+LVQjwyHnH4irPao4829NbQ2t1dh2eqPakynoVjmjkkz1jGO2Lf6h6g7Xb1KB1D3MVC3T9qjrHi2xIMo4cjYIo7+LBhK27RSHggdWOeND74tt1hb/4D57LP7BYBfJGIt+3N/fPyqiEitTvm68tMEC24atc73NtmhtyeOGCG8R0bXkgZkcmm7PrIOjoyNvO84B7Pu6flQzNCzptwjNQWVIxSEIRkw14//s3zqX6bjKNp5lWS7YYtfV27Hxqff/XP8moRp4keRsW9BJSotXUC1jig4dDy38EVq8yWMCD52XSg//x7UCw+EQwHvng/uGvHv3DkdHR+j1er4KYqvV8oOpHve5sNHhMg58LKqkk5EWW9B1Znd3d76KG523UN6/vYbYoESbayRwOp3i6Ogopy7ZYi2cpJ4LPa8O4EwLtI6TEjoSMRvp1HvTdVF0gFlsQdeZMdXEVskKTZKxaDCJOtuqpo9pCXgbNd2WA6dOMIBH+5Bp1UvaXKuv6gTI7+lvoOmMXOPY7/c9udW+relO1nkrq1DY9DEdZ5SYhUjFc4gZbcP7p9rIPf/4Ph0rtj1GyemcMUWU79GJ43VSdaQttXgFSbyqZtbxiJFc6/wC+cpy+l2rSIZSyJ4DGwVnO6NCwy1JdE0hbc/gmqbaacEfEl32L6uEq5Jrlf9Ye1zXF7Wvx2wZst+2glk6L4bW4eheT8CDwqd93W6LQ1s2Gg0/t1olV1OJbWptWWJmg2z6/zJqz66Ihc7HlpTpekiO6xznnXsoOMZ2StsyqMDAg6YwKjHbpIBKkQ20bW07IPBUhNTGosC0jhMkZbrelMsEQuM/8HjtJ98jPgVipmNEwnt88uSMA4ymxjEirM6bXdgecoLtsRWcvPldVVHu7+/92h6Wt9dJmY7mtjudHWBi5yhzXiVnsWILofSxIjLBv4vAiYcOp13bY9e86GT8XHuGBvAsy6eH6nnsIK/r9QiNoun+WDadUZUsGyywg3zMntZZIzFTB57OTIhQxAIST7GhEmbaEECuUhivUSPvjKprGolOouzb7N+6B1fR/luWWIUUilAgg9dI55KvLaGwdtzWxKT20W0K9P+8R45BWvWOaU1UJbU4kaa70gEu2sfMqo+x+yx6T8ku3w8RC/t7bcuW6vjaIivso7oGk2t3WJnR7rFJuzr3UDQotJdZaK3euvZYZM9Q2wx9Z5cOsrWnJWi0L/AQ9NJql9o2dV842p/jvbVhaO1oUSCQz6EAgg3uWRJXJoCzTYTmlNAyBP6uum5cAw5ajOr4+DinMtr94GyK8jqCVsYGVSQVIWKmSqR+zrZpnZcYONb5MzRXJyQoPllyppOcTcN7+/atfzDtyRKmTVI/Qg6zLsjnotu3b9/i+PgYFxcXODs78xFXDqYx5WpbeMpxeV+69osRdZbLZ0qjbjLNlCclZuuUMzuA839MKWKZd6YBkhDyt2N0NrS27qmwxMuSs9hn7XvaHqn08Npvbm58lVBNZ1RnOJa7rnaKpd3yf5xkVEGj86hVIG0ZbxKp59rQpuGoWqNpTqFIsVUgaUdWuByPx14FZ7VQBmB0k2lNVY45ckUBGf2NlUDEilhsk6BZoqvbHZyenmI6naLVavnroc20oqAqkHQs2L/ZDqg6skAS+zvtHFqLW5SGt+6+leCGlDJL0LYBVRroaOnWLavVytuN16DptlqdkWl4/D/bGccn3X6ARVTUjqE1dmrDTR1iGwwq851twDqxGnhhCm6WZV7xpZNLpZIVbrkWku1b9zHjsgNWAGYqnl3rGevLRWMlwf+VGfdsVsS2oeNhiOza9hIq/qGK+Wq18nM4nxmMCQWvivrzU9tT0dy5LxQRM52jGAxUu2u9AF3mwfnIZmXo/FIm0GLfL2Obos9UjRRXHfuy1ydLzgjrEDNlhxF2q/I8N0JLJ5iTLtMpmS9vF4Nbp7uKHckSNLWj7i8Tup91k2ToXIQSm5AqEUsh03Ns054h0rXucxZ6D+ui6qECCxYhZ8OqEkA+XccqE6GiDmrHbUBtwmtRx9f+7naiVCdIr58E0/ZtW3AhVuK9LDGztgUe+nooHU2fN40yl7ElnTXgYT2fOhVUAGyKk13Azutnf7JkPRQgiKWMhexUdO9Fav5zHcAyUKKtyhkLrCgBz7LMO75My9VCKlwDCSDnsIXsuK7ggtqgLDGLBbYOgZByxu1OOPYByCmVJGe0rRZT0SI/mgoa2guuyE4xYhZql3xfX9vP7Hqutu0zpuAwXVyVSK7RpcrLMYLpjLb4R1FabawdhhCbv8vMl/uwpz5C79HGvGZV1fXBOcxmSYTsSBtYlO2jVfULE56HT5acKaHgQMQUPEYydR2FLoDl4tennBN4SMXjug9utnt/f49Wq+UX1bdaLb/g3KZTVAXWjoy20Y5cb0ZbFqXqPCXyRrto+p1WNrTpf7sgFoqnDpJKJtSOLF6ha80Gg8GjiVMXZBeRC3tOgnYMkTK2/VARi20GDVR9tGu27OdCz3r/akem3qkdGWEPldXeVlSY96JO3yaT81OgNtFtIxjF1d9OSZw+8zsAcimh7E9Uztge+b51iJ9CdDVwwGvUZ9px3f1vC7Qj8LDxPN+fz+d+C4gsy3LbdpBAKOHVlFCqO9ysm+se1Y5cnxxziK3NYmSjDCkLOabbhHVulehmWeaVRfZ7ALlqo1qsgqqELf7BdbjarzUgWJQCu4lzXETK7BrkfWS7aD/W9FASXv2MpjAr6a3VarlS+XwdK6JSNK+EUCZgELs3vrbv7RqW8OprQtNENfMgy7JH+8GVXe9oX1tsassQid+nHT8GHMJWnyw5A8KpZJqzzsE/VADhOc6UdVhsGWVdVK9rifjdKnUqex9cPGzLlNuJct2gtEkETh1yOkG65sVG/qqGovZgN/bVhe0xtccSmrI25URur+cpxOQ5UIdcnXL7Gft5Xhvbgq4jVRuq46ETpn6/aKIsskPsevm/fThvluAyksviALx2nlML1KjNLcm1WzloBTdd0xOKqMfUiCJioba0yqpVT3dhS22HmopH0AHmfdLOmtZEBVL7tWZmqJJrq+DF+rO14br3YyTX3quS8l2RNas2aNEavsfr06qXJLwkc9q31ZahteGhlNoYnjo/xPq2td+2bRkiZ7qRtHMP288wRZTFP5iCyz5O29k1j+uCmpu0zdDnrE2svWLtcNeEV9uqtk2trEy1nG1Vi3+o/cpmFOwbVfIlE/L4pMkZ8Fix0EXaGu0goXiuc68TpCU1PLc+dpVGtk2oDZXk2j1mVJ2IRdEtYvbW1AJL9ELpeE+J9u0TllTY/bjsuigluDEbhYhVmesIXde+bRYjX0XRQbWH2tEGPUJ7HsXWkIaCAPZ9+3mNqtvfZl/OWyjiTAdOrz12LbRHyI760Gqam0aEN7Vl0T3a1NZtwao9iuVy6ckZ8BBF1wI1HKfopNkAoBZI0qqrodTvUD/fJJgVso0WJbIqwbahv5cWTciyzBej4Fps5/JrIVlIhdsW6DxtN3cvWs9Zdjx7zv9jfWqbwQM7Jmo6o6pnJGf8n6bcqj01zXZdem1sHi0bLHjKvYZe7wKx4ITaV9VIBmDYVp17KLikxEwrMz4lcPoxYxcByo8Bnyw5s+qAqiwclHRzZH5Hn7d5HSRpen51Hrd93m1AbcjJUCdMTdOhvM97LYq2FTkjhKbhaafm36GJscqd3xJ1jQizvDadEFuZcVvnt1Anis6prYi4L7uuO4fty3Q2WPKdD1Ze1cImIdXMpjSGJtRQ+9S0rND/1HlS2+6KWPC8et0hoqskSNuiVmaMVV4tWtvzHMVH1R77vhIh2jJUGGYbsOOI3VqA1x0qUEOb2ErA3MuMfdymXq8juNtyhNWW6oDusk2qCqnVK3WuVcdXFUiSB9qPNlWiu2/VzN5bTEXbBaxqxqI1TBNlMFSL1VBV4zipa8OV5Ma2cNgHYiok73mXc05IMWM7tH2cKaGnp6e5uZtjZ6iK8nMDnkrM+XfRZ/m8z/m6LOy1cE10FbBuGdGu7fjJkjMiFlU7hGIQ+rGr1JFiUIcupliVteW6KPE6e4SUAutsVMmmIadL188pUacd16k8IXLxVKhqEHI+qwa1oQY71NnYVRn7GOxkr+1yH20yRG7WIRQwUkW/aP3hJniK7TWazWcAO2uTShJ1wraOhF6Xfk8DBnysS3naF6wysMtx0tqR4wqJ2XK5zPUJkm5+Rvu1Fq2IEbJ9qhIhdSykmm37nFbdIelluijbolbH1LQ8bZOx4lmxeWWT64z9BuuIRcyeuwbPpetxtTqnBodoYx0v1X5PGR+LbMb/P/f+EraDXdnykydnh4YdWDVSqJEafraqUHJmS4XrmrmnHtu+du5hw2cFB0xORNzHRasoVc2OIdsVOcPW6XiuQxwLCmh7pB1Z4EAJb1VgVXBdxxmqLGij6uvsWMbOVHQJS8h0TYi2yX3Ysoz6yHtQJVyrhaqCGyJoZRELLKy7fht0sVHtXRKLEEmLRbCVSGhKqN1o2qrhVo3bBbEIOfWhxy6VMwA5YsZ0MaaMAciNMSSzNqPAVl19TuGnTaL2RWTMKpH7IGhMuW00Gv6cNuWW6bbOOV8qn2twQ9VCNw1gKaGIqTuh99QuIbVsn2Oj/R1pt8Vi4dfnMZ2ZvsZyuXyk2mo6Y4yYFdlLr2mTNhwjtKFx8ZBzd9G5leweAjp/H9Jmnzw5086v5Wc58dMR4Ge3eV5VJUgm9KE7zFfNEY4hNMDsI+Jlo1yslMaH7tu0zwjcJrBpdfoIRS11EKbTWEZdXIdQsEAJRcgRPrQ9rdqtlSbLKhTW+Q7Ze5PIp3V+bUpjiJhVxY5KckMq7nOVx1CbDR3HOhrq+OpYvU9n2F6/hW2HVr1V5Sy0tsw6sJs4bWXbqCVlatt9OCQ8P23I61Zyxv9rW7TrokIbTAPrHdqYndjvi+xo25cquCH1cdfjo/YHLVjD1FsdzzUrwwb+ygRadFwM9V37GogHMBQxIlHkKO8CdszmfGfTtjVt2c41RUsOYnO1HUvK2Gzdfdh7OtR8bc8VI48KVX0PgaJr3Jf9PllyFhoE1IEqSofbxEkLnVehAwDVCW5iSme4Ss6bRWwACl2vbdxFTtkm0N+QxIz590py1ZZVIBQhlCVmu4IGK2JqT6jkelVg7RdLsVXUajVfxfC57TE0oagtLUGrWntU9cwWVontHVWEGLko8z37Wu0UK/+/j3Fy3TxgAyxKbjd1hvWcRQGDdTYNzWF2XU1RauM27anHYoGamCOv/Zd2s3a0SniZdLCQY1yGBCsh498xUrHrQKAel9fDIiDOOV9dUK9rPB57RTxW6r2M0mPfKyJlITvq+yHHV22363RbC+sP6h587Ndst7YvF80zIWK2zmaKsmNmrE0e2oekXWPXWDWE2qWdl3Z53Z8sOQPyTtPJyQmazSYajQba7Tba7TZarZaXsrlg25ajfs55dUf5ZrOJbreLTqeDXq+HTqeDdrvtVZ8qOsJAfkCwRJMbXJIgMX0FeIiKhqLPZQchHbCtLfU35F4u6oBUFXaCt6muulZJbWjXmIUinHoOOwmEzklyy37BgIGmx1ShTRY5yPa+rX1Dk1aIGOh5aD/rLNuJR8cWq4yzX1cp1VZVhxC5tRFha0sgn4qiTkhIkQg5K/pdPYc6S1piXe14iDYZG7us4hMiuLH2qXMMkC8yo3bc9Br1+FbJ1cc+Mgx4bNpA7xF4SAtVcqtVGtUpDh2XdlJ1zipz+h09t61oGrp2IJyyvM6O27Ynj8e2r/vqWQWXKqMSs1iAIOY4W1Kr9lxHykJ2t/egQSz7eh8EQ/uIbt2gRYCyLItWBtVjaDvUNhVSd2Nz2KbXzudQNhjtuU9SFCI3NvMhVJCrCtBrP4QdP2lyBjwYng5wo9HIPXTip8Py3PNpx1Eyw00h1emoYnTdItTh7No5bdjrnAsd1Pi3HeCtTZTs6j4vSiZ0IXRV7ahOfSjNVgeHxWIRbBNKHtTJs5+xCClm+ju+tDTbogmB/VmdNwtLKNTZUBuvO2fIeYul4lUZNn2LrzU6r8TC3pPt+zFiRhT9dvaxazKxCUJOV1GGgSUTsfExRnSLbMjzAA/Bg1AaXlGmyC5hCfk6JTJEaPmstgwFT4Bicmvnmdhvxme1n7Wl2ngf9tT71uvXQAsQLnqmx+BxQtC2pn2e/wMQJb9FsESiSM3dJUJjN8drguvCLWHVY6j9bEAqZJ91JK3ofZ7T3oNti0Xq4y6Jbui6LNmxbZOvDwnrJ2pgvGid+DZtWZqcOeeOAPwqgN/OsuwPOOcuAfwNAN8C8K8A/OEsy262dmU7hnXq6/U6Op0OZrMZLi4uMBqNMBgMcHNzgyzLMB6PATxEhlXmLns+4CHCpdH0druNbreL8/NznJ2d4ezsDO122+9JUlVnWDs6I9q8n/l8jm63m9u8FgBms1lusNKIE2GjxDwXYVMeaMtGo4Fms4lOp4Nut4tut+vtqMpZ1ewIPI7YkajPZjO0Wi2/GJ77AHGC4PYEAHJELERCi5wN2kaDBc1mE61Wy6uQDBxooZoqOMTWqdJJiW1yNpuhXq9juVz6zaeBvNLDKDydnHXOsr0GPb/akqnKzWbTP7S4ShWDLzZQoET99PQUq9XKV4QD8koR/+ZzSLGwJC10fuBh+wEN+NCeVHNDQYOq2JFQZ4kFBnQDa9qaa6fY7lRhs45xkcMXOr/Od6rkqh0PoUKGCJptP/bz6uBpxTy1HxWjECGx5GSToKsNYhWp4vuwoyWoem+r1So351rlW/uXjuu27cXstu657DXz99S2yYB1SIncFdQm+h6LgMznc4xGo9z4xD4FwO8pZ0mxLTpVRj3b1Lfka52D1I6hQPW++rf2V6qR9Xodzr3PerKZBFUjZ7x2ZoPtY87ZRDn70wB+A0Dvw9/fAfDLWZb9gnPuOx/+/vNbvbodw05WTGns9Xro9/ueKGVZhuFwCCC/sa9NJSs6j3YYdcLpAJNQdDodtFot7wRXlZgR1hlVUtHpdHJ7JK1WK0wmk1yEDyi2XUiZ4PskZrQlHTYlE+xMNtpRFXtap54TFNNC5/O5J2dcBO+c82lS9jh0jmljplOEFrmHIlmWTDAtlA5cVZUzG/VkwIV2XCwWaDabyLIst/cPHTmSXUZGdZLg5Mzz2AnE/oaccDS9Vx+h9aRVsGUscqxtYrVaodFoeDuw7UynU29L/R8dZPteETEDHpMJToyhzIaqFfuxDqf2LQYL2A41sk3HTtNJ2UaVaAB5BaQoDU+vxQZ/lJxZh0PteAh7htoir30+n/tAi7axo6MjH6xiH9Z1QlaBA8KpVOsUMz7Yf9WWoZTlfak+2q/0fRtooR1pH+DB+bTrSdVppr+jQcBNyVmISGiAkHZj/7YZRPsKZNHv0EAxbUCyyL3NtP9q+re2OWtDItbunkrM+MzfmvZbN+9s25baHm0gg0E99tlQynxVECK7++rjpciZc+7HAPx+AP8VgP/sw9s/A+CnPrz+RQC/ghdEzmxk9vj4GM1mE/P5HL1eD+fn5xgOh+h2u1gul7i7u8NqtcJ8PvcOHCdFThJAvFSsSsu6nqfRaOSUHiVn3OCQA8WhnY4YNOpFcrZYLNDpdHy53sFggCzLMJlMAOQVM50k1p2Hzzphc3K0a824ZpB7u+wzWrQpbKCA97NcLtHpdHKLuJ1znqRxAlBCoUSCNg45cJbMWEJBlUftWDVHWBEjFbRju90GgJzjyzVBltyyT9u1GDrB6r3baDr7uCq6JLk2XTqWRnQohCZTjlVZ9n69BZ0V9n0Afm0Q7abP2sdDDrGeG3hwjjQVVBUzXc9axSBWaIxaLpeo1+s+qMI5hDako8f5hQ4Lx0pNAYoVHSi6jpDDpoGXKiqQet3ap2kTVczYXrXKqC2qFHOWQ46zXkMoAKNKQCz4Ytem7BLWUWffCxHc09NTAA+qqwYC1FkuIhpW6SjjXFuSCzzYV/t6SM3dF8nlM+cR/tbL5TJHzADkfEFdI1lkM+BpClro3kPkTPu7FkdToruv8dK2PV4L5wSqZjZgUgWE/E0GsWIF+/R7z0VZ5ewvAfhzALry3hdZln0FAFmWfeWc+3wrV7Rn6IRF5+3Vq1e+sYzHY68E9ft91Go1DIdDv3eNRoOBBzVNCVUoms6iH69evcJnn33mH1dXVzg7O/MErWpOhwWvS6Nx3W4XtVoNn3/+uW+8WZah1Wp5FXI4HGI8HmM+n2M6neYmBjuI2YavaVZ0Lnq9nrfdq1evcHV1hfPzc7Tbbe8Mc6KsGthGAPjBK8synJ2d4fj4GNPpNDcYjEYjHB0dYTKZ4PT09NH+NNYhCS32VgeDgw6VXKbV0obn5+fodrtotVqVLlCjEyptSOLKaB2JZqPRwGQy8ZF43SuJNlTnriiFmXa0ygTTFy8vL9HpdHBxceH7vQYN9uW8rYMl68D7fZOILMvQaDS8Usa9urgXGosN6B4/1kFm6miof2u7tCT75OQEnU4HzWYT5+fn6PV6fpzU/n1oUmGVClbPa7Va/vqOj499mi3bHlNtZ7OZJ7m0HQMvmjZqHb+iawCQWytBR63Vavl+3e12fXtVUnEo8NrpwNXrda9OaBo77cdMAvZdBlG1/YVIWkgNUlilxrZNjabX63X0er1c6rKdd/ZhU1UtNEWPYyAJB/fu0u0yypCz2Po1fS90TXy2r7XPc47jXMM+v8+sDW17eq0MlgPwpJyBAp13bTsD8MhmfE3EiEmob4f+tjZVex4dHXkb2kyiXY6Vdj5hH+Y8oEEWm/ZZJahNtY5Cs9nMpdxu25ZryZlz7g8AeJ1l2a85535q0xM4534OwM8BwDe/+c1Nv75TcBDTiBLXTHU6HZ/WuFwuvYLGCYDRJp0AeEw+64M/oE6MnU4nV51RB6KqOBvroPfLiAyjxK1WC7PZDNPpFIPBAABy6aE2qk672ii7EkCNYHDQUcWMdqTTpmTiJSmQbIur1QqdTscPaLPZDLVaDdPpNHcfSrY06m4X0dt2SdvQVlxjpmvNVMm1k2NVbKkOidqRKiAAtNttH1Fmegr7tDpcMYKmjoh1MtTxVsUulGZrnbYq9W+1o9pwuVyi2Wx6B5kOnrah+Xzu7WedPbU58HiNj9pBVW46Egyesa8rKbNrIA8NdUqoWGsamdqBNuM6NAapdO8uTdGzjp86NTGSq78l56Dj42NvT67JDW0wfwh7WnKhdjw9PfXXpM6nzs3qLFtyVmS/IpIbIhK0p1Yk1vn7kGnLtCHtlGUPFQcBeCW2Vqt5m2nJ+DKENqYKhRRInXtjdmXbo3Km2Qf7DlLbuRVAri/zfc4nob75XNWsKBgYex2zqY6Vh1hTavuw2s76KFWEVctpx1328TLK2U8C+IPOuf8IQANAzzn3VwF87Zz78oNq9iWA16EvZ1n2XQDfBYBvf/vblbM+De6cQ6PRQK1Ww+XlpR/4F4sFer2eT4viBDYej72joqWRuW5AJxYOLiw6Uq/X8fnnn+P8/BxffPEFvvGNb+Dq6gqff/45zs7O0O12vbOspKKq0M5HJ4oDPaNLtVoN9/f3ODo6Qr/fR6vVwt3dnV+TpqWRNUKsx9f0OxZLocpzcXGBL774AmdnZ/6Zyhmdkao4byGoY0oHhPYDkEvTHA6HqNfrGI/HuL+/x3A49OqFOimMIFtoBJLOLwlYp9PB+fk5Op0OvvzyS3S7XVxeXqLX66HVavlBqUqEglDniYoF+zajXJPJxBP46XSKdruN6XSK8Xjst8ugihFSL3QCYURao2eaFsrf6/LyEu12G5eXl769WrWnKlDHk+vLNOV1Npuh0WhgNpthMpl49Zt9mKm3VIE0Rc86MfacHEPU8dWqqxw7tdiPbjdSlbRl3g/wUGgHQK4i72KxQLvdzpU212I1tJ2uAaItAeSCL6FzW2VcHV9Vyakid7tdTzKqsBbSRt01wMn1ozYIYPtrzGm2pAMIE4oYKbNtVdso5yUdKw9J0GywmESWRX1U8SkiZTofr1POYtei16HXpuRH5yWSCl2Ltk9b6vVpgJh9KJbCGFPJypCwIuUsdM8homYDXZwPNfCrY8Qubal243MoS6rKxAyIr93UcWDbtlxLzrIs+3kAP//hAn8KwH+eZdkfc8791wB+FsAvfHj+21u7qj1CG/Tx8bFP3Wm321gul56o3d/f+0EkyzKMRiOvYDClTOVZ/ngaWeMk2Gq18Nlnn+H8/Nyn4F1cXORSx6pOJizUIWFkqdVq+UFrOp16507XNmhqmU0rI1TZVEm5Xq97InFxcYGrqyv0ej2/dk/X9lRNobCw0WJOAM45tNvtXKSJhG00GuWqEdJR1r2UlFRoeq11fBlBp5LbbrdxdnZWWKCmirbUCZUTAtMHqViwXzKdlrYLkTONxFsFkufgMe06M67pOT8/94qPpujYlKkqQe+PYPTdfVC02H7m87lfHK9EQ9fnhlL0iJjDq21UAwiqnO0rRecpsMECdfC5xocpt+yv3FdT5xNLbEPkgudTB4L21Igv5yNd92zX7h06pVHvB3hQfmzfpvKtio9d96POs3WgY+Qi5ORq+7LE166L1GBNlRRIgtfE/lxEXovUn9jr2LXY55DSw2f2F7s2d9+21HPZjBHa77kKWZm/Q9cTei/Wdi2R2JdvGbOftd1LQGh8tWPDNvGcfc5+AcAvOef+BIB/DeAPbeeSDgOdyFqtln9vMpmgXq9jMpnk1k8NBgM453ypeFYrY6oeHRim3pHwnZ+fo9Vq4csvv8T5+Tk+//xzr5gxX70q0ctNYMlZrfawbw+j8M1m00fj6dyNx2Ocnp76SLxG3QmbekcHrdFo4OLiwitnV1dXXunRdJ2XQnTVhs453w4A5KrSMTDA9ZCtVsuvAVL1gkqkkjOd8EgiVJUgOWu1Wt6OtDXbc5XtqJE6tScLfKhDynU/JGcMtNg1LHSW6biwTyop42/DSC+JGZUePmvVy6q2Sb0mtSNJxfHxsV+vwnZGNcgGBuzaM0vOLJFQ8qdRatpUAzOaile1sdISC/7evKfVauVJhWYM2HU/NiXPKmihyLr+ZpacWeLLh9p7n85bEaxDT6JB+4UIWIiMFRGNWFpjiDyESIR1fDlXVYGcqe30+tUxLrJNGbJRREDsNYSux14bH2y3VQhiaV/m39YG6whYyC5l3yu6pqL3Y21X72MfCNkPiK9PrDKK2uy2sRE5y7LsV/C+KiOyLHsH4Pdt/YoOADupUZlgVKndbiPLstz6sMFggLdv3/q0PKoWdN50fRkXW5OcdTodfPHFF+h2u7i6usLV1ZXf6yzkcLwUsPNrKg8nfOccOp0OTk5OMBgMcH5+jouLC4zHY/T7fcxmM4xGI0wmE+8ME3TONGWE6YokuywQQKWCqWVVSncqA9pQHToAj9K7Wq2WTysbjUY55YzFLegscxBkmyI50Ig508M0fbLb7eZSdKrgbJSBTgZ8TVuyYl69XvepZZbQMkBA+6lyxoCDOg4a3aVKp9Wx2G5DVRqraktrQ07qulZPiRftFFrvE1I0VPEJEQhLZqj4qPJDZ9gqA1WBncjZr+kYM1PDEo2i9VEh5ceez/5mllDYVEebolOlNmlJro28A4+Lo5QhGUUEI+bg2r9jZM2mjB3aljbYsu7+yyo+68hH0bXECBufrW1D39snQjYEnkbAnvKZstcW+1+oDe8TocBACFUja2VJ8C7wHOXso4NOXpqa55zD+fm5/xwjxc45jMdjDAYDX62MRS1IJpScdbtdXFxcoNVq4erqKre+h5+t2uL2TaGDqtpR03g0tUwrDrbbbW9HJWeM7jLlk0S3Xq/j7Owstx0BHe9D5/s/ByFyoYMb1TQSeaZFsfoRi7DwNQe8GDnTtqp7w1kyUXXVTBFzSLRvaZtUxUcJmabm0Qnkb6HKmSUQuviaFQ6rUnChLLTtqeOq/VvTRZWIWaKh5AzIb5xsSUOIQKgaoeTW9pGqwV6T9mM+h5SM0BqW2LoWey7r1Nq/rRpqyUQVYdui2pHEbRPSsU7xCREGPhcRtX1E1J+KkO2U7BKbEo5NHOqQPcoQtdh394l1117WDrsmIGVtfAisO/ehf+Oy2Md1JnImoME1kuuc86lKvV4PFxcXuLy8xHg8xs3NDUajkd9kmRF4OhBcI8F0RW5w3Wg0cH5+7lMdGVmvaorOpuBkaVM8mI40nU5xcXHhFR8+TyYTTCaTR+RM15zonmZUekgkSDRYgOQlqo8EnSe+VhLAFFElFdykmqRC1R9OBnTCaE9NlaTio6SCv99LbZPqhPBvOma2MplVd2KKjx47RiDseyFV4qXYMuQkqSNsVYuYumPXSKkttW2F8vhD6s9LtKN1hNm/Y2RhU1JRhkjo6yo6wEWw11hEEtaRik0d5CLnNmS7qtqx6BoPQS6K7FRVGwLrCdBTsakC+VLx0q9/H0jkzEAHWjrEwIOCRkeLig/TGrXiIB0xEgeuNWHBBb6nJXhVmXgpDkcReP10pOjoM0pMoqFrV1TxseSM5I4klkSMxNlusPix2FCjm5rGAzxUCTw5OfHpeiRkTDdTUkFyFlJ8LJFWcvESiZkiFGEHHuyxWj1s8EsCEXpYR9gqD6GHBnleGqGwsO1Rya8qk5ZMxNLL9Lhl1IdQitNLtaP9uyhNat17RccPEbai55cEq1qE7jtEXvU7TznXpu+9FKy79lDa577O/akg2SGBSOQsADpeWZbl1CxN+5rP5zg7O/MFQTStkY4Z1R2WlmcFNyUVfNCJ4/k/BqgdNRWJ631YIETXR+naFcKu59FUOyVkmq7zUh03i5j6o2lQWu44lFpmj6WL/lUltkSs6iljm8A6biEViDZeV6lMj6G2Ah5vWmvP97HY0RK1dWlR60hFGdLwsdiQWEcmLJ6iWMRs9bHYkHjqfZax6cdmq6ci2SEhYX9I5CyCULSWZIBlo1lMgEUY6Bhr+hgruGnKGIsIvJR8/+dAHVVdZ0HFJ8syv9myrVBGhFLGQgqFpgF+bLbU9sjnGIEIrU2xx7Ak1truY3OEiXXkQokv/7ZkY1M14mOzIWHJRSyyvm6NSkhJKjrXx4yi+1QbP+c4CQ9IdkpISKgiEjkrQMjJIhmgesM0skajkSulryliVt15qetPngMlBnRwNc0xtn+PkocQma1KNad9IeYQ2/UoofUpVskJ2e1TsSNQTALKrE8JKR9Fx/7YUaReFKWXlTlGwnsk+yQkJCR8/EjkbA2UVGiKHlMebWqZOsJl1px8SpOtTSezzyFCod8LpYnx708NofZTdsF7IhKPUdYmdm1LQjkkmyUkJCQkJJRDImclEXLeGBFmUQEgr1LEHqHjfWoI2TNULMB+/lMnZUXYxB7Jdk9DsltCQkJCQkLCLpHI2YYIqT/2tf28PifkESNbZdOgEsJINktISEhISEhIeHlI5OwZWOcAJwf56Ui2S0hISEhISEhI+NSQyNmWkMhEQkJCQkJCQkJCQsJz4La52/vakzn3BsAQwNu9nTQhoTxeIbXNhOoitc+EqiK1zYSqIrXNhKri38iy7LPQP/ZKzgDAOferWZZ9e68nTUgogdQ2E6qM1D4TqorUNhOqitQ2E14iaoe+gISEhISEhISEhISEhIREzhISEhISEhISEhISEiqBQ5Cz7x7gnAkJZZDaZkKVkdpnQlWR2mZCVZHaZsKLw97XnCUkJCQkJCQkJCQkJCQ8RkprTEhISEhISEhISEhIqAD2Rs6ccz/tnPsXzrnfdM59Z1/nTUggnHN/xTn32jn3f8l7l865f+Cc+38/PF/I/37+Q3v9F865/+AwV53wKcA59w3n3P/inPsN59yvO+f+9If3U/tMOCiccw3n3D9yzv2fH9rmf/nh/dQ2EyoB59yRc+7/cM793Q9/p7aZ8KKxF3LmnDsC8N8C+A8B/A4Af9Q59zv2ce6EBMH/AOCnzXvfAfDLWZb9BIBf/vA3PrTPPwLg3/rwnf/uQztOSNgFFgD+bJZl/yaA3wPgT35og6l9JhwaUwC/N8uyfwfA7wTw086534PUNhOqgz8N4Dfk79Q2E1409qWc/W4Av5ll2b/MsmwG4K8D+Jk9nTshAQCQZdn/CuDavP0zAH7xw+tfBPAfy/t/PcuyaZZlvwXgN/G+HSckbB1Zln2VZdk/+fD6Hu8djR9Fap8JB0b2HoMPf558eGRIbTOhAnDO/RiA3w/gv5e3U9tMeNHYFzn7UQDfk7+//+G9hIRD44ssy74C3jvIAD7/8H5qswkHgXPuWwB+F4B/iNQ+EyqAD2lj/xTAawD/IMuy1DYTqoK/BODPAVjJe6ltJrxo7IucucB7qUxkQpWR2mzC3uGc6wD4mwD+TJZl/aKPBt5L7TNhJ8iybJll2e8E8GMAfrdz7t8u+Hhqmwl7gXPuDwB4nWXZr5X9SuC91DYTKod9kbPvA/iG/P1jAH6wp3MnJBTha+fclwDw4fn1h/dTm03YK5xzJ3hPzP5almV/68PbqX0mVAZZlt0C+BW8X6+T2mbCofGTAP6gc+5f4f1ymd/rnPurSG0z4YVjX+TsHwP4CefcjzvnTvF+Qebf2dO5ExKK8HcA/OyH1z8L4G/L+3/EOVd3zv04gJ8A8I8OcH0JnwCccw7AXwbwG1mW/UX5V2qfCQeFc+4z59z5h9dNAP8egP8bqW0mHBhZlv18lmU/lmXZt/Der/yfsyz7Y0htM+GF43gfJ8mybOGc+1MA/j6AIwB/JcuyX9/HuRMSCOfc/wTgpwC8cs59H8B/AeAXAPySc+5PAPjXAP4QAGRZ9uvOuV8C8M/xvpLen8yybHmQC0/4FPCTAP44gH/2YW0PAPwFpPaZcHh8CeAXP1S1qwH4pSzL/q5z7n9HapsJ1UQaNxNeNFyWpXTbhISEhISEhISEhISEQ2Nvm1AnJCQkJCQkJCQkJCQkxJHIWUJCQkJCQkJCQkJCQgWQyFlCQkJCQkJCQkJCQkIFkMhZQkJCQkJCQkJCQkJCBZDIWUJCQkJCQkJCQkJCQgWQyFlCQkJCQkJCQkJCQkIFkMhZQkJCQkJCQkJCQkJCBZDIWUJCQkJCQkJCQkJCQgXw/wNxlBC0dD8RVgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 1080x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"pyplot.figure(figsize=(15,5))\n",
"pyplot.imshow(torch.cat([r[0].sum(1).view(N, N).cpu() for r in res], 1), cmap=pyplot.cm.gray_r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1+"
}
},
"nbformat": 4,
"nbformat_minor": 2
}