In [1]:
import import_ipynb
import models

In [2]:
from typing import Union
import torch
import numpy as np
import tensorflow_probability as tfp
from typing import Any, Dict, Tuple, List
from torch.distributions import Normal
tfd = tfp.distributions
from scipy.optimize import root_scalar
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import numpy as np
from typing import Callable, Tuple

ArrayT = Union[torch.Tensor, np.ndarray]

def permute_dataset(features: torch.Tensor, target: torch.Tensor, seed: int) -> tuple[torch.Tensor, torch.Tensor]:
    torch.manual_seed(seed)
    permutation = torch.randperm(target.size(0))
    return features[permutation], target[permutation]

def chandrupatla_torch(f, low, high, tol=1e-5, max_iter=60):
    a, b = torch.tensor(low).unsqueeze(0), torch.tensor(high).unsqueeze(0)
    fa, fb = f(a), f(b)

    if (fa * fb > 0).any():
        raise ValueError("Функция должна иметь разные знаки на концах отрезка.")

    c = a.clone()
    fc = fa.clone()
    converged = torch.zeros_like(fc, dtype=torch.bool)

    for _ in range(max_iter):
        d = (a + b) / 2
        fd = f(d)

        left_update = (fd * fb < 0)
        a = torch.where(left_update, d, a)
        fa = torch.where(left_update, fd, fa)

        b = torch.where(~left_update, d, b)
        fb = torch.where(~left_update, fd, fb)

        c = (a + b) / 2
        fc = f(c)

        new_converged = (torch.abs(fc) < tol) | (torch.abs(b - a) < tol)
        converged = converged | new_converged

        if converged.all():
            return c.squeeze(-1)

    raise RuntimeError("Не удалось найти корень.")


def _normal_quantile_via_root(means, scales, q, axis=(0, 1)):
    n = Normal(means, scales)
    
    def quantile_root_fn(x):
        x_tensor = torch.tensor(x, dtype=torch.float32)
        cdf_val = n.cdf(x_tensor)
        
        # Проверяем, существует ли указанная ось
        if cdf_val.dim() >= max(axis) + 1:
            cdf_val = cdf_val.mean(dim=axis)
        else:
            cdf_val = cdf_val.mean()

        return cdf_val - q

    low = torch.min(means) - 5 * torch.max(scales)
    high = torch.max(means) + 5 * torch.max(scales)

    root = chandrupatla_torch(quantile_root_fn, low, high, tol=1e-5, max_iter=60)
    return root
def _approximate_normal_quantile(means, scales, q, axis=(0, 1)):
    if len(means.shape) > 1:
        mixture_mean = means.mean(dim=axis)
        variance_term = (scales**2 + means**2).mean(dim=axis) - mixture_mean**2
    else:
        mixture_mean = means.mean()
        variance_term = (scales**2 + means**2).mean() - mixture_mean**2

    valid_mask = variance_term >= 0
    safe_scale = torch.where(valid_mask, torch.sqrt(variance_term), torch.tensor(1.0))

    n = Normal(mixture_mean, safe_scale)

    quantiles = torch.where(valid_mask, n.icdf(torch.tensor(q)), torch.tensor(float('nan')))
    return quantiles

def _get_percentile_normal(means, scales, quantiles, axis=None, approximate=False):
    if means.dim() > 1:
        axis = (0, 1)
    else:
        axis = (0,)  

    quantile_fn = _approximate_normal_quantile if approximate else _normal_quantile_via_root
    return [quantile_fn(means, scales, q, axis) for q in quantiles]

def _make_forecast_inner(model_args, distribution):
    """Construct inner forecast function for MAP and VI."""
    def forecast_inner(params, x_subset):
        mlp,mlp_template = make_model(**model_args)
        likelihood = models.make_likelihood_model(
            params, x_subset, mlp,mlp_template, distribution
        )
        if distribution == models.LikelihoodDist.NORMAL:
            return (likelihood.base_dist.loc, likelihood.base_dist.scale)
        elif distribution == models.LikelihoodDist.NB:
            return (
                likelihood.base_dist.total_count,
                likelihood.base_dist.logits,
            )
        elif distribution == models.LikelihoodDist.ZINB:
            return (
                likelihood.base_dist.base_dist.total_count,
                likelihood.base_dist.base_dist.logits,
                likelihood.base_dist.gate_logits,
            )
        else:
            raise TypeError('Distribution must be one of NORMAL, NB, or ZINB.')

    return forecast_inner
import torch

def forecast_parameters_batched(
    features: torch.Tensor,
    params: dict,
    distribution: str,
    forecast_inner: callable,
    batchsize: int = 1024,
):
    """Вычисление параметров распределения с помощью PyTorch."""

    forecast_params_slices = [[], [], []]

    data_size = features.shape[0]
    num_batches = data_size // batchsize
    for i in range(num_batches + 1):
        if i == num_batches:
            batch_slice = slice(i * batchsize, None)
            if features[batch_slice].shape[0] == 0:
                continue
        else:
            batch_slice = slice(i * batchsize, (i + 1) * batchsize)

        forecast_params = forecast_inner(params, features[batch_slice])
        for idx, fc_param in enumerate(forecast_params):
            forecast_params_slices[idx].append(fc_param)

    if distribution == models.LikelihoodDist.NORMAL:
        loc = torch.cat(forecast_params_slices.pop(0), dim=-1)
        scale = torch.cat(forecast_params_slices.pop(0), dim=-1)
        forecast_params = (loc, scale)

    elif distribution == models.LikelihoodDist.NB:
        total_count = torch.cat(forecast_params_slices.pop(0), dim=0)
        logits = torch.cat(forecast_params_slices.pop(0), dim=0)
        forecast_params = (total_count, logits)

    elif distribution == models.LikelihoodDist.ZINB:
        total_count = forecast_params_slices[0][0]
        logits = torch.cat(forecast_params_slices[1], dim=-1)
        zero_mass = torch.cat(forecast_params_slices[2], dim=-1)
        forecast_params = (total_count, logits, zero_mass)

    else:
        raise TypeError('Distribution must be NORMAL, NB, or ZINB.')

    return forecast_params
def softplus_inverse(x):
    """Реализация обратного Softplus."""
    return torch.log(torch.exp(x) - 1)


def make_vi_init(prior_d):
    """Construct a surrogate posterior init function."""
    
    trace = pyro.poutine.trace(prior_d).get_trace()
    trace.compute_log_prob()
    samples = {name: site['value'] for name, site in trace.nodes.items() if site['type'] == 'sample'}
    
    def _fn(seed=None):
        if seed is not None:
            pyro.set_rng_seed(seed)

        init_values = {}
        for i, x in enumerate(samples):
            if len(samples[x].shape) != 2:
                init_values[f"zero_initial_mean_for_bias_or_transformed_scale_{i}"] = pyro.sample(
                    f"zero_initial_mean_for_bias_or_transformed_scale_{i}",
                    dist.Delta(torch.zeros_like(samples[x]))
                )
            else:
                init_values[f"initial_weight_matrix_{i}"] = pyro.sample(
                    f"initial_weight_matrix_{i}",
                    dist.TransformedDistribution(
                        dist.Uniform(-2, 2),
                        dist.transforms.AffineTransform(0.0, torch.ones_like(samples[x]))
                    )
                )

            inv_softplus_scale = torch.log(torch.exp(torch.tensor(0.3)) - 1)
            init_values[f"initial_inv_softplus_surrogate_scale_{i}"] = pyro.sample(
                f"initial_inv_softplus_surrogate_scale_{i}",
                dist.Delta(inv_softplus_scale * torch.ones_like(samples[x]))
            )

        return init_values
    
    return _fn

def make_model(
    width, depth, input_scales, num_seasonal_harmonics, seasonality_periods,
    init_x, fourier_degrees, interactions
):
    mlp = models.BayesianNeuralField1D(
        width=width,
        depth=depth,
        input_scales=input_scales,
        fourier_degrees=fourier_degrees,
        interactions=interactions,
        num_seasonal_harmonics=num_seasonal_harmonics,
        seasonality_periods=seasonality_periods,
    )

    init_input = torch.zeros(init_x, dtype=torch.float32)
    mlp(init_input)

    return mlp, mlp.state_dict()


import pyro.distributions as dist
from typing import Any, Dict
import pyro
import functools
def make_prior(**kwargs):
    kwargs.pop("likelihood_distribution", None)
    mlp_template = make_model(**kwargs)[1]

    def prior_model():
        return models.prior_model_fn(mlp_template)

    return prior_model


def build_observation_distribution(distribution, forecast_params):
    """
    Returns (zero inflated) Negative Binomial distribution given parameters.

    Args:
        distribution: Indicates whether a zero-inflated models.LikelihoodDist.ZINB
        or models.LikelihoodDist.NB distribution should be returned.
        forecast_params: Tuple of total_count, logits, and optionally maybe_zero_mass
    """
    total_count, logits, *maybe_zero_mass = forecast_params
    
    if distribution == 'NB': 
        return NegativeBinomial(total_count=total_count.unsqueeze(-1), logits=logits)
    
    elif distribution == 'ZINB':  
        inflated_loc_probs = maybe_zero_mass[0]
        return ZeroInflatedNegativeBinomial(total_count=total_count.unsqueeze(-1), logits=logits, gate=inflated_loc_probs)
    else:
        raise ValueError(f'Unknown distribution: {distribution}')
class ZeroInflatedNegativeBinomial:
    def __init__(self, total_count, probs, gate):
        self.base_dist = NegativeBinomial(total_count=total_count.item(), probs=probs.item())
        self.gate = gate

    def cdf(self, x):
        return custom_cdf(self, x)

    def prob_zero(self):
        """Calculate the probability of zero occurrences."""
        pmf_zero = self.base_dist.log_prob(torch.tensor(0)).exp()
        return self.gate + (1 - self.gate) * pmf_zero

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def stddev(self):
        return self.base_dist.stddev  
def custom_cdf(dist, x):
    """Compute the CDF for Zero-Inflated Negative Binomial distribution."""
    pmf_values = dist.base_dist.log_prob(torch.arange(0, x.max().item() + 1)).exp()
    cdf_values = pmf_values.cumsum(dim=0)
    
    return dist.gate + (1 - dist.gate) * cdf_values[x.long()]

def negative_binomial_cdf(n, p, x):
    """Custom CDF for Negative Binomial distribution."""
    x_tensor = x if isinstance(x, torch.Tensor) else torch.tensor(x)
    
    pmf_values = [torch.exp(NegativeBinomial(total_count=n.item(), probs=p.item()).log_prob(torch.tensor(k))) for k in range(int(x_tensor.max().item()) + 1)]
    
    cdf_values = torch.tensor(pmf_values).cumsum(dim=0)
    
    return cdf_values[x_tensor.long()]

def _get_nb_quantiles_root(dist, q, ensemble_axes=(0, 1, 2)):
    """Returns (zero inflated) Negative Binomial quantiles via root-finding."""
    
    def cdf_diff(x):
        if isinstance(dist, ZeroInflatedNegativeBinomial):
            cdf_val = custom_cdf(dist, x)
        else:  
            cdf_val = negative_binomial_cdf(dist.total_count.unsqueeze(0), dist.probs.unsqueeze(0), x)  
            
        if cdf_val.dim() > 1:
            return cdf_val.mean(dim=ensemble_axes) - q
        else:
            return cdf_val - q

    low = 0.0
    high_mean = dist.mean if isinstance(dist.mean, torch.Tensor) else torch.tensor(dist.mean)
    high_stddev = dist.stddev if isinstance(dist.stddev, torch.Tensor) else torch.tensor(dist.stddev)
    
    if isinstance(dist, ZeroInflatedNegativeBinomial):
        high = (high_mean + 1.1 * torch.sqrt(1 - q) * high_stddev).max()
    else:  
        high = (high_mean + 1.1 * high_stddev).max()

    res = chandrupatla_torch(cdf_diff, low, high)

    if isinstance(dist, ZeroInflatedNegativeBinomial):
        return torch.ceil(torch.where(dist.prob_zero() > q, 0, res))
    else: 
        return torch.ceil(res)
def fit_vi(
    features: torch.Tensor,
    target: torch.Tensor,
    seed: int,
    observation_model: str,
    model_args: dict[str, Any],
    ensemble_size: int,
    learning_rate: float,
    num_epochs: int,
    sample_size_divergence: int,
    sample_size_posterior: int,
    kl_weight: float,
    batch_size: Union[int,None] = None,
):
  """Fit BNF using an ensemble VI."""
  distribution = models.LikelihoodDist(observation_model)

  def _neg_energy_fn(params, x, y):
    return models.make_likelihood_model(
        params, x, *make_model(**model_args), distribution
    ).log_prob(y).sum()

  return ensemble_vi(
      features,
      target,
      _neg_energy_fn,
      prior_d=make_prior(**model_args),
      ensemble_size=ensemble_size,
      learning_rate=learning_rate,
      num_epochs=num_epochs,
      seed=seed,
      sample_size=sample_size_divergence,
      num_samples=sample_size_posterior,
      kl_weight=kl_weight,
      batch_size=batch_size,
  )

def fit_map(
    features: torch.Tensor,
    target: torch.Tensor,
    seed: int,
    observation_model: str,
    model_args: dict[str, any],
    num_particles: int,
    learning_rate: float,
    num_epochs: int,
    prior_weight: float = 1.0,
    batch_size: Union[int, None] = None,
    num_splits: int = 1,
):
    torch.manual_seed(seed)
    distribution = models.LikelihoodDist(observation_model)
    
    def _neg_energy_fn(params, x, y):
        model = models.make_likelihood_model(
            params, x, *make_model(**model_args), distribution
        )
        log_prob = model.log_prob(y).sum()
        return log_prob
    
    target_scale = torch.std(target[~torch.isnan(target)])
    
    def _make_init_fn(prior_d):
        """Construct a surrogate posterior init function."""
        trace = pyro.poutine.trace(prior_d).get_trace()
        trace.compute_log_prob()
        samples = {name: site['value'] for name, site in trace.nodes.items() if site['type'] == 'sample'}

        def _fn(seed=None):
            if seed is not None:
                pyro.set_rng_seed(seed)

            init_values = {}
            for i, x in enumerate(samples):
                if i == 0:
                    init_values[f"zero_initial_log_noise_scale"] = pyro.sample(
                        f"zero_initial_log_noise_scale",
                        dist.Delta(torch.ones_like(samples[x]) * torch.log(torch.tensor(0.5) / 2))
                    )
                elif len(samples[x].shape) != 2:
                    init_values[f"zero_initial_mean_for_bias_or_transformed_scale_{i}"] = pyro.sample(
                        f"zero_initial_mean_for_bias_or_transformed_scale_{i}",
                        dist.Delta(torch.zeros_like(samples[x]))
                    )
                else:
                    init_values[f"initial_weight_matrix_{i}"] = pyro.sample(
                        f"initial_weight_matrix_{i}",
                        dist.TransformedDistribution(
                            dist.Uniform(-2, 2),
                            dist.transforms.AffineTransform(0.0, torch.ones_like(samples[x]))
                        )
                    )
            return init_values

        return _fn
    
    prior = make_prior(**model_args)
    params = []
    losses = []
    for i in range(num_splits):
        seed_i = seed + i if num_splits > 1 else seed

        init_fn = _make_init_fn(prior)
        ensemble_size = num_particles // num_splits

        params_i, losses_i = ensemble_map(
            features,
            target,
            _neg_energy_fn,
            prior_d=prior,
            init_fn=init_fn,
            ensemble_size=ensemble_size,
            learning_rate=learning_rate,
            num_epochs=num_epochs,
            seed=seed_i,
            batch_size=batch_size,
            prior_weight=prior_weight,
        )

        params.extend([{
            k: v.detach() for k, v in p.items()
        } for p in params_i])

        losses.extend(losses_i)

    params_dict = {}
    for p in params:
        for k, v in p.items():
            if k not in params_dict:
                params_dict[k] = []
            params_dict[k].append(v) 

    params = {k: torch.stack(v, dim=0) for k, v in params_dict.items()}
    
    losses = torch.cat(losses, dim=0)  

    return params, losses


def predict_bnf(
    features: torch.Tensor,
    observation_model: str,
    params: dict[str, torch.Tensor],  
    model_args: dict[str, any],
    quantiles: torch.Tensor,
    ensemble_dims: int = 2,
    approximate_quantiles: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
    """Predict new data from an existing BNF fit."""
    distribution = models.LikelihoodDist(observation_model)
    assert ensemble_dims >= 1

    forecast_inner = _make_forecast_inner(model_args, distribution)

    forecast_params_list = []
    for i in range(ensemble_dims):
        params_i = {k: v[i] for k, v in params.items()}  
        forecast_params_i = forecast_inner(list(params_i.values()), features)  

        if isinstance(forecast_params_i, tuple):
            forecast_params_list.extend(forecast_params_i) 
        else:
            forecast_params_list.append(forecast_params_i) 

    forecast_params = torch.stack(forecast_params_list, dim=0)

    if distribution == models.LikelihoodDist.NORMAL:
        means, scales = forecast_params[0], forecast_params[1]
        forecast_means = means
        forecast_quantiles = _get_percentile_normal(
            forecast_means,
            scales,
            quantiles,
            axis=(0, 1),  
            approximate=approximate_quantiles,
        )

    elif distribution in [models.LikelihoodDist.NB, models.LikelihoodDist.ZINB]:
        obs_d = _build_observation_distribution(distribution, forecast_params)
        forecast_means = obs_d.mean()
        forecast_quantiles = torch.stack(
            [ _get_nb_quantiles_root(obs_d, q, ensemble_axes=(0, 1)) for q in quantiles]
        )

    else:
        raise ValueError(f'Unknown distribution: {distribution}')

    return forecast_means, forecast_quantiles
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
def compute_prior_log_prob(prior_samples):
    i=0
    log_prob_values=[]
    for name, param in prior_samples.items():
        if i==0:
            log_prob_values.append(dist.Logistic(0.0, 1.0).log_prob(param).sum())
        elif i==1:
            log_prob_values.append(dist.Logistic(-1.5, 1.0).log_prob(param).sum())
        elif i==2:
            log_prob_values.append(dist.Logistic(0.0, 1.0).log_prob(param).sum())
        else:
            log_prob_values.append(dist.Logistic(-1, torch.ones_like(param)).log_prob(param).sum())
        i+=1

    return sum(log_prob_values)
def ensemble_map(
    features: torch.Tensor,
    target: torch.Tensor,
    neg_energy_fn: callable,
    prior_d: torch.distributions.Distribution,
    init_fn: callable,
    ensemble_size: int,
    learning_rate: float,
    num_epochs: int,
    seed: int,
    batch_size: Union[int,None] = None,
    prior_weight: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fit an ensemble of MAP estimates in PyTorch."""
    
    if batch_size is None:
        batch_size = target.shape[0]
    
    device = features.device
    
    
    init_seeds = torch.randint(0, 2**32 - 1, (1, ensemble_size), dtype=torch.int64)
    init_params = [init_fn(seed.item()) for seed in init_seeds[0]]
    dataset = TensorDataset(features, target)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    all_losses = []
    def _target_log_prob_fn(params, x_batch, y_batch, neg_energy_fn, prior_d, prior_weight, batch_size):
        neg_log_likelihood = neg_energy_fn(list(params.values()), x_batch, y_batch)

        if prior_weight == 0.0:
            return -neg_log_likelihood * (y_batch.shape[0] / batch_size)
        
    
        prior_log_prob = compute_prior_log_prob(params)
        return -(neg_log_likelihood * (y_batch.shape[0] / batch_size) + prior_log_prob * prior_weight)
    for param_set in init_params:
        for key, value in param_set.items():
            if torch.all(value == 0):  
                param_set[key] = torch.nn.Parameter(value + torch.randn_like(value) * 1e-3, requires_grad=True)
            else:
                param_set[key] = torch.nn.Parameter(value.clone(), requires_grad=True)
    param_groups = [{'params': list(params.values())} for params in init_params]


    optimizer = optim.Adam(param_groups, lr=learning_rate)

    for epoch in range(num_epochs):
        epoch_losses = []
        for batch_features, batch_target in data_loader:
            batch_features, batch_target = batch_features.to(device), batch_target.to(device)

            losses = []
            for params in init_params:
                optimizer.zero_grad()
                loss = _target_log_prob_fn(params, batch_features, batch_target, neg_energy_fn, prior_d, prior_weight, batch_size)

                loss.backward()
                optimizer.step()
            
                losses.append(loss.item())
            epoch_losses.append(losses)
        all_losses.append(epoch_losses)
    updated_params = [{k: v.detach().clone() for k, v in params.items()} for params in init_params]
    return updated_params, torch.tensor(all_losses)
def ensemble_vi(
    features: torch.Tensor,
    target: torch.Tensor,
    neg_energy_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], float],
    prior_d: Callable[[torch.Tensor], Normal],
    ensemble_size: int,
    learning_rate: float,
    num_epochs: int,
    seed: int,
    sample_size: int = 10,
    num_samples: int = 30,
    kl_weight: float = 1.0,
    batch_size: int = None,
) -> Tuple[torch.nn.Module, torch.Tensor, torch.Tensor]:
    features=torch.tensor(features)
    target=torch.tensor(target)
    def _target_log_prob_fn_inner(params, x_batch, y_batch):
        return compute_prior_log_prob(params) + (
            neg_energy_fn(list(params.values()), x_batch, y_batch)
            * (target.shape[0] / y_batch.shape[0])
            / kl_weight
        )
    def _target_log_prob_fn(params, features, target, seed, batch_size=None):
        if batch_size is None:
            return _target_log_prob_fn_inner(params, features, target)

        torch.manual_seed(seed)
        indices = torch.randperm(features.size(0))[:batch_size]

        batch_features = features[indices]
        batch_target = target[indices]

        return _target_log_prob_fn_inner(params, batch_features, batch_target)
    def make_surrogate_posterior(params: dict, batch_ndims=1):
        posterior = {}

        keys = list(params.keys())

        for i in range(0, len(keys), 2):
            mean_key = keys[i]
            log_std_key = keys[i + 1]

            mean = params[mean_key]
            log_std = params[log_std_key]
            std = 0.0001 + F.softplus(log_std)
            posterior[mean_key] = Normal(mean, std)

        return posterior

    def sample_from_posterior(posterior, num_samples=10):

        samples = {}
        for key, dist in posterior.items():
            samples[key] = dist.sample((num_samples,))  
        return samples
    init_seeds = torch.randint(0, 2**32 - 1, (1, ensemble_size), dtype=torch.int64)
    init_params = [init_fn(seed.item()) for seed in init_seeds[0]]
    def fit_surrogate_posterior(
        init_params, target_log_prob_fn, learning_rate, num_epochs, sample_size, batch_size
    ):
        optimizer = torch.optim.Adam([p for p in init_params.values() if isinstance(p, torch.Tensor)], lr=learning_rate)
        losses = []

        for epoch in range(num_epochs):
            optimizer.zero_grad()

            posterior = make_surrogate_posterior(init_params)
            samples = sample_from_posterior(posterior, 1)
            loss = -target_log_prob_fn(samples, features, target, seed, batch_size)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

        return init_params, torch.tensor(losses)

    def predict(surrogate_params, num_samples):
        posterior = make_surrogate_posterior(surrogate_params)
        predictions = sample_from_posterior(posterior, num_samples)
        return predictions

    opt_seeds = torch.randint(0, 2**32 - 1, (ensemble_size,), dtype=torch.int64)

    fit_params = []
    fit_losses = []

    for seed in opt_seeds:
        trained_params, train_losses = fit_surrogate_posterior(
            init_params[seed % ensemble_size], 
            _target_log_prob_fn, 
            learning_rate, 
            num_epochs, 
            sample_size, 
            batch_size
        )
        fit_params.append(trained_params)
        fit_losses.append(train_losses)

    predictions = [predict(params, num_samples) for params in fit_params]

    losses = torch.stack(fit_losses)
    predictions = {key: torch.stack([p[key] for p in predictions]) for key in predictions[0]}

    return fit_params, losses, predictions