#### Imports ####

In [None]:
import os
import torch
import numpy as np
import normflows as nf
import hydra
import matplotlib
from matplotlib import pyplot as plt
from tqdm import tqdm

from flows import RealNVP, NeuralSplineFlow
from prefflow import PrefFlow
from plotter import Plotter
from target import set_up_problem
from misc import convert_to_ranking

#### Load config file to set-up experiment and algorithm details ####

In [None]:
with hydra.initialize(version_base=None, config_path="conf"):
   os.chdir(os.getcwd())
   output_folder = os.path.join(os.getcwd(), 'temporary_outputs')
   cfg = hydra.compose(config_name="config.yaml",overrides=["+db=mysql"])
if not cfg.plot.showduringtraining: # Show plots?
    matplotlib.use('Agg') 

#### Device and Precision ####

In [None]:
torch.set_default_dtype(torch.float64 if cfg.device.precision_double else torch.float32)
device = torch.device(cfg.device.device)

#### Random seeds ####

In [None]:
import random
torch.manual_seed(cfg.exp.seed)
np.random.seed(cfg.exp.seed)
random.seed(cfg.exp.seed)

  #### Target belief density ####

In [None]:
target_name = 'onemoon'
D = 2
target, bounds, uniform, D, normalize = set_up_problem(target_name,D)

#### Base distribution and Flow architecture ####

In [None]:
q0 = nf.distributions.DiagGaussian(D, trainable=False)
nflows = cfg.params.nflows
if cfg.params.flow == "realnvp":
    nfm = RealNVP(nflows,D,q0,device,cfg.device.precision_double)
if cfg.params.flow == "neuralsplineflow":
    nfm = NeuralSplineFlow(nflows,D,q0,device,cfg.device.precision_double)

#### Data generation 1/2 ####

In [None]:
target_sample = target.sample(10000)
target_mean = target_sample.mean(dim=0).double()
target_std = target_sample.std(dim=0).double()
def sample_alternatives(n,k=2,distribution="uniform"):
        if distribution=="uniform":
            return uniform.sample(torch.tensor([k*n])).to(device)
        elif distribution=="target":
            return target.sample(k*n).to(device)
        elif distribution=="mixture_uniform_gaussian":
            target_gaussian = torch.distributions.MultivariateNormal(target_mean, target_std*torch.eye(D))
            howoftentarget = cfg.exp.mixture_success_prob
            samples = []  
            for _ in range(k):
                if np.random.sample() <= howoftentarget:
                    x = target_gaussian.sample((n,))
                else:
                    x = uniform.sample(torch.tensor([n])).to(device)
                samples.append(x)
            return torch.cat(samples, dim=0)
def expert_feedback_pairwise(comp,s=None):
    noise = (0,0) if (s is None) else torch.distributions.Exponential(s).sample((2,)).to(device)
    logprobs = target.log_prob(comp).to(device)
    return torch.ge(logprobs[0] + noise[0],logprobs[1] + noise[1]).long().view(1).to(device)
def expert_feedback_ranking(alternatives,s=None):
    k = alternatives.shape[0]
    noise = torch.distributions.Exponential(s).sample((k,)).to(device)
    logprobs = target.log_prob(alternatives).to(device) + noise
    _, ranking_inds = torch.sort(logprobs, descending=True)
    return ranking_inds.view(k).to(device)
def generate_dataset(N,s=None,distribution="uniform"):
    X = sample_alternatives(1,2,distribution)
    Y = expert_feedback_pairwise(X,s)
    X = X.unsqueeze(2) #add new dimension, which indicates sample index
    if N > 1:
        for i in range(0,N-1):
            comp = sample_alternatives(1,2,distribution)
            X = torch.cat((X,comp.unsqueeze(2)),2)
            Y = torch.cat((Y,expert_feedback_pairwise(comp,s)),0)
    return X,Y #X.shape = (2,D,N) = (comp,space dimensions, number of comps)
def generate_dataset_ranking(N,k,s=None,distribution="uniform"):
    X = sample_alternatives(1,k,distribution)
    Y = expert_feedback_ranking(X,s).view(1,k)
    X = X.unsqueeze(2) #add new dimension, which indicates sample index
    if N > 1:
        for i in range(0,N-1):
            alternatives = sample_alternatives(1,k,distribution)
            X = torch.cat((X,alternatives.unsqueeze(2)),2)
            Y = torch.cat((Y,expert_feedback_ranking(alternatives,s).view(1,k)),0)
    Xdata = convert_to_ranking(X.numpy(),Y.numpy())
    #return X,Y #X.shape = (k,D,N) = (alternatives,space dimensions, number of rankings)
    return torch.from_numpy(Xdata).view(k,-1,N) 

#### Data generation 2/2 ####

In [None]:
n = cfg.data.n
true_s = cfg.exp.true_s
ranking = True if cfg.data.k > 2 else False
if ranking:
    k = cfg.data.k
    dataset = generate_dataset_ranking(N=n,k=k,s=true_s,distribution=cfg.exp.lambda_dist)
else:
    dataset = generate_dataset(N=n,s=true_s,distribution=cfg.exp.lambda_dist)
        
def minibatch(dataset,batch_size,ranking):
    indices = torch.randperm(n)[:batch_size]
    batch = (dataset[0][:,:,indices],dataset[1][indices]) if not ranking else dataset[:,:,indices]
    return batch

  #### Initialize preferential flow ####

In [None]:
prefflow = PrefFlow(nfm,D=D,s=cfg.modelparams.s,ranking=ranking,device=device,precision_double=cfg.device.precision_double)

#### Initialize optimizer ####

In [None]:
loss_hist = np.array([])
batch_size = cfg.params.batch_size
optimizer = getattr(torch.optim, cfg.params.optimizer.capitalize())
optimizer_prefflow = optimizer([{'params':prefflow.parameters()}],lr=cfg.params.lr, weight_decay=cfg.params.weight_decay)

#### Initialize plotter ####

In [None]:
plotter = Plotter(D,bounds)

### SGD FS-MAP ###

In [None]:
for it in tqdm(range(cfg.params.max_iter),disable=not cfg.plot.progressbar_show):
    
    #Sample minibacth
    batch = minibatch(dataset,batch_size,ranking)

    #Update flow parameters
    prefflow.train()
    optimizer_prefflow.zero_grad()
    loss = -prefflow.logposterior(batch,cfg.modelparams.weightprior)
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer_prefflow.step()
    loss_hist = np.append(loss_hist, loss.to('cpu').detach().numpy())

    # Plot learned density
    if (it + 1) % cfg.plot.show_iter == 0:
        print("loss: " + str(loss.to('cpu').detach().numpy()))
        if cfg.plot.showdatapoints:
            showdata = minibatch(dataset,batch_size=n,ranking=ranking)
            probmassinarea = plotter.plot_moon(target,prefflow,data=showdata,cfg=cfg)
        else:
            probmassinarea = plotter.plot_moon(target,prefflow,data=None,cfg=cfg)
        plt.show()

#### Plot loss trajectory ####

In [None]:
plt.figure(figsize=(15, 15))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()