In [1]:
%load_ext autoreload
%autoreload 2
from matplotlib import pyplot as plt
from matplotlib import gridspec
import matplotlib
%matplotlib inline

matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=r"\usepackage{bm} \usepackage{amsmath}"

params = {'text.usetex' : True,
          'font.size' : 12,
          'font.family' : 'lmodern',
          }
plt.rcParams.update(params)
colors = ['#d3494e', '#1f77b4', '#ff7f0e', '#2ca02c','#7D3C98', '#873600', '#d3494e', '#9467bd', '#e377c2', '#7f7f7f'] 
fmts={6:'-o',1:'-s',2:'-d', 3:'-s',4:'-d', 0:'-^',5:'--P'}
marksz={6:4,1:4,2:5,3:4,4:5,0:5,5:5}

In [2]:
import torch
import numpy as np

reference_policy_1_list: list = [0.01, 0.1, 0.3, 0.6, 0.9, 0.99]
preference_coef: torch.Tensor = torch.linspace(.01, .99, 99)
kl_coef_list: list = [0.1, 0.01, 0.001, 0.0001]

def RLHF_reward(
    kl_coef: float,
    preference_coef: torch.Tensor
):
    reward_1: torch.Tensor = (preference_coef / (1. - preference_coef) ).log()
    reward_1 = reward_1 / (2 * kl_coef)
    return reward_1, -1. * reward_1

def Nash_RS_implicit_reward_sumexp(
    kl_coef: float,
    preference_coef: torch.Tensor,
    reference_policy_1: float = .5
):
    reference_policy_2: float = 1. - reference_policy_1

    sequence_1: torch.Tensor = np.log(reference_policy_2) - preference_coef / kl_coef
    sequence_2: torch.Tensor = (np.log(reference_policy_1) - 1. / (2 * kl_coef)) * torch.ones_like(sequence_1)
    sequence: torch.Tensor = torch.cat((sequence_1.reshape(sequence_1.shape[0],1), sequence_2.reshape(sequence_2.shape[0], 1)), dim=1)
    reward_1: torch.Tensor = -1. * sequence.logsumexp(dim=1)

    sequence_1: torch.Tensor = np.log(reference_policy_1) - (1. - preference_coef) / kl_coef
    sequence_2: torch.Tensor = (np.log(reference_policy_2) - 1. / (2 * kl_coef)) * torch.ones_like(sequence_1)
    sequence: torch.Tensor = torch.cat((sequence_1.reshape(sequence_1.shape[0],1), sequence_2.reshape(sequence_2.shape[0], 1)), dim=1)
    reward_2: torch.Tensor = -1. * sequence.logsumexp(dim=1)    
    
    return reward_1, reward_2

plt.figure(figsize=(21, 30), dpi=1200)
gs = gridspec.GridSpec(len(reference_policy_1_list), len(kl_coef_list))
gs.update(left=0.05, right=0.95, wspace=0.35, hspace=0.35)
for index_1, reference_policy_1 in enumerate(reference_policy_1_list):
    for index_2, kl_coef in enumerate(kl_coef_list):

        ax = plt.subplot(gs[index_1, index_2])
        ax.grid(alpha=0.5, linestyle='--')

        reward_1, reward_2 = RLHF_reward(kl_coef, preference_coef)
        ax.plot(preference_coef.cpu().detach().numpy(), reward_1.cpu().detach().numpy(), label=f"RLHF $r_1$")
        ax.plot(preference_coef.cpu().detach().numpy(), reward_2.cpu().detach().numpy(), label=f"RLHF $r_2$")

        reward_1, reward_2 = Nash_RS_implicit_reward_sumexp(kl_coef, preference_coef, reference_policy_1=reference_policy_1)
        ax.plot(preference_coef.cpu().detach().numpy(), reward_1.cpu().detach().numpy(), label=f"NashRS $r_1$")
        ax.plot(preference_coef.cpu().detach().numpy(), reward_2.cpu().detach().numpy(), label=f"NashRS $r_2$")
        ax.set_yscale("symlog")
        ax.legend(framealpha=0.1, fontsize=16, loc=(0.5, 0.2))
        ax.set_xlabel('preference ratio $a$', fontsize=25)
        ax.set_title(f"$\pi_1$={reference_policy_1}, $\\tau$={kl_coef}", fontsize=32)

plt.savefig('implicit_reward_visualization.pdf', bbox_inches='tight', dpi=1200, format='pdf')
plt.close()