# Testing Predictions of the variants of MILD

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams['font.size'] = '15'

import torch
import os

from mild_hri.utils import *
from mild_hri.dataloaders import *

from scipy.stats import *

import pandas as pd
from statsmodels.stats.anova import AnovaRM
import statsmodels.api as sm
import pingouin as pg
import pylab as py

In [None]:
bp_mse = np.load('/home/vignesh/playground/buetepage-phri/logs/mse/bp_hh_20hz_3joints_xvel.npz', allow_pickle=True)['arr_0']
mild_mse = np.load('../logs/mse/bp_hh_20hz_3joints_xvel.npz', allow_pickle=True)['arr_0']

mse = np.concatenate([bp_mse[None], mild_mse], axis=0)
width = 1/(len(mse)+1)
spacing = width+0.05
colors = get_cmap()(np.linspace(0.1,1,len(mse)))
boxplot_kwargs = dict(
	showfliers=False, widths=[width], patch_artist = True,
	medianprops = dict(color = "black", linewidth = 1.5),
	vert=True,
)
legend_props = { "size": 15, "weight":"bold" }

hatches = ['//','*', '+', 'x', 'o', 'O']
actions = ['Waving', 'Handshake', 'Rocket\nFistbump', 'Parachute\nFistbump']
methods = ['Bütepage et al. [10]', 'MILD']
# methods = ['Bütepage et al. [10]', 'MILD v1', 'MILD v2.1', 'MILD v2.2', 'MILD v3.1', 'MILD v3.2']

fig1 = plt.figure(figsize=(len(mse)*4, 6))

ax_box = fig1.add_subplot(1,1,1)
boxes = []
box_y = []
for i in range(len(mse)):
	box_y.append([])
	for a in range(4):
		mse[i, a] = np.array(mse[i, a])*100
		box = ax_box.boxplot(mse[i, a], positions=[a+i*spacing], boxprops = dict(facecolor=colors[i]), **boxplot_kwargs)
		box["boxes"][0].set(hatch=hatches[i])
		box_y[-1].append(box['whiskers'][1].get_ydata()[1])
		if a==0:
			boxes.append(box["boxes"][0])

for i in range(len(mse)):
	for j in range(i+1,len(mse)):
		for a in range(4):
			res = mannwhitneyu(mse[i][a], mse[j][a])
			ymax = max(box_y[i][a], box_y[j][a])
			if res.pvalue < 0.05:
				ax_box.plot([a+i*spacing,a+i*spacing,a+j*spacing,a+j*spacing], [ymax+0.3, ymax+0.5, ymax+0.5, ymax+0.3], 'k-')
				if res.pvalue < 0.01:
					ax_box.text(a+width/2, ymax+0.7, '**', horizontalalignment='center', verticalalignment='center')
				else:
					ax_box.text(a+width/2, ymax+0.7, '*', horizontalalignment='center', verticalalignment='center')

ax_box.set_ylabel('Mean Squared Error (cm)')
ax_box.set_xticks(np.arange(len(actions))+width/2, actions)
ax_box.tick_params(axis='x', colors='black')
ax_box.tick_params(axis='y', colors='black')
ax_box.yaxis.label.set_color('black')
# ax_box.set_yticks(np.arange(5)+1)#, np.arange(5)+1)
ax_box.set_ylim(-0.1, 10.5)
ax_box.set_xlim(-0.3, len(actions)-0.3)
ax_box.legend(boxes, methods, loc='upper center', ncol=len(mse), prop=legend_props)
plt.tight_layout()
plt.savefig('logs/plots/mse_bp_hh_20hz.pdf')
plt.show()


In [None]:
bp_mse = np.load('/home/vignesh/playground/buetepage-phri/logs/mse/nuisihh_3joints_xvel.npz', allow_pickle=True)['arr_0']
mild_mse = np.load('../logs/mse/nuisiv2_3joints_xvel.npz', allow_pickle=True)['arr_0']

mse = np.concatenate([bp_mse[None], mild_mse], axis=0)
width = 1/(len(mse)+1)
spacing = width+0.05
colors = get_cmap()(np.linspace(0.1,1,len(mse)))
boxplot_kwargs = dict(
	showfliers=False, widths=[width], patch_artist = True,
	medianprops = dict(color = "black", linewidth = 1.5),
	vert=True,
)
legend_props = { "size": 15, "weight":"bold" }

hatches = ['//','*', '+', 'x', 'o', 'O']
actions = ['Waving', 'Handshake', 'Rocket\nFistbump', 'Parachute\nFistbump']
methods = ['Bütepage et al. [10]', 'MILD']
# methods = ['Bütepage et al. [10]', 'MILD v1', 'MILD v2.1', 'MILD v2.2', 'MILD v3.1', 'MILD v3.2']

fig1 = plt.figure(figsize=(len(mse)*4, 6))

ax_box = fig1.add_subplot(1,1,1)
boxes = []
box_y = []
for i in range(len(mse)):
	box_y.append([])
	for a in range(4):
		mse[i, a] = np.array(mse[i, a])*100
		box = ax_box.boxplot(mse[i, a], positions=[a+i*spacing], boxprops = dict(facecolor=colors[i]), **boxplot_kwargs)
		box["boxes"][0].set(hatch=hatches[i])
		box_y[-1].append(box['whiskers'][1].get_ydata()[1])
		if a==0:
			boxes.append(box["boxes"][0])

for i in range(len(mse)):
	for j in range(i+1,len(mse)):
		for a in range(4):
			res = mannwhitneyu(mse[i][a], mse[j][a])
			ymax = max(box_y[i][a], box_y[j][a])
			if res.pvalue < 0.05:
				ax_box.plot([a+i*spacing,a+i*spacing,a+j*spacing,a+j*spacing], [ymax+0.3, ymax+0.5, ymax+0.5, ymax+0.3], 'k-')
				if res.pvalue < 0.01:
					ax_box.text(a+width/2, ymax+0.7, '**', horizontalalignment='center', verticalalignment='center')
				else:
					ax_box.text(a+width/2, ymax+0.7, '*', horizontalalignment='center', verticalalignment='center')

ax_box.set_ylabel('Mean Squared Error (cm)')
ax_box.set_xticks(np.arange(len(actions))+width/2, actions)
ax_box.tick_params(axis='x', colors='black')
ax_box.tick_params(axis='y', colors='black')
ax_box.yaxis.label.set_color('black')
# ax_box.set_yticks(np.arange(5)+1)#, np.arange(5)+1)
ax_box.set_ylim(-0.1, 15)
ax_box.set_xlim(-0.3, len(actions)-0.3)
ax_box.legend(boxes, methods, loc='upper center', ncol=len(mse), prop=legend_props)
plt.tight_layout()
plt.savefig('logs/plots/mse_nuisi_hh.pdf')
plt.show()


In [None]:
bp_mse = np.load('/home/vignesh/playground/buetepage-phri/logs/mse/nuisipepper_3joints_xvel.npz', allow_pickle=True)['arr_0']
mild_mse = np.load('../logs/mse/nuisiv2_pepper_3joints_xvel_old.npz', allow_pickle=True)['arr_0']
mse = np.concatenate([bp_mse[None], mild_mse], axis=0)


width = 1/(len(mse)+1)
spacing = width+0.05
colors = get_cmap()(np.linspace(0.1,1,len(mse)))
boxplot_kwargs = dict(
	showfliers=False, widths=[width], patch_artist = True,
	medianprops = dict(color = "black", linewidth = 1.5),
	vert=True,
)
legend_props = { "size": 15, "weight":"bold" }

hatches = ['//','*', '+', 'x', 'o', 'O']
actions = ['Waving', 'Handshake', 'Rocket\nFistbump', 'Parachute\nFistbump']
# methods = ['Bütepage et al. [10]', 'MILD']
methods = ['Bütepage et al. [10]', 'MILD v1', 'MILD v2.1', 'MILD v2.2', 'MILD v3.1', 'MILD v3.2']


fig1 = plt.figure(figsize=(len(mse)*4, 6))

ax_box = fig1.add_subplot(1,1,1)
boxes = []
box_y = []
ymax_ = -1000
for i in range(len(mse)):
	box_y.append([])
	for a in range(4):
		mse[i, a] = np.array(mse[i, a])
		box = ax_box.boxplot(mse[i, a], positions=[a+i*spacing], boxprops = dict(facecolor=colors[i]), **boxplot_kwargs)
		box["boxes"][0].set(hatch=hatches[i])
		box_y[-1].append(box['whiskers'][1].get_ydata()[1])
		ymax_ = max(box_y[-1][-1], ymax_)
		if a==0:
			boxes.append(box["boxes"][0])


pvalues = np.zeros((4*(len(mse)-1), len(mse)-1))
for a in range(4):
	# d = {'MSE':np.concatenate(mse[:,a])}
	# d['participant'] = np.tile(np.arange(len(mse[0,a])), 6)
	# d['method'] = np.repeat([1,2,3,4,5,6], len(mse[0,a]))
	d = {'MSE':np.concatenate(mse[[0,2,3,4,5],a])}
	d['participant'] = np.tile(np.arange(len(mse[0,a])), 5)
	d['method'] = np.repeat([1,2,3,4,5], len(mse[0,a]))
	df = pd.DataFrame(d)
	anova = AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table
	s = 'MSE'
	key = 'method'
	keys = ['F Value', 'Pr > F']
	s += f'\t{anova[keys[0]][key]:.3e}'
	s += f'\t{anova[keys[1]][key]:.3f}'
	# print(s)
	spher, W, chisq, dof, pval = pg.sphericity(data=df, dv='MSE', subject='participant', within='method')
	# print(spher)
	# sm.qqplot(d[labels_to_use[i]], line ='r')
	# py.show()
	# print(AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table)
	
	fvalue, pvalue = f_oneway(mse[0,a], mse[1,a], mse[2,a], mse[3,a], mse[4,a], mse[5,a])
	# print(fvalue, pvalue)
	ymax = ymax_
	count = 0
	for i in range(0,len(mse)):
		for j in range(i+1,len(mse)):
			res = mannwhitneyu(mse[i][a], mse[j][a])
			print(a*(len(mse)-1)+i, j-1, f'{res.pvalue:.3e}')
			pvalues[a*(len(mse)-1)+i, j-1] = res.pvalue
			if res.pvalue < 0.05:
				ax_box.plot([a+i*spacing,a+i*spacing,a+j*spacing,a+j*spacing], [ymax*1.05, ymax*1.1, ymax*1.1, ymax*1.05], 'k-')
				if res.pvalue < 0.01:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '**', horizontalalignment='center', verticalalignment='center')
				else:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '*', horizontalalignment='center', verticalalignment='center')
				ymax += 0.1*ymax
	# print('')

pvalues[pvalues<0.001] = 0.
s = ''
for a in range(4):
	for i in range(0,len(mse)-1):
		for j in range(i):
			s += '--\t'
		for j in range(i+1,len(mse)):
			if pvalues[a*(len(mse)-1)+i, j-1]== 0:
				s += '0.\t'
			else:
				s += f'{pvalues[a*(len(mse)-1)+i, j-1]:.3f}\t'
		s += '\n'
# for i in range(pvalues.shape[0]):
# 	for j in range(0,i):
# 	for j in range(pvalues.shape[1]):
print(s)
ax_box.set_ylabel('Mean Squared Error (radians)')
ax_box.set_xticks(np.arange(len(actions))+width/2, actions)
ax_box.tick_params(axis='x', colors='black')
ax_box.tick_params(axis='y', colors='black')
ax_box.yaxis.label.set_color('black')
# ax_box.set_yticks(np.arange(5)+1)#, np.arange(5)+1)
# ax_box.set_ylim(-0.1, 15)
# ax_box.set_xlim(-0.3, len(actions)-0.3)
ax_box.legend(boxes, methods, loc='upper center', ncol=len(mse), prop=legend_props)
plt.tight_layout()
plt.savefig('logs/plots/mse_nuisi_pepper.pdf')
plt.show()


In [None]:
bp_mse = np.load('/home/vignesh/playground/buetepage-phri/logs/mse/bp_pepper_20hz_3joints_xvel.npz', allow_pickle=True)['arr_0']
mild_mse = np.load('../logs/mse/bp_pepper_20hz_3joints_xvel_old.npz', allow_pickle=True)['arr_0']
mse = np.concatenate([bp_mse[None], mild_mse], axis=0)


width = 1/(len(mse)+1)
spacing = width+0.05
colors = get_cmap()(np.linspace(0.1,1,len(mse)))
boxplot_kwargs = dict(
	showfliers=False, widths=[width], patch_artist = True,
	medianprops = dict(color = "black", linewidth = 1.5),
	vert=True,
)
legend_props = { "size": 15, "weight":"bold" }

hatches = ['//','*', '+', 'x', 'o', 'O']
actions = ['Waving', 'Handshake', 'Rocket\nFistbump', 'Parachute\nFistbump']
# methods = ['Bütepage et al. [10]', 'MILD']
methods = ['Bütepage et al. [10]', 'MILD v1', 'MILD v2.1', 'MILD v2.2', 'MILD v3.1', 'MILD v3.2']


fig1 = plt.figure(figsize=(len(mse)*4, 6))

ax_box = fig1.add_subplot(1,1,1)
boxes = []
box_y = []
for i in range(len(mse)):
	box_y.append([])
	for a in range(4):
		mse[i, a] = np.array(mse[i, a])
		box = ax_box.boxplot(mse[i, a], positions=[a+i*spacing], boxprops = dict(facecolor=colors[i]), **boxplot_kwargs)
		box["boxes"][0].set(hatch=hatches[i])
		box_y[-1].append(box['whiskers'][1].get_ydata()[1])
		if a==0:
			boxes.append(box["boxes"][0])

pvalues = np.zeros((4*(len(mse)-1), len(mse)-1))
for a in range(4):
	# d = {'MSE':np.concatenate(mse[:,a])}
	# d['participant'] = np.tile(np.arange(len(mse[0,a])), 6)
	# d['method'] = np.repeat([1,2,3,4,5,6], len(mse[0,a]))
	d = {'MSE':np.concatenate(mse[[0,2,3,4,5],a])}
	d['participant'] = np.tile(np.arange(len(mse[0,a])), 5)
	d['method'] = np.repeat([1,2,3,4,5], len(mse[0,a]))
	df = pd.DataFrame(d)
	anova = AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table
	s = 'MSE'
	key = 'method'
	keys = ['F Value', 'Pr > F']
	s += f'\t{anova[keys[0]][key]:.3e}'
	s += f'\t{anova[keys[1]][key]:.3f}'
	# print(s)
	spher, W, chisq, dof, pval = pg.sphericity(data=df, dv='MSE', subject='participant', within='method')
	# print(spher)
	# sm.qqplot(d[labels_to_use[i]], line ='r')
	# py.show()
	# print(AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table)
	
	fvalue, pvalue = f_oneway(mse[0,a], mse[1,a], mse[2,a], mse[3,a], mse[4,a], mse[5,a])
	# print(fvalue, pvalue)
	ymax = ymax_
	count = 0
	for i in range(0,len(mse)):
		for j in range(i+1,len(mse)):
			res = mannwhitneyu(mse[i][a], mse[j][a])
			print(a*(len(mse)-1)+i, j-1, f'{res.pvalue:.3e}')
			pvalues[a*(len(mse)-1)+i, j-1] = res.pvalue
			if res.pvalue < 0.05:
				ax_box.plot([a+i*spacing,a+i*spacing,a+j*spacing,a+j*spacing], [ymax*1.05, ymax*1.1, ymax*1.1, ymax*1.05], 'k-')
				if res.pvalue < 0.01:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '**', horizontalalignment='center', verticalalignment='center')
				else:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '*', horizontalalignment='center', verticalalignment='center')
				ymax += 0.1*ymax
	# print('')

pvalues[pvalues<0.001] = 0.
s = ''
for a in range(4):
	for i in range(0,len(mse)-1):
		for j in range(i):
			s += '--\t'
		for j in range(i+1,len(mse)):
			if pvalues[a*(len(mse)-1)+i, j-1]== 0:
				s += '0.\t'
			else:
				s += f'{pvalues[a*(len(mse)-1)+i, j-1]:.3f}\t'
		s += '\n'
# for i in range(pvalues.shape[0]):
# 	for j in range(0,i):
# 	for j in range(pvalues.shape[1]):
print(s)
ax_box.set_ylabel('Mean Squared Error (radians)')
ax_box.set_xticks(np.arange(len(actions))+width/2, actions)
ax_box.tick_params(axis='x', colors='black')
ax_box.tick_params(axis='y', colors='black')
ax_box.yaxis.label.set_color('black')
# ax_box.set_yticks(np.arange(5)+1)#, np.arange(5)+1)
# ax_box.set_ylim(-0.1, 15)
# ax_box.set_xlim(-0.3, len(actions)-0.3)
ax_box.legend(boxes, methods, loc='upper center', ncol=len(mse), prop=legend_props)
plt.tight_layout()
plt.savefig('logs/plots/mse_bp_pepper.pdf')
plt.show()


In [None]:
bp_mse = np.load('/home/vignesh/playground/buetepage-phri/logs/mse/yumi_20hz_3joints_xvel.npz', allow_pickle=True)['arr_0']
mild_mse = np.load('../logs/mse/bp_yumi_20hz_3joints_xvel.npz', allow_pickle=True)['arr_0']
mse = np.concatenate([bp_mse[None], mild_mse], axis=0)


width = 1/(len(mse)+1)
spacing = width+0.05
colors = get_cmap()(np.linspace(0.1,1,len(mse)))
boxplot_kwargs = dict(
	showfliers=False, widths=[width], patch_artist = True,
	medianprops = dict(color = "black", linewidth = 1.5),
	vert=True,
)
legend_props = { "size": 15, "weight":"bold" }

hatches = ['//','*', '+', 'x', 'o', 'O']
actions = ['Waving', 'Handshake', 'Rocket\nFistbump', 'Parachute\nFistbump']
# methods = ['Bütepage et al. [10]', 'MILD']
methods = ['Bütepage et al. [10]', 'MILD v1', 'MILD v2.1', 'MILD v2.2', 'MILD v3.1', 'MILD v3.2']


fig1 = plt.figure(figsize=(len(mse)*4, 6))

ax_box = fig1.add_subplot(1,1,1)
boxes = []
box_y = []
for i in range(len(mse)):
	box_y.append([])
	for a in range(4):
		mse[i, a] = np.array(mse[i, a])
		box = ax_box.boxplot(mse[i, a], positions=[a+i*spacing], boxprops = dict(facecolor=colors[i]), **boxplot_kwargs)
		box["boxes"][0].set(hatch=hatches[i])
		box_y[-1].append(box['whiskers'][1].get_ydata()[1])
		if a==0:
			boxes.append(box["boxes"][0])

pvalues = np.zeros((4*(len(mse)-1), len(mse)-1))
for a in range(4):
	# d = {'MSE':np.concatenate(mse[:,a])}
	# d['participant'] = np.tile(np.arange(len(mse[0,a])), 6)
	# d['method'] = np.repeat([1,2,3,4,5,6], len(mse[0,a]))
	d = {'MSE':np.concatenate(mse[[0,2,3,4,5],a])}
	d['participant'] = np.tile(np.arange(len(mse[0,a])), 5)
	d['method'] = np.repeat([1,2,3,4,5], len(mse[0,a]))
	df = pd.DataFrame(d)
	anova = AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table
	s = 'MSE'
	key = 'method'
	keys = ['F Value', 'Pr > F']
	s += f'\t{anova[keys[0]][key]:.3e}'
	s += f'\t{anova[keys[1]][key]:.3f}'
	# print(s)
	spher, W, chisq, dof, pval = pg.sphericity(data=df, dv='MSE', subject='participant', within='method')
	# print(spher)
	# sm.qqplot(d[labels_to_use[i]], line ='r')
	# py.show()
	# print(AnovaRM(data=df, depvar='MSE', subject='participant', within=['method']).fit().anova_table)
	
	fvalue, pvalue = f_oneway(mse[0,a], mse[1,a], mse[2,a], mse[3,a], mse[4,a], mse[5,a])
	# print(fvalue, pvalue)
	ymax = ymax_
	count = 0
	for i in range(0,len(mse)):
		for j in range(i+1,len(mse)):
			res = mannwhitneyu(mse[i][a], mse[j][a])
			print(a*(len(mse)-1)+i, j-1, f'{res.pvalue:.3e}')
			pvalues[a*(len(mse)-1)+i, j-1] = res.pvalue
			if res.pvalue < 0.05:
				ax_box.plot([a+i*spacing,a+i*spacing,a+j*spacing,a+j*spacing], [ymax*1.05, ymax*1.1, ymax*1.1, ymax*1.05], 'k-')
				if res.pvalue < 0.01:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '**', horizontalalignment='center', verticalalignment='center')
				else:
					ax_box.text(a+i*spacing+width/2, ymax*1.15, '*', horizontalalignment='center', verticalalignment='center')
				ymax += 0.1*ymax
	# print('')

pvalues[pvalues<0.001] = 0.
s = ''
for a in range(4):
	for i in range(0,len(mse)-1):
		for j in range(i):
			s += '--\t'
		for j in range(i+1,len(mse)):
			if pvalues[a*(len(mse)-1)+i, j-1]== 0:
				s += '0.\t'
			else:
				s += f'{pvalues[a*(len(mse)-1)+i, j-1]:.3f}\t'
		s += '\n'
# for i in range(pvalues.shape[0]):
# 	for j in range(0,i):
# 	for j in range(pvalues.shape[1]):
print(s)
ax_box.set_ylabel('Mean Squared Error (radians)')
ax_box.set_xticks(np.arange(len(actions))+width/2, actions)
ax_box.tick_params(axis='x', colors='black')
ax_box.tick_params(axis='y', colors='black')
ax_box.yaxis.label.set_color('black')
# ax_box.set_yticks(np.arange(5)+1)#, np.arange(5)+1)
# ax_box.set_ylim(-0.1, 15)
# ax_box.set_xlim(-0.3, len(actions)-0.3)
ax_box.legend(boxes, methods, loc='upper center', ncol=len(mse), prop=legend_props)
plt.tight_layout()
plt.savefig('logs/plots/mse_bp_pepper.pdf')
plt.show()


# "MILD" - Standard VAE Loss
$$\mathcal{L}_t = \mathbb{E}_{q_h}\log p(\boldsymbol{x}^h_t|\boldsymbol{z}^h_t) + \mathbb{E}_{q_r}\log p(\boldsymbol{x}^r_t|\boldsymbol{z}^r_t) + \mathcal{L}_{KL}$$

The following  use an additional conditional reconstruction term $\mathcal{L}_{cond} = \mathbb{E}\log p(\boldsymbol{x}^r_t|\boldsymbol{\hat{z}}^r_t)$
They mainly differ in how the expectation is calculated w.r.t $\boldsymbol{\hat{z}}^r_t$, specifically which distribution is used for Monte Carlo sampling. We then calculate the reconstruction loss for the sampled points.

# Cond. Samples
- Samples are drawn from the HMM conditional distribution.
- Calculate the posterior distribution $\boldsymbol{\mu}_{\boldsymbol{z}}(\boldsymbol{x}^h_t),\boldsymbol{\Sigma}_{\boldsymbol{z}}(\boldsymbol{x}^h_t) = q(\boldsymbol{z}^h_t|\boldsymbol{x}^h_t)$
- Calculate the conditional distribution using the posterior mean and sample $\boldsymbol{\hat{z}}^r_t$ from this.
$$
  \boldsymbol{K}_i = {\color{orange}\boldsymbol{\Sigma}^{rh}_i}({\color{orange}\boldsymbol{\Sigma}^{hh}_i})^{-1} \\
    \boldsymbol{\hat{\mu}}^r_i = {\color{orange}\boldsymbol{\mu}^{r}_i} + \boldsymbol{K}_i({\color{orange}\boldsymbol{\mu}^h_i} - {\color{magenta}\boldsymbol{\mu}_{\boldsymbol{z}}(\boldsymbol{x}^h_t)})\\
    \boldsymbol{\hat{\Sigma}}^r_i = {\color{orange}\boldsymbol{\Sigma}^{rr}_i} - \boldsymbol{K}_i{\color{orange}\boldsymbol{\Sigma}^{hr}_i} + \boldsymbol{\hat{\mu}}^r_i(\boldsymbol{\hat{\mu}}^r_i)^T\\
    \boldsymbol{\hat{\mu}}^r_t = \sum_{i=1}^N {\color{orange}\bar \alpha_i^t} \hspace{0.2em} \boldsymbol{\hat{\mu}}^r_i\\
    \boldsymbol{\hat{\Sigma}}^r_t = \left[\sum_{i=1}^N {\color{orange}\bar \alpha_i^t} \hspace{0.2em} \boldsymbol{\hat{\Sigma}}^r_i\right]  - \boldsymbol{\hat{\mu}}^r_t(\boldsymbol{\hat{\mu}}^r_t)^T\\
    p(\boldsymbol{z}_t^r | q_t^h) = \mathcal{N}(\boldsymbol{z}^r_t;\boldsymbol{\hat{\mu}}^r_t, \boldsymbol{\hat{\Sigma}}^r_t) \\
    \mathcal{L}_{cond} = \mathbb{E}_{\boldsymbol{\hat{z}}^r_t \sim p(\boldsymbol{z}_t^r | q_t^h)}\log p(\boldsymbol{x}^r_t|\boldsymbol{\hat{z}}^r_t)
$$

# Diagonalized Cond. Samples
Same as above but using the diagonalized form of the conditional distribution covariance
$$p(\boldsymbol{z}_t^r | q_t^h) = \mathcal{N}(\boldsymbol{z}^r_t;\boldsymbol{\hat{\mu}}^r_t, diag(\boldsymbol{\hat{\Sigma}}^r_t))$$

# Posterior Sample Conditioning 
$$
{\color{magenta}\boldsymbol{z}^h_t} \sim q(\boldsymbol{z}^h_t|\boldsymbol{x}^h_t)\\
\boldsymbol{K}_i = {\color{orange}\boldsymbol{\Sigma}^{rh}_i}({\color{orange}\boldsymbol{\Sigma}^{hh}_i})^{-1} \\
    \boldsymbol{\hat{z}}^r_t = \sum_{k=1}^K {\color{orange}\bar \alpha_i^t} [{\color{orange}\boldsymbol{\mu}^{r}_i} + \boldsymbol{K}_i({\color{orange}\boldsymbol{\mu}^h_i} - {\color{magenta}\boldsymbol{z}^h_t})]\\
    \mathcal{L}_{cond} = \mathbb{E}_{\boldsymbol{\hat{z}}^r_t| \boldsymbol{z}^h_t \sim q(\boldsymbol{z}^h_t | \boldsymbol{x}^h_t)}\log p(\boldsymbol{x}^r_t|\boldsymbol{\hat{z}}^r_t)
$$


# "with Post. Cov."
This implies Conditioning with the Posterior Covariance
$$\boldsymbol{K}_i = {\color{orange}\boldsymbol{\Sigma}^{rh}_i}({\color{orange}\boldsymbol{\Sigma}^{hh}_i} + {\color{magenta}\boldsymbol{\Sigma}_{\boldsymbol{z}}(\boldsymbol{x}^h_t)})^{-1}$$