## Models

In [170]:
import enum
import numpy as np
import torch
import torch.nn.init as init
import pyro.distributions as dist
from typing import Any, Dict
import pyro

class LikelihoodDist(enum.Enum):
  NORMAL = 'NORMAL'
  NB = 'NB'
  ZINB = 'ZINB'

def make_seasonal_frequencies(
    seasonality_periods: np.ndarray, num_harmonics: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
  """Return unique Fourier frequencies for given periods and harmonics."""
  seasonality_periods = np.array(seasonality_periods, dtype=np.float32)
  num_harmonics = np.array(num_harmonics, dtype=np.float32)  
  if np.any((num_harmonics > seasonality_periods / 2)):
    raise ValueError('Harmonic cannot exceed half seasonal period.')
  if seasonality_periods.shape != num_harmonics.shape:
    raise ValueError('Number of seasonal periods and harmonics must be equal.')
  if len(num_harmonics.shape) != 1:
    raise ValueError(
        'Arguments `num_harmonics` and `seasonality_periods` must be rank 1.'
    )
  if seasonality_periods.shape[0] == 0:
    return (np.zeros(0), np.zeros(0))
  harmonics = [np.arange(1, h + 1, dtype=np.float32) for h in num_harmonics]
  frequencies = np.concatenate(
      [h / p for (h, p) in zip(harmonics, seasonality_periods)]
  )
  _, idx = np.unique(frequencies, return_index=True)
  idx_sort = np.sort(idx)
  unique_frequencies = frequencies[idx_sort]
  unique_harmonics = np.concatenate(harmonics)[idx_sort]
  return (unique_frequencies, unique_harmonics)


def make_seasonal_features(
    x: torch.Tensor, 
    seasonality_periods: torch.Tensor, 
    num_harmonics: torch.Tensor, 
    rescale: bool = False
) -> torch.Tensor:
    """Возвращает набор косинусных и синусных признаков для каждого сезона."""
    
    x = x.reshape(-1, 1)
    
    frequencies, harmonics = make_seasonal_frequencies(seasonality_periods, num_harmonics)
    
    y = 2 * np.pi * frequencies * x.numpy() 
    features = np.column_stack((np.cos(y), np.sin(y)))
    
    features = torch.tensor(features, dtype=torch.float32)
    
    denominator = torch.tensor(np.tile(harmonics, 2), dtype=torch.float32)  
    return features / denominator if rescale else features


def make_fourier_features(x, max_degree, rescale=False):
    """Создает набор синусоидальных и косинусных признаков."""
    x = x.view(-1, 1)  
    degrees = torch.arange(max_degree, dtype=torch.float32)
    y = 2 * torch.pi * 2**degrees * x
    
    features = torch.cat((torch.cos(y), torch.sin(y)), dim=-1)
    
    if rescale:
        denominator = torch.cat((degrees + 1, degrees + 1))
        features = features / denominator
    
    return features

from torch.distributions import Uniform, TransformedDistribution, SigmoidTransform, AffineTransform

def logistic_distribution(loc: float, scale: float):
    """Создает логистическое распределение в PyTorch."""
    base_dist = Uniform(0.0, 1.0)  
    logistic_dist = TransformedDistribution(
        base_dist, [
            SigmoidTransform().inv, 
            AffineTransform(loc=loc, scale=scale)  
        ]
    )
    return logistic_dist
def prior_model_fn(mlp_template):
    log_noise_scale = pyro.sample("log_noise_scale", dist.Logistic(0.0, 10.0))
    shape = pyro.sample("shape", dist.Logistic(-1.5, 10.0))
    inflated_loc_probs = pyro.sample("inflated_loc_probs", dist.Logistic(0.0, 10.0))
    flat_mlp_params = {}
    for name, param in mlp_template.items():
        weights = pyro.sample(f"weights_{name}", dist.Logistic(0.0, torch.ones_like(param)))
        flat_mlp_params[name] = weights
    
    return log_noise_scale, shape, inflated_loc_probs, flat_mlp_params
import torch.nn.functional as F
from pyro.distributions import (
    Normal, NegativeBinomial, ZeroInflatedNegativeBinomial, Independent
)

from pyro.distributions import ZeroInflatedNegativeBinomial

def make_likelihood_model(params, x, mlp,mlp_template, distribution):
    if distribution == LikelihoodDist.NORMAL:
        if isinstance(params[0], torch.Tensor):
            first_length = params[0].shape[0] if params[0].ndim > 0 else 1
        else:
            first_length = len(params[0]) if hasattr(params[0], '__len__') else 1
        if first_length>1:
            log_noise_scale = params[0].mean()
        else:
            log_noise_scale = params[0]
        model_state_dict = mlp.state_dict()
        new_params_values = params[3:]
        model_state_dict.update(zip(model_state_dict.keys(), new_params_values))
        mlp.load_state_dict(model_state_dict)
        predictions = mlp(x)
        noise_scale = 0.01 + torch.exp(torch.tensor(log_noise_scale,dtype=torch.float32))
        return torch.distributions.Independent(
            torch.distributions.Normal(predictions, noise_scale), 1
        )

    elif distribution == LikelihoodDist.NB:
        if isinstance(params[0], torch.Tensor):
            first_length = params[0].shape[0] if params[0].ndim > 0 else 1
        else:
            first_length = len(params[0]) if hasattr(params[0], '__len__') else 1
        if first_length>1:
            log_noise_scale = params[0].mean()
        else:
            log_noise_scale = params[0]
        model_state_dict = mlp.state_dict()
        new_params_values = params[3:]
        model_state_dict.update(zip(model_state_dict.keys(), new_params_values))
        mlp.load_state_dict(model_state_dict)
        predictions = mlp(x)
        mean = torch.nn.functional.softplus(predictions)
        shape = torch.nn.functional.softplus(params[1])

        neg_binomial = torch.distributions.NegativeBinomial(
            total_count=1 / shape, logits=-torch.log(shape) - torch.log(mean)
        )
        return torch.distributions.Independent(neg_binomial, 1)

    elif distribution == LikelihoodDist.ZINB:
        if isinstance(params[0], torch.Tensor):
            first_length = params[0].shape[0] if params[0].ndim > 0 else 1
        else:
            first_length = len(params[0]) if hasattr(params[0], '__len__') else 1
        if first_length>1:
            log_noise_scale = params[0].mean()
        else:
            log_noise_scale = params[0]
        model_state_dict = mlp.state_dict()
        new_params_values = params[3:]
        model_state_dict.update(zip(model_state_dict.keys(), new_params_values))
        mlp.load_state_dict(model_state_dict)
        predictions = mlp(x)
        mean = torch.nn.functional.softplus(predictions)
        shape = torch.nn.functional.softplus(params[1])
        inflated_probs = torch.sigmoid(params[2])

        zinb_dist = ZeroInflatedNegativeBinomial(
            total_count=1 / shape, logits=-torch.log(shape) - torch.log(mean),
            gate=inflated_probs
        )
        return torch.distributions.Independent(zinb_dist, 1)

    else:
        raise ValueError(f"Unknown likelihood distribution: {distribution}")
import torch
import torch.nn as nn



import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class BayesianNeuralField1D(nn.Module):
    def __init__(self, width, depth, input_scales, fourier_degrees, interactions,
                 num_seasonal_harmonics=None, seasonality_periods=None):
        super().__init__()

        self.width = width
        self.depth = depth
        self.input_scales = torch.tensor(input_scales, dtype=torch.float32)
        self.fourier_degrees = torch.tensor(fourier_degrees, dtype=torch.int64)
        self.interactions = torch.tensor(interactions, dtype=torch.long)
        
        self.num_seasonal_harmonics = torch.tensor(
            num_seasonal_harmonics if num_seasonal_harmonics is not None else [],
            dtype=torch.float32
        )
        self.seasonality_periods = torch.tensor(
            seasonality_periods if seasonality_periods is not None else [],
            dtype=torch.float32
        )

        self.layer_scales = nn.ParameterList([
            nn.Parameter(torch.randn(())) for _ in range(depth + 1)
        ])
        self.hidden_layers = nn.ModuleList()
        
        self.feature_scales = nn.ParameterList()

    def make_layer_scale(self, layer_id):
        """Get the softplus-transformed scale parameter for a specific layer."""
        return F.softplus(self.layer_scales[layer_id])
        
    def activation_fn(self, x):
        activation_weight = torch.sigmoid(self.logit_activation_weight)
        return activation_weight * F.elu(x) + (1 - activation_weight) * torch.tanh(x)

    def forward(self, x):
        if len(x.shape) == 1:
            x = x.unsqueeze(-1)
        x=torch.tensor(x,dtype=torch.float32)
        self.log_scale_adjustment = nn.Parameter(torch.randn(x.shape[-1:], dtype=torch.float32))
        
        scaled_x = x / (self.input_scales * torch.exp(self.log_scale_adjustment))
        
        seasonal_features = make_seasonal_features(
            x[..., 0], self.seasonality_periods, self.num_seasonal_harmonics, rescale=False
        )

        fourier_features = [
            make_fourier_features(scaled_x[..., i], degree, rescale=False)
            for i, degree in enumerate(self.fourier_degrees) if degree > 0
        ]

        interaction_features = torch.prod(scaled_x[..., self.interactions], dim=-1, keepdim=True)
        features = [scaled_x, *fourier_features, seasonal_features, interaction_features]

        if len(self.feature_scales) == 0:
            for _ in range(len(features)):
                self.feature_scales.append(nn.Parameter(torch.randn(())))

        features = [
            f * F.softplus(self.feature_scales[i]) for i, f in enumerate(features) if f.numel() > 0
        ]
        self.logit_activation_weight = nn.Parameter(torch.randn(1))

        h = torch.cat(features, dim=-1)
        if len(self.hidden_layers) == 0:
            input_dim = h.shape[-1]
            self.hidden_layers = nn.ModuleList([
                nn.Linear(self.width if i > 0 else input_dim, self.width) 
                for i in range(self.depth)
            ])
            self.output_layer = nn.Linear(self.width, 1)

        
        for layer_id, layer in enumerate(self.hidden_layers):
            layer_scale = self.make_layer_scale(layer_id)
            h = h / torch.sqrt(torch.tensor(h.shape[-1], dtype=torch.float32))
            h = self.activation_fn(layer_scale * layer(h))

        self.output_layer = nn.Linear(self.width, 1)
        output_scale = self.make_layer_scale(self.depth)
        h = h / torch.sqrt(torch.tensor(h.shape[-1], dtype=torch.float32))
        return output_scale * self.output_layer(h).squeeze(-1)