# Draw two chains from 2D unit normal multivariate Gaussian chains called barley and wheat

In [1]:
import numpy as np
import matplotlib.pyplot as plt


num_samples = 20000
mean = np.zeros(2)
cov = np.eye(2)


wheat_chain = np.random.multivariate_normal(mean, cov, num_samples)
barley_chain = np.random.multivariate_normal(mean, cov, num_samples)

# 'Harvest' the two chains

In [2]:
from CombineHarvester import Harvest


In [3]:
#initiate the the class
Wheat = Harvest('output/Wheat', chain = wheat_chain, n_flows = 1) 
Barley = Harvest('output/Barley', chain = barley_chain, n_flows = 2)


In [None]:
#train the flows
Wheat.harvest()
Barley.harvest()

Training the flows


 10%|██████▌                                                          | 10/100 [00:10<01:05,  1.37it/s, train=2.8494112, val=2.856931]

# 'Combine' the two chains

In [None]:
from CombineHarvester import Combine
Grain = Combine(Wheat, Barley)

In [None]:
wheat_weights, barley_weights = Grain.combine()

# Let's compare the two set of weighted chains to the truth which is normal with cov = diag(1/2 , 1/2)

In [None]:
truth_chains = np.random.multivariate_normal(mean, 0.5 * cov, num_samples) 

In [None]:
import getdist.plots as gdplt
from getdist import MCSamples

weighted_wheat =  MCSamples(samples=wheat_chain, weights=wheat_weights)
weighted_barley =  MCSamples(samples=barley_chain, weights=barley_weights)
truth = MCSamples(samples=truth_chains)

g = gdplt.get_subplot_plotter()
g.triangle_plot([weighted_wheat, weighted_barley, truth], filled=False, legend_labels=['Weighted Wheat', 'Weighted Barley', 'Truth'])
plt.show()

# You can also save and reload the trained flows

In [None]:
Wheat.save_models()
Barley.save_models()

In [None]:
Wheat_2 = Harvest('output/Wheat', chain = wheat_chain, n_flows = 1) 
Barley_2 = Harvest('output/Barley', chain = barley_chain, n_flows = 2)

In [None]:
Wheat_2.load_models()
Barley_2.load_models()

In [None]:
Grain_2 = Combine(Wheat_2, Barley_2)

In [None]:
wheat_weights_2, barley_weights_2 = Grain_2.combine()

In [None]:
import getdist.plots as gdplt
from getdist import MCSamples

weighted_wheat_2 =  MCSamples(samples=wheat_chain, weights=wheat_weights_2)
weighted_barley_2 =  MCSamples(samples=barley_chain, weights=barley_weights_2)
truth = MCSamples(samples=truth_chains)

g = gdplt.get_subplot_plotter()
g.triangle_plot([weighted_wheat_2, weighted_barley_2, truth], filled=False, legend_labels=['Weighted Wheat', 'Weighted Barley', 'Truth'])
plt.show()