# Resampled Importance Sampling Test

In [None]:
# %matplotlib widget
import matplotlib.pyplot as plt
import math
import numpy as np

In [None]:
def target_density(xs):
    return 2 - 2 * xs

def sample_p1(us):
    return us

def p1(xs):
    return np.full(xs.shape, 1)

def sample_p2(us):
    return 0.5 * us

def p2(xs):
    return np.where(xs < 0.5, 2, 0)

xs = np.linspace(0, 1, 100)
p1s = p1(xs)
p2s = p2(xs)

figure = plt.figure(figsize=(12, 4))
fig_p1 = figure.add_subplot(1, 2, 1)
fig_p1.title.set_text('p1')
fig_p1.set_xlim(0, 1)
fig_p1.set_ylim(0, 2.5)
fig_p1.grid(True)
fig_p1.plot(xs, p1s)
fig_p2 = figure.add_subplot(1, 2, 2)
fig_p2.title.set_text('p2')
fig_p2.set_xlim(0, 1)
fig_p2.set_ylim(0, 2.5)
fig_p2.grid(True)
fig_p2.plot(xs, p2s)
figure.show()

In [None]:
def sampleRIS(us):
    numSamples = us.shape[1]
    M = us.shape[0]
    half_M = M // 2
    def sampleProposedDist(usPerSample):
        return np.concatenate([sample_p1(usPerSample[:half_M]),
                               sample_p2(usPerSample[half_M:])])
    def computeWeights(candidatesPerSample):
        return np.concatenate([target_density(candidatesPerSample[:half_M]) / p1(candidatesPerSample[:half_M]),
                               target_density(candidatesPerSample[half_M:]) / p2(candidatesPerSample[half_M:])])
    def resample(weightsPerSample):
        probs = weightsPerSample / np.sum(weightsPerSample)
        return np.random.choice(np.arange(M), p=probs)

    candidates = np.apply_along_axis(sampleProposedDist, 0, us)
    weights = np.apply_along_axis(computeWeights, 0, candidates)
    indexPerSampleValues = np.apply_along_axis(resample, 0, weights)
    samples = np.take_along_axis(candidates, np.array([indexPerSampleValues]), 0).reshape(numSamples)

    targetDensityValues = target_density(samples)
    sumWeightsPerSampleValues = np.apply_along_axis(np.sum, 0, weights)

    biased_ws = sumWeightsPerSampleValues / (M * targetDensityValues)

    def countNonZeroDists(samples):
        return np.where(p1(samples) > 0.0, half_M, 0) + np.where(p2(samples) > 0.0, half_M, 0)
    numNonZeroDistsPerSamplesValues = countNonZeroDists(samples)
    naive_ws = sumWeightsPerSampleValues / (numNonZeroDistsPerSamplesValues * targetDensityValues)

    def computeMISWeight(samples, indices):
        denom = half_M * p1(samples) + half_M * p2(samples)
        return np.where(indices < half_M, p1(samples), p2(samples)) / denom
    misWeightPerSampleValues = computeMISWeight(samples, indexPerSampleValues)
    mis_ws = misWeightPerSampleValues * sumWeightsPerSampleValues / targetDensityValues

    return samples, biased_ws, naive_ws, mis_ws

In [None]:
figure = plt.figure(figsize=(12, 4))
figBiased = figure.add_subplot(1, 2, 1)
figBiased.title.set_text('Biased RIS')
figBiased.set_ylim(0, 4)
figBiased.grid(True)
figNaive = figure.add_subplot(1, 2, 2)
figNaive.title.set_text('Naive Unbiased RIS')
figNaive.set_ylim(0, 4)
figNaive.grid(True)

analytic_xs = np.linspace(0, 0.999, 100)
analytic_ys = target_density(analytic_xs)
figBiased.plot(analytic_xs, 1 / analytic_ys, label='1 / f(x)', color='k')
figNaive.plot(analytic_xs, 1 / analytic_ys, label='1 / f(x)', color='k')

numSamples = 50000
numBins = 100
configs = ((2, 'hotpink'), (4, 'red'), (10, 'mediumpurple'), (20, 'dodgerblue'))
for M, plotColor in configs:
    us = np.random.rand(M, numSamples)
    ris_xs, ris_biased_ws, ris_naive_ws, _ = sampleRIS(us)

    hist, histBounds = np.histogram(ris_xs, bins=numBins, range=(0, 1), density=True)
    invHist = np.divide(1.0, hist, out=np.zeros_like(hist), where=hist != 0)
    hist_xs = []
    for i in range(0, len(hist)):
        hist_xs.append((histBounds[i] + histBounds[i + 1]) * 0.5)

    avg_biased_ws = np.zeros(numBins)
    numSamplesPerBin = np.zeros(numBins)
    for i in range(numSamples):
        binIdx = int(min(numBins * ris_xs[i], numBins - 1))
        avg_biased_ws[binIdx] += ris_biased_ws[i]
        numSamplesPerBin[binIdx] += 1
    for i in range(numBins):
        if numSamplesPerBin[i] == 0:
            continue;
        avg_biased_ws[i] /= numSamplesPerBin[i]
    figBiased.plot(hist_xs, avg_biased_ws, color=plotColor, alpha=0.5)
    figBiased.plot(hist_xs, invHist, label='M=' + str(M), color=plotColor)

    avg_naive_ws = np.zeros(numBins)
    numSamplesPerBin = np.zeros(numBins)
    for i in range(numSamples):
        binIdx = int(min(numBins * ris_xs[i], numBins - 1))
        avg_naive_ws[binIdx] += ris_naive_ws[i]
        numSamplesPerBin[binIdx] += 1
    for i in range(numBins):
        if numSamplesPerBin[i] == 0:
            continue;
        avg_naive_ws[i] /= numSamplesPerBin[i]
    figNaive.plot(hist_xs, avg_naive_ws, color=plotColor, alpha=0.5)
    figNaive.plot(hist_xs, invHist, label='M=' + str(M), color=plotColor)

handles, labels = figBiased.get_legend_handles_labels()
figure.legend(handles, labels, ncol=len(configs) + 1, bbox_to_anchor=(0.5, 1.0), loc='lower center')
figure.show()

In [None]:
def target_density(xs):
    return 2 - 2 * xs

def sample_p1(us):
    return us

def p1(xs):
    return np.full(xs.shape, 1)

def sample_p2(us):
    return np.where(us < 0.999, 0.5 * (us / 0.999), 0.5 + 0.5 * (us - 0.999) / 0.001)

def p2(xs):
    return np.where(xs < 0.5, 1.998, 0.002)



xs = np.linspace(0, 1, 100)
p1s = p1(xs)
p2s = p2(xs)

figure = plt.figure(figsize=(12, 4))
fig_p1 = figure.add_subplot(1, 2, 1)
fig_p1.title.set_text('p1')
fig_p1.set_xlim(0, 1)
fig_p1.set_ylim(0, 2.5)
fig_p1.grid(True)
fig_p1.plot(xs, p1s)
fig_p2 = figure.add_subplot(1, 2, 2)
fig_p2.title.set_text('p2')
fig_p2.set_xlim(0, 1)
fig_p2.set_ylim(0, 2.5)
fig_p2.grid(True)
fig_p2.plot(xs, p2s)
figure.show()

In [None]:
figure = plt.figure(figsize=(12, 4))
figNaive = figure.add_subplot(1, 2, 1)
figNaive.title.set_text('Naive Unbiased RIS')
figNaive.set_ylim(0, 4)
figNaive.grid(True)
figMIS = figure.add_subplot(1, 2, 2)
figMIS.title.set_text('MIS Unbiased RIS')
figMIS.set_ylim(0, 4)
figMIS.grid(True)

analytic_xs = np.linspace(0, 0.999, 100)
analytic_ys = target_density(analytic_xs)
figNaive.plot(analytic_xs, 1 / analytic_ys, label='1 / f(x)', color='k')
figMIS.plot(analytic_xs, 1 / analytic_ys, label='1 / f(x)', color='k')

numSamples = 50000
numBins = 100
configs = ((2, 'hotpink'), (4, 'red'), (10, 'mediumpurple'), (20, 'dodgerblue'))
for M, plotColor in configs:
    us = np.random.rand(M, numSamples)
    ris_xs, _, ris_naive_ws, ris_mis_ws = sampleRIS(us)

    hist, histBounds = np.histogram(ris_xs, bins=numBins, range=(0, 1), density=True)
    invHist = np.divide(1.0, hist, out=np.zeros_like(hist), where=hist != 0)
    hist_xs = []
    for i in range(0, len(hist)):
        hist_xs.append((histBounds[i] + histBounds[i + 1]) * 0.5)

    avg_naive_ws = np.zeros(numBins)
    numSamplesPerBin = np.zeros(numBins)
    for i in range(numSamples):
        binIdx = int(min(numBins * ris_xs[i], numBins - 1))
        avg_naive_ws[binIdx] += ris_naive_ws[i]
        numSamplesPerBin[binIdx] += 1
    for i in range(numBins):
        if numSamplesPerBin[i] == 0:
            continue;
        avg_naive_ws[i] /= numSamplesPerBin[i]
    figNaive.plot(hist_xs, avg_naive_ws, color=plotColor, alpha=0.5)
    figNaive.plot(hist_xs, invHist, label='M=' + str(M), color=plotColor)

    avg_mis_ws = np.zeros(numBins)
    numSamplesPerBin = np.zeros(numBins)
    for i in range(numSamples):
        binIdx = int(min(numBins * ris_xs[i], numBins - 1))
        avg_mis_ws[binIdx] += ris_mis_ws[i]
        numSamplesPerBin[binIdx] += 1
    for i in range(numBins):
        if numSamplesPerBin[i] == 0:
            continue;
        avg_mis_ws[i] /= numSamplesPerBin[i]
    figMIS.plot(hist_xs, avg_mis_ws, color=plotColor, alpha=0.5)
    figMIS.plot(hist_xs, invHist, label='M=' + str(M), color=plotColor)

handles, labels = figMIS.get_legend_handles_labels()
figure.legend(handles, labels, ncol=len(configs) + 1, bbox_to_anchor=(0.5, 1.0), loc='lower center')
figure.show()