查看使用SRNet-Modular得出的表达式

In [1]:
import os
import torch
import sympy as sp
from utils import load_pickle
from load_data import load_sr_data
from SRNet import srnets
from neural_network import neural_networks

def load_srnet(path_dir, net, args):
    sr_param = torch.load(os.path.join(path_dir, "sr_param"))
    _, srnet_class = srnets[args.srnet_name]
    srnet = srnet_class(sr_param=sr_param, neural_network=net)
    srnet.load_state_dict(torch.load(os.path.join(path_dir, "srnet")))
    srnet.assign_genes(load_pickle(os.path.join(path_dir, "srnet_genes")))
    return srnet

event_dir = "./output/compared"
net_dir = "./output/MLPs"
datasets = ["kkk0","feynman0","529_pollen"]

for dataset in datasets:
    srnet_dir = os.path.join(event_dir, f"mlp-{dataset}")
    net_path = os.path.join(net_dir, f"MLP-{dataset}", "mlp", "mlp")
    args = torch.load(os.path.join(srnet_dir, "args"))
    train_dataset, _ = load_sr_data(args.dataset)
    n_inputs, n_outputs = train_dataset.X.size(-1), 1

    # load net
    nn_class = neural_networks["MLP"]
    net = nn_class(n_inputs, n_outputs)
    net.load_state_dict(torch.load(net_path))
    net.eval()

    # load sr
    srnet = load_srnet(srnet_dir, net, args)

    exps = [layer.expr(mul_w=False)[0] for layer in srnet.explained_layers]
    for i, expr in enumerate(exps):
        expr = sp.simplify(expr).evalf(4)
        exps[i] = str(expr.xreplace({n : round(n, 2) for n in expr.atoms(sp.Number)})).replace("**", "^")
    hidx = 1
    for i, expr in enumerate(exps[1:]):
        for xidx in range(16):
            if f"x{xidx}" in expr:
                exps[i+1] = exps[i+1].replace(f"x{xidx}", "(hat{h}^"+f"{hidx})_{xidx}")
        hidx += 1
    print(f'Dataset {dataset} Each Layer Expression:{exps}')
    final_expr = srnet.expr()[0]
    final_expr = sp.simplify(final_expr).evalf(4)
    final_expr = str(final_expr.xreplace({n : round(n, 2) for n in final_expr.atoms(sp.Number)})).replace("**", "^")
    print(f'Dataset {dataset} Final Expression:{final_expr}')