In [12]:
import torch
import logging
from typing import *

In [13]:
logging.getLogger().setLevel(logging.DEBUG)

In [14]:
base_val = torch.as_tensor([
    [ #false
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
    ], [ #succ
        [0,1,0,0,0],
        [0,0,1,0,0],
        [0,0,0,1,0],
        [0,0,0,0,1],
        [0,0,0,0,0]
    ], [ #p
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
    ], [ #q
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
        [0,0,0,0,0],
    ]
], dtype=torch.float)

In [15]:

class Rulebook(NamedTuple):
    body_predicates : torch.Tensor
    variable_choices : torch.Tensor

    def to(self, device : torch.device):
        self.body_predicates = self.body_predicates.to(device)
        self.variable_choices = self.variable_choices.to(device)


In [16]:
body_predicates = torch.as_tensor([[0, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
                                    [0, 1, 1, 1, 0, 0, 2, 1, 0, 0, 2]])
variable_choices = torch.as_tensor([[0, 1, 2, 3, 4, 5, 6, 7 ,0, 0, 2],
                                    [0, 1, 7, 3, 3, 2, 6, 7 ,0, 7, 5]])


In [17]:
rulebook = Rulebook(
        body_predicates=torch.as_tensor([
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[1,1],[1,1],[1,1],[1,1],[3,3],[2,0],[1,1],[0,0],[1,2],[2,1],[2,2]] for _ in range(0,2)],
            [[[1,1],[1,1],[1,1],[1,1],[3,3],[2,0],[1,1],[0,0],[1,2],[2,1],[2,2]] for _ in range(0,2)],
        ]),
        variable_choices=torch.as_tensor([
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[0,0] for _ in range(0,11)] for _ in range(0,2)],
            [[[0,0],[1,0],[2,7],[5,4],[3,3],[2,0],[1,1],[0,0],[1,2],[2,1],[2,2]] for _ in range(0,2)],
            [[[0,0],[1,0],[2,7],[5,4],[3,3],[2,0],[1,1],[0,0],[1,2],[2,1],[2,2]] for _ in range(0,2)],
        ])
    )

In [18]:
pred_names = ['false', 'succ', 'p', 'q']
def var_choices(n : int, vars : int = 3) -> List[int]:
    return [int(n) // vars, n % vars]
def rule_str(rs : List[int], predicate : int, rulebook : Rulebook) -> str:
    lines = []
    for clause in range(0, rulebook.body_predicates.shape[1]):
        ret = []
        for i in range(0, rulebook.body_predicates.shape[3]):
            vs = ','.join(map(lambda v: chr(ord('A')+v),  var_choices(rulebook.variable_choices[predicate,clause,rs[clause],i])))
            ret.append(f'{pred_names[rulebook.body_predicates[predicate,clause,rs[clause],i]]}({vs})')
        lines.append(f"{pred_names[predicate]}(A, B) :- {','.join(ret)}")
    return '\n'.join(lines)

In [19]:
print(rule_str([2,3], predicate=2, rulebook = rulebook))

p(A, B) :- false(A,C),false(C,B)
p(A, B) :- false(B,C),false(B,B)


In [20]:
def extend_val(val : torch.Tensor, vars : int = 3) -> torch.Tensor:
    i = 0
    ret = []
    shape = list(val.shape) + [val.shape[-1] for _ in range(0, vars - 2)]
    valt = val.transpose(1, 2)
    for arg1 in range(0, vars):
        for arg2 in range(0, vars):
            v = val
            if arg1 == arg2:
                v = v.diagonal(dim1=1,dim2=2)
                for _ in range(0, arg1):
                    v = v.unsqueeze(1)
                while len(v.shape) < vars+1:
                    v = v.unsqueeze(-1)
            else:
                if arg2 < arg1:
                    v = valt
                unused = (x for x in range(0, vars) if x not in {arg1, arg2})
                for u in unused:
                    v = v.unsqueeze(u + 1)
            logging.debug(f"{i=} {arg1=} {arg2=} {v.shape=}")
            v = torch.broadcast_to(v, shape) #type: ignore
            v = v.unsqueeze(1)
            ret.append(v)
            i += 1
    return torch.cat(ret, dim=1)

In [21]:
val2 = extend_val(base_val)
val2.shape

DEBUG:root:i=0 arg1=0 arg2=0 v.shape=torch.Size([4, 5, 1, 1])
DEBUG:root:i=1 arg1=0 arg2=1 v.shape=torch.Size([4, 5, 5, 1])
DEBUG:root:i=2 arg1=0 arg2=2 v.shape=torch.Size([4, 5, 1, 5])
DEBUG:root:i=3 arg1=1 arg2=0 v.shape=torch.Size([4, 5, 5, 1])
DEBUG:root:i=4 arg1=1 arg2=1 v.shape=torch.Size([4, 1, 5, 1])
DEBUG:root:i=5 arg1=1 arg2=2 v.shape=torch.Size([4, 1, 5, 5])
DEBUG:root:i=6 arg1=2 arg2=0 v.shape=torch.Size([4, 5, 1, 5])
DEBUG:root:i=7 arg1=2 arg2=1 v.shape=torch.Size([4, 1, 5, 5])
DEBUG:root:i=8 arg1=2 arg2=2 v.shape=torch.Size([4, 1, 1, 5])


torch.Size([4, 9, 5, 5, 5])

In [22]:
def disjuction2(a : torch.Tensor, b : torch.Tensor) -> torch.Tensor:
    return 1 - (1 - a) * (1 - b)

def disjunction_dim(a : torch.Tensor, dim : int = -1) -> torch.Tensor:
    return 1 - ((1 - a).prod(dim = dim))

In [23]:

def conjuction2(a : torch.Tensor, b : torch.Tensor) -> torch.Tensor:
    return a * b
def conjunction_dim(a : torch.Tensor, dim : int = -1) -> torch.Tensor:
    return a.prod(dim=dim)

In [24]:
disjuction2(0, 1)

1

In [25]:
def infer_single_step(ex_val : torch.Tensor, rules : Rulebook, rule_weights : torch.Tensor) -> torch.Tensor:
    logging.debug(f"{ex_val.shape=} {rules.body_predicates.shape=} {rules.variable_choices.shape=}")
    ex_val = ex_val[rules.body_predicates, rules.variable_choices]
    logging.debug(f"{ex_val.shape=}")
    #conjuction of body predictes
    ex_val = conjunction_dim(ex_val, dim = 3)
    #existential quantification
    ex_val = disjunction_dim(ex_val, dim = -1)
    #rule weighing
    rule_weights = rule_weights.softmax(-1).unsqueeze(-1).unsqueeze(-1)
    ex_val = ex_val * rule_weights
    ex_val = ex_val.sum(dim = 2)
    #disjunction on clauses
    ex_val = disjunction_dim(ex_val, dim = 1)
    return ex_val

In [27]:
val = base_val
weights = torch.rand(4, 2, 11)
        #predicate, clause, rule
for _ in range(0,5):
    val2 = extend_val(val)
    val2 = infer_single_step(val2, rulebook, weights)
    val = disjuction2(val, val2)
    del val2

DEBUG:root:i=0 arg1=0 arg2=0 v.shape=torch.Size([4, 5, 1, 1])
DEBUG:root:i=1 arg1=0 arg2=1 v.shape=torch.Size([4, 5, 5, 1])
DEBUG:root:i=2 arg1=0 arg2=2 v.shape=torch.Size([4, 5, 1, 5])
DEBUG:root:i=3 arg1=1 arg2=0 v.shape=torch.Size([4, 5, 5, 1])
DEBUG:root:i=4 arg1=1 arg2=1 v.shape=torch.Size([4, 1, 5, 1])
DEBUG:root:i=5 arg1=1 arg2=2 v.shape=torch.Size([4, 1, 5, 5])
DEBUG:root:i=6 arg1=2 arg2=0 v.shape=torch.Size([4, 5, 1, 5])
DEBUG:root:i=7 arg1=2 arg2=1 v.shape=torch.Size([4, 1, 5, 5])
DEBUG:root:i=8 arg1=2 arg2=2 v.shape=torch.Size([4, 1, 1, 5])


IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [6, 2, 11, 2], [4, 2, 11, 2]

In [None]:
val.shape

torch.Size([4, 5, 5])

In [None]:
targets = torch.as_tensor([
        [2,0,2],
        [2,1,3],
        [2,2,4],
        [2,1,1],
        [2,3,2],
        [2,0,0],
        [2,0,1]
    ])
target_values = torch.as_tensor([
        1.0,
        1.0,
        1.0,
        0.0,
        0.0,
        0.0,
        0.0,])


In [None]:
torch.index_select(val, 0, targets)

IndexError: index_select(): Index is supposed to be a vector

In [None]:
val[targets[:,0],targets[:,1],targets[:,2]]

tensor([0.9647, 0.9756, 0.9717, 0.9756, 0.9943, 0.6638, 0.9871])