In [1]:
import sys
sys.path.append('..')

import utils, selex_dca, indep_sites
import adabmDCA

  from tqdm.autonotebook import tqdm


In [2]:
import torch
from models.energy_model import EnergyModel

class IndepSites(EnergyModel):
    def __init__(
        self,
        h: torch.Tensor # Lxq tensor
    ):
        super().__init__()
        n_dim = h.dim()
        if n_dim != 2:
            raise ValueError(f"Expected tensor with 2 dimensions, got {n_dim}.")

        self.h = torch.nn.Parameter(h)

    def get_n_states(self):
        return self.h.size(1)

    def get_sequence_length(self):
        return self.h.size(0)

    def compute_energy(
        self,
        x: torch.Tensor
    ):

        L = self.get_sequence_length()
        q = self.get_n_states()
        x_flat = x.view(-1, L * q)
        bias_flat = self.h.view(L * q)

        return - x_flat @ bias_flat

    def set_zerosum_gauge(self):
        self.h = self.h - self.h.mean(dim=1, keepdim=True)

In [3]:
h = torch.Tensor([[1,2],[30,40]])

In [4]:
x = torch.Tensor([[1, 0],[0,1]])

In [5]:
model = IndepSites(h)

In [6]:
model.compute_energy(x)

tensor([-41.], grad_fn=<MvBackward0>)

In [7]:
import torch
from models.energy_model import EnergyModel

class Potts(EnergyModel):
    def __init__(
        self,
        J: torch.Tensor, # LxqxLxq tensor
        h: torch.Tensor, # Lxq tensor
    ):
        super().__init__()
        sz_h = h.size()
        sz_J = J.size()
        if len(sz_h) != 2:
            raise ValueError(f"Expected tensor with 2 dimensions, got {len(sz_h)}.")
        if len(sz_J) != 4:
            raise ValueError(f"Expected tensor with 2 dimensions, got {len(sz_J)}.")
        if not (sz_J[0:2] == sz_J[2:4] == sz_h):
            raise ValueError("Wrong tensor dimensions")
        
        self.h = torch.nn.Parameter(h)

        L, q = sz_h
        mask = torch.ones(L, q, L, q)
        mask[torch.arange(L), :, torch.arange(L), :] = 0
        # set the (i,i) blocks to zero
        J = J * mask
        self.J = torch.nn.Parameter(J)

    def get_n_states(self):
        return self.h.size(1)

    def get_sequence_length(self):
        return self.h.size(0)

    def compute_energy(
        self,
        x: torch.Tensor
    ):
        L = self.get_sequence_length()
        q = self.get_n_states()
        # the -1 accounts for possible batch index along dimension 0
        x_flat = x.view(-1, L * q)
        bias_flat = self.h.view(L * q)
        couplings_flat = self.J.reshape(L * q, L * q)
        bias_term = x_flat @ bias_flat
        coupling_term = torch.sum(x_flat * (x_flat @ couplings_flat), dim=1)
        return - bias_term - 0.5 * coupling_term

    def set_zerosum_gauge(self):
        h = self.h
        self.h = h - h.mean(dim=1, keepdim=True)
        J = self.J
        self.J -= J.mean(dim=1, keepdim=True) + \
                    J.mean(dim=3, keepdim=True) - \
                    J.mean(dim=(1, 3), keepdim=True)    

In [49]:
import torch
from models.energy_model import EnergyModel

# used as dummy for checks
class InfiniteEnergy(EnergyModel):
    def __init__(self):
        super().__init__()

    def compute_energy(
        self,
        x: torch.Tensor
    ):
        if x.dim() == 2:
            return torch.full((1,), torch.inf)
        elif x.dim() == 3:
            return torch.full((x.size(0),), torch.inf)
        else:
            raise ValueError(f"Expected tensor `x` of dimension either 2 or 3, got {x.dim()}")

In [82]:
import torch

class MultiModeDistribution(torch.nn.Module):
    def __init__(
        self,
        *modes,
        normalized: bool = True
    ):
        super().__init__()
        self.modes = torch.nn.ModuleList(modes)
        self.normalized = normalized

    def get_n_modes(self):
        return len(self.modes)

    def compute_logprobabilities(
        self,
        x: torch.Tensor,  # batch_size * L * q
    ) -> torch.Tensor: # batch_size * n_modes
        minusE = torch.Tensor()
        for mode in self.modes:
            minusEw = - mode.compute_energy(x)
            minusE = torch.cat((minusE, minusEw[:,None]), dim=1)
        logp = minusE 
        if self.normalized == True:
            logp -= minusE.logsumexp(dim=1, keepdim=True)
        return logp    

In [165]:
import torch

def store_ancestors(parent):
    N = parent.size(0)

    ancestors = []
    offset = torch.zeros(N, dtype=torch.long)
    length = torch.zeros(N, dtype=torch.long)

    cur_offset = 0

    for v in range(N):
        path = [v]
        p = parent[v]

        while p != -1:
            path.append(p)
            p = parent[p]

        offset[v] = cur_offset
        length[v] = len(path)
        ancestors.extend(path)
        cur_offset += len(path)

    ancestors_flat = torch.tensor(ancestors, dtype=torch.long)
    return ancestors_flat, offset, length

def ancestors_of(v, ancestors_flat, offset, length):
    return ancestors_flat[offset[v] : offset[v] + length[v]]


class Tree:
    def __init__(self,
                 parent: torch.Tensor | None = None,  # parent[v] is the index of v's parent. the root is -1
                 nodename = None, 
                ):
        if parent is None:
            parent = torch.IntTensor()
            
        assert((parent >= -1).all())
        if nodename is None:
            nodename = [str(v) for v in range(len(parent))]
        assert(len(nodename) == parent.size(0))
        self.parent = parent
        self.nodename = nodename
        self.ancestors_flat, self.offset, self.length = store_ancestors(parent)

    def ancestors_of(self, v):  # returns a torch vector with v and its ancestors
        assert(v < self.get_n_nodes()), f"v={v}"
        return self.ancestors_flat[self.offset[v] : self.offset[v] + self.length[v]]

    def get_n_nodes(self):   # excluding the root
        return len(self.parent)

    def get_depth(self):
        return self.length.max().item()

    def parent(self, v):
        assert(v < self.get_n_nodes())
        return self.parent[v]

    def add_node(self, parent_node, name = None):
        N = self.get_n_nodes
        if name is None:
            name = str(N)
        if type(parent_node) is str:
            if parent_node == "root":
                parent_node = -1
            else:
                parent_node = self.nodename.index(parent_node)
            
        self.parent = torch.cat((self.parent, torch.IntTensor([parent_node])))
        self.nodename.append(name)
        self.ancestors_flat, self.offset, self.length = store_ancestors(self.parent)  # re-computing everything, not efficient but ok  
        

class RoundTree:
    def __init__(
        self,
        n_modes: int | None = None,
        tree: Tree | None = None,  
        selected_modes: torch.BoolTensor | None = None,   # (n_rounds * n_modes) modes selected for at each round
    ):
        if n_modes is None:
            if selected_modes is None:
                raise ValueError("Must provide either the selected modes or the total number of modes")
            else:
                n_modes = selected_modes.size(1)
        if tree is None:
            tree = Tree()
        if selected_modes is None:
            selected_modes = torch.BoolTensor()
            
        n_rounds = tree.parent.size(0)
        assert(selected_modes.size(0) == n_rounds)

        self.n_modes = n_modes
        self.tree = tree
        self.selected_modes = selected_modes   

    def add_node(self, parent_node, selected_modes, name = None):
        assert selected_modes.size(0) == self.n_modes, f"Number of modes in `selected_modes`, {selected_modes.size(0)}, different from the expected {self.n_modes}"
        self.tree.add_node(parent_node, name=name)
        self.selected_modes = torch.cat((self.selected_modes, selected_modes[None,:]))

    def parent(self, t):
        return self.tree.parent(t)

    def ancestors_of(self, v):
        return self.tree.ancestors_of(v)

    def get_n_modes(self):
        return self.n_modes

In [158]:
import torch

class MultiRoundDistribution(torch.nn.Module):
     def __init__(
        self,
        round_zero: EnergyModel,
        selection: MultiModeDistribution,
        round_tree: RoundTree
     ):
        if selection.get_n_modes() != round_tree.get_n_modes():
            raise ValueError(f"Number of modes must coincide for selection probability, got {selection.get_n_modes()} and {round_tree.get_n_modes()}")
        super().__init__()
        self.round_zero = round_zero
        self.selection = selection
        self.round_tree = round_tree

     # compute $\sum_{\tau \in \mathcal A(t)} \log p_{s,\tau}
     def selection_energy_up_to_round(self, x, t):
        if t == -1:
            return torch.zeros(x.size(0))
        ancestors = self.round_tree.ancestors_of(t)
        logps_modes = self.selection.compute_logprobabilities(x)
        ancestors = self.round_tree.ancestors_of(t)
        selected = self.round_tree.selected_modes[ancestors]
        # first pick only the selected rounds, then (log)sum(exp) over modes, then sum over rounds
        return - (logps_modes[:,None,:] + torch.log(selected)).logsumexp(dim=-1).sum(1)

     # compute sum_tau log p_{s,tau}
     @torch.compile
     def compute_energy_up_to_round(self, x, t):
         logNs0 = - self.round_zero.compute_energy(x)
         logps = - self.selection_energy_up_to_round(x, t)
         return - (logps + logNs0)

In [139]:
round_tree = RoundTree(n_modes=3)
round_tree.add_node(-1, torch.BoolTensor([1, 0, 1]))
round_tree.add_node(0, torch.BoolTensor([1, 1, 0]))
round_tree.add_node(-1, torch.BoolTensor([0, 0, 0]))

round_tree.selected_modes

tensor([[ True, False,  True],
        [ True,  True, False],
        [False, False, False]])

In [140]:
q = 4
L = 20

k = torch.randn(L, q)

n_modes = 3
potts_models = [
    Potts(torch.randn(L, q, L, q), torch.randn(L, q))
          for _ in range(n_modes)]

Ns0 = IndepSites(k)
ps = MultiModeDistribution(*potts_models)
Nst = MultiRoundDistribution(Ns0, ps, round_tree)

In [141]:
from adabmDCA.functional import one_hot
M = 5
x_ = torch.randint(q, (M, L))
x = one_hot(x_, num_classes=q)

In [142]:
t = -1
en = Nst.compute_energy_up_to_round(x, t)
avg_en = en.mean()
avg_en.backward()

tensor([ 4.9829,  8.4547, -0.9469,  1.4852,  8.2614],
       grad_fn=<CompiledFunctionBackward>)
tensor([0., 0., 0., 0., 0.])


In [143]:
def perceptron(L):
    return torch.nn.Sequential(
        torch.nn.Linear(L, 1),
        torch.torch.nn.Softmax()
    )

nns = [perceptron(L) for _ in range(n_modes)]
ps = MultiModeDistribution(*nns)
Nst = MultiRoundDistribution(Ns0, ps, round_tree)

In [144]:
en = Nst.compute_energy_up_to_round(x, t)
avg_en = en.mean()
avg_en.backward()

tensor([ 4.9829,  8.4547, -0.9469,  1.4852,  8.2614],
       grad_fn=<CompiledFunctionBackward>)
tensor([0., 0., 0., 0., 0.])


In [167]:
q = 4
L = 45

k = torch.randn(L, q)
h = torch.randn(L, q)
J = torch.randn(L, q, L, q)

round_tree = RoundTree(n_modes=2)
round_tree.add_node(-1, torch.BoolTensor([1, 0]))
round_tree.add_node(0, torch.BoolTensor([1, 0]))

Ns0 = IndepSites(k)
potts = Potts(J, h)
ps = MultiModeDistribution(potts, InfiniteEnergy(), normalized=False)
Nst = MultiRoundDistribution(Ns0, ps, round_tree)

In [177]:
M = 50
x_ = torch.randint(q, (M, L))
x = one_hot(x_, num_classes=q)

params = {"bias_Ns0": k, "bias_ps": h, "couplings_ps": potts.J}

for t in range(0):
    params_t = selex_dca.get_params_at_round(params, t)
    en_adabm = adabmDCA.statmech.compute_energy(x, params_t)
    en_class = Nst.compute_energy_up_to_round(x, t-1)
    assert torch.allclose(en_adabm, en_class)

In [180]:
t = 1
params_t = selex_dca.get_params_at_round(params, t)
en_adabm = adabmDCA.statmech.compute_energy(x, params_t)
en_class = Nst.compute_energy_up_to_round(x, t-1)
en_adabm - en_class

tensor([-2.8610e-06,  0.0000e+00,  3.8147e-06,  1.4305e-06,  1.9073e-06,
         2.8610e-06,  0.0000e+00,  3.8147e-06,  3.8147e-06, -9.5367e-07,
         1.9073e-06,  2.8610e-06,  0.0000e+00,  3.8147e-06,  3.8147e-06,
        -1.4305e-06, -3.8147e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        -9.5367e-07,  0.0000e+00, -4.7684e-07,  0.0000e+00,  0.0000e+00,
        -8.5831e-06,  1.9073e-06,  1.9073e-06, -1.9073e-06,  2.1458e-06,
         1.9073e-06,  0.0000e+00,  3.8147e-06,  1.6689e-06,  0.0000e+00,
         2.3842e-07, -1.9073e-06,  1.9073e-06,  0.0000e+00,  3.8147e-06,
        -1.9073e-06, -4.7684e-07, -1.9073e-06,  2.8610e-06, -1.9073e-06,
         7.6294e-06, -1.9073e-06,  1.9073e-06, -1.9073e-06, -3.8147e-06],
       grad_fn=<SubBackward0>)