In [None]:
import torch
import numpy as np

import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

from utils.SineDataSet import get_batch
from utils.Generator import Generator
from utils.Critic import Critic
from utils.loss_functions import get_Cη_loss, get_Cyδ_loss, get_G_loss

# Initialize generators and optimizers

In [None]:
G = Generator().cuda()
Cη = Critic().cuda()
Cyδ = Critic().cuda()

lr, betas = 2e-4, (.5, .9)
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=betas)
optimizer_Cη = torch.optim.Adam(Cη.parameters(), lr=lr, betas=betas)
optimizer_Cyδ = torch.optim.Adam(Cyδ.parameters(), lr=lr, betas=betas)

# Training & logging

In [None]:
mean_errors_to_groundtruth = []
ticks = []
losses_Cη = []
losses_Cyδ = []
losses_G = []
tick = time.time()

In [None]:
while True:
    with torch.no_grad():
        y, yδ, η = get_batch()
        y_approximation, η_approximation, y_renoised = G.apply_for_training(yδ)
        mean_errors_to_groundtruth.append(torch.norm(y - y_approximation, dim=1).mean().item())
        ticks.append(time.time() - tick)

    optimizer_Cη.zero_grad()
    loss_Cη = get_Cη_loss(Cη, η, η_approximation)
    losses_Cη.append(loss_Cη.item())
    loss_Cη.backward()
    optimizer_Cη.step()

    optimizer_Cyδ.zero_grad()
    loss_Cyδ = get_Cyδ_loss(Cyδ, yδ, y_renoised) 
    losses_Cyδ.append(loss_Cyδ.item())
    loss_Cyδ.backward()
    optimizer_Cyδ.step()

    optimizer_G.zero_grad()
    loss_G = get_G_loss(G, Cη, Cyδ, yδ)
    losses_G.append(loss_G.item())
    loss_G.backward()
    optimizer_G.step()

    with torch.no_grad():
        mean_errors_to_groundtruth.append(torch.norm(y - y_approximation, dim=1).mean().item())
        ticks.append(time.time() - tick)

    if not len(mean_errors_to_groundtruth) % 100:
        torch.save(G.state_dict(), f'trained_models/gtfd/G_state_{len(mean_errors_to_groundtruth)}.pt')
        np.save(f'trained_models/gtfd/mean_errors_to_groundtruth.npy', mean_errors_to_groundtruth)
        np.save(f'trained_models/gtfd/ticks.npy', ticks)
        np.save(f'trained_models/gtfd/losses_Cη.npy', losses_Cη)
        np.save(f'trained_models/gtfd/losses_Cyδ.npy', losses_Cyδ)
        np.save(f'trained_models/gtfd/losses_G.npy', losses_G)