# Setup

In [None]:
import numpy as np
import torch

import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
import sys
sys.path.insert(0, '..')  # repo source code
from src.heat import sun_score_hk, sun_score_hk_old, sun_score_hk_stable, sun_score_hk_autograd, sun_score_hk_autograd_v2, sample_sun_hk, sun_hk

from src.utils import grab, wrap
from src.devices import set_device, get_device, summary

In [None]:
# Set a device
import src.devices as devices  # from our src code

devices.set_device('cpu')
print(devices.summary())

# SU(2) score

In [None]:
thetas = torch.linspace(-np.pi, np.pi, steps=101)

widths = [0.00001, 0.1, 0.5, 1.0]
fig, axes = plt.subplots(2, 2)

for width, ax in zip(widths, axes.flatten()):
    # NOTE(gkanwar): sigma^2 is a useful normalization to get target score in
    # a constant range regardless of sigma(t).
    a = width**2 *sun_score_hk(thetas[:,None], width=torch.tensor(width))
    b = width**2 *sun_score_hk_old(thetas[:,None], width=torch.tensor(width))
    c = width**2 *sun_score_hk_stable(thetas[:,None], width=torch.tensor(width))
    d = width**2 *sun_score_hk_autograd(thetas[:,None], width=torch.tensor(width))
    e = width**2 *sun_score_hk_autograd_v2(thetas[:,None], width=torch.tensor(width))
    
    ax.plot(grab(thetas), grab(a)[:,0], label='new', linewidth=2.0)
    ax.plot(grab(thetas), grab(b)[:,0], label='old')
    ax.plot(grab(thetas), grab(c)[:,0], label='stable')
    ax.plot(grab(thetas), grab(d)[:,0], label='autograd')
    ax.plot(grab(thetas), grab(e)[:,0], label='autograd (v2)')
    ax.set_title(rf'$\sigma={width}$')
axes[0,0].legend()
plt.show()