In [None]:
import torch
import numpy as np

import time
import matplotlib.pyplot as plt

from utils.SineDataSet import get_batch
from utils.Generator import Generator
from utils.SineDataSet import get_η

# Initialize generators and optimizers

In [None]:
G = Generator().cuda()

lr, betas = 2e-4, (.5, .9)
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=betas)

with torch.no_grad():
    σ = torch.std(get_η([100000000])).item()
print(f"σ = {σ}")
epsilon = 1e-3

# Training & logging

In [None]:
mean_errors_to_groundtruth = []
ticks = []
losses_G = []
tick = time.time()

In [None]:
while True:
    with torch.no_grad():
        y, yδ, _ = get_batch()

    optimizer_G.zero_grad()
    y_approximation = G(yδ)
    b = torch.randn_like(yδ)
    y_approximation_b = G(yδ + epsilon * b)

    loss = torch.mean((y_approximation - yδ)**2) - σ**2 + 2 * σ**2 / epsilon * torch.mean(b * (y_approximation_b - y_approximation))
    losses_G.append(loss.item())

    loss.backward()
    optimizer_G.step()

    with torch.no_grad():
        mean_errors_to_groundtruth.append(torch.norm(y - y_approximation, dim=1).mean().item())
        ticks.append(time.time() - tick)

    if not len(mean_errors_to_groundtruth) % 100:
        torch.save(G.state_dict(), f'trained_models/stein/G_state_{len(mean_errors_to_groundtruth)}.pt')
        np.save(f'trained_models/stein/mean_errors_to_groundtruth.npy', mean_errors_to_groundtruth)
        np.save(f'trained_models/stein/ticks.npy', ticks)
        np.save(f'trained_models/stein/losses_G.npy', losses_G)