In [None]:
# default_exp utils_blitz

# uitls_blitz

> API details.

In [None]:
#export
#hide
from blitz.modules import BayesianLinear
from blitz.modules import BayesianEmbedding, BayesianConv1d
from blitz.modules.base_bayesian_module import BayesianModule
from torch import nn
import torch

In [None]:
#export
def convert_layer_to_bayesian(layer, config: dict):
    if isinstance(layer, torch.nn.Linear):
        new_layer = BayesianLinear(
            layer.in_features,
            layer.out_features,
            prior_sigma_1=config["prior_sigma_1"],
            prior_sigma_2=config["prior_sigma_2"],
            prior_pi=config["prior_pi"],
            posterior_mu_init=config["posterior_mu_init"],
            posterior_rho_init=config["posterior_rho_init"],
        )
    elif isinstance(layer, nn.Embedding):
        new_layer = BayesianEmbedding(layer.num_embeddings, layer.embedding_dim)
    elif isinstance(layer, nn.Conv1d):
        new_layer = BayesianConv1d(
            layer.in_channels,
            layer.out_channels,
            kernel_size=layer.kernel_size[0],
            groups=layer.groups,
            padding=layer.padding,
            dilation=layer.dilation,
        )
    else:
        Warning(
            f"Could not find correct type for conversion of layer {layer} with type {type(layer)}"
        )
        new_layer = layer

    return new_layer

In [None]:
#export
def convert_to_bayesian_model(model, config: dict):
    for p in model.named_children():
        cur_layer_name = p[0]
        cur_layer = p[1]
        if len(list(cur_layer.named_children())) > 0:
            convert_to_bayesian_model(cur_layer, config)
        elif not isinstance(cur_layer, BayesianModule):
            new_layer = convert_layer_to_bayesian(cur_layer, config)
            setattr(model, cur_layer_name, new_layer)

    return model


In [None]:
#export
def set_train_mode(model, mode):
    if isinstance(model, BayesianModule):
        model.freeze = not mode

    for module in model.children():
        set_train_mode(module, mode)