# Relational Concept Bottleneck Models

First of all we need to create a dataset. For this we will use a Hanoi dataset.

In [134]:
import sys
sys.path.append('..')
import torch

In [298]:
from datasets.hanoi import hanoi_toy_dataset
n_samples = 3
n_positions = 7
n_disks = 3
n_sizes = 4
random_seed = 42
fold = 1
X, labels_concepts, labels_tasks, q_names, tower_ids = hanoi_toy_dataset(
        mode='relational', n_samples=n_samples, random_seed=random_seed + fold,
        n_positions=n_positions, n_disks=n_disks, n_sizes=n_sizes)

We inspect the data:

In [299]:
tower_ids

[['0', '1', '2'], ['3', '4', '5'], ['6', '7', '8']]

In [300]:
q_names['concepts']

['top(0,1)',
 'top(0,2)',
 'top(1,0)',
 'top(1,2)',
 'top(2,0)',
 'top(2,1)',
 'top(3,4)',
 'top(3,5)',
 'top(4,3)',
 'top(4,5)',
 'top(5,3)',
 'top(5,4)',
 'top(6,7)',
 'top(6,8)',
 'top(7,6)',
 'top(7,8)',
 'top(8,6)',
 'top(8,7)',
 'larger(0,1)',
 'larger(0,2)',
 'larger(1,0)',
 'larger(1,2)',
 'larger(2,0)',
 'larger(2,1)',
 'larger(3,4)',
 'larger(3,5)',
 'larger(4,3)',
 'larger(4,5)',
 'larger(5,3)',
 'larger(5,4)',
 'larger(6,7)',
 'larger(6,8)',
 'larger(7,6)',
 'larger(7,8)',
 'larger(8,6)',
 'larger(8,7)']

In [301]:
q_names['tasks'][:10]

['correct(1)', 'correct(4)', 'correct(7)']

In [302]:
# these are input features (first feature: disk size, second feature: height)
print(X.shape)
X

torch.Size([1, 9, 2])


tensor([[[0.0149, 3.0075],
         [1.0172, 2.0045],
         [2.0811, 1.0612],
         [3.0011, 0.0556],
         [4.0468, 3.0650],
         [5.0955, 1.0581],
         [0.0653, 1.0874],
         [1.0621, 2.0555],
         [2.0404, 0.0429]]])

In [303]:
# these are concept labels to predict
print(labels_concepts.shape, len(q_names['concepts']))
labels_concepts, q_names['concepts']

torch.Size([1, 36]) 36


(tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1.,
          1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.]]),
 ['top(0,1)',
  'top(0,2)',
  'top(1,0)',
  'top(1,2)',
  'top(2,0)',
  'top(2,1)',
  'top(3,4)',
  'top(3,5)',
  'top(4,3)',
  'top(4,5)',
  'top(5,3)',
  'top(5,4)',
  'top(6,7)',
  'top(6,8)',
  'top(7,6)',
  'top(7,8)',
  'top(8,6)',
  'top(8,7)',
  'larger(0,1)',
  'larger(0,2)',
  'larger(1,0)',
  'larger(1,2)',
  'larger(2,0)',
  'larger(2,1)',
  'larger(3,4)',
  'larger(3,5)',
  'larger(4,3)',
  'larger(4,5)',
  'larger(5,3)',
  'larger(5,4)',
  'larger(6,7)',
  'larger(6,8)',
  'larger(7,6)',
  'larger(7,8)',
  'larger(8,6)',
  'larger(8,7)'])

In [304]:
# these are class labels to predict (same as concept labels!)
print(labels_tasks.shape, len(q_names['tasks']))
labels_tasks, q_names['tasks']

torch.Size([1, 3]) 3


(tensor([[1., 0., 0.]]), ['correct(1)', 'correct(4)', 'correct(7)'])

Create a domain of disks:

In [305]:
from rcbm.logic.commons import Domain
n_samples = X.shape[1]
n_features = X.shape[2]
disks = Domain("disks", [f'{i}' for i in torch.arange(X.shape[1]).tolist()])
print(f'Number of constants: {len(disks.constants)}')
n_samples, disks.constants

Number of constants: 9


(9, ['0', '1', '2', '3', '4', '5', '6', '7', '8'])

In [306]:
from rcbm.logic.commons import Rule
body = [
    'top(X,Y)', 'top(Y,X)', 'top(X,Z)', 'top(Z,X)', 'top(Y,Z)', 'top(Z,Y)',
    'larger(X,Y)', 'larger(Y,X)', 'larger(X,Z)', 'larger(Z,X)', 'larger(Y,Z)', 'larger(Z,Y)',
]
head = ["correct(X)"]
rule = Rule("phi", body=body, head=head, var2domain={"X": "disks", "Y": "disks", "Z": "disks"})

rule.body

[('top', 'X', 'Y'),
 ('top', 'Y', 'X'),
 ('top', 'X', 'Z'),
 ('top', 'Z', 'X'),
 ('top', 'Y', 'Z'),
 ('top', 'Z', 'Y'),
 ('larger', 'X', 'Y'),
 ('larger', 'Y', 'X'),
 ('larger', 'X', 'Z'),
 ('larger', 'Z', 'X'),
 ('larger', 'Y', 'Z'),
 ('larger', 'Z', 'Y')]

In [307]:
rule.head

[('correct', 'X')]

In [308]:
from rcbm.logic.grounding import DomainGrounder
grounder = DomainGrounder({"disks": disks.constants}, [rule], manifolds_per_rule={"phi": tower_ids})
groundings = grounder.ground()
groundings['phi'][0]
# notice we have 2 instances of ground rules for each disk!
# in fact we have 2 possible substitutions for each disk X: (Y,Z) or (Z,Y)

[((('correct', '2'),),
  (('top', '2', '1'),
   ('top', '1', '2'),
   ('top', '2', '0'),
   ('top', '0', '2'),
   ('top', '1', '0'),
   ('top', '0', '1'),
   ('larger', '2', '1'),
   ('larger', '1', '2'),
   ('larger', '2', '0'),
   ('larger', '0', '2'),
   ('larger', '1', '0'),
   ('larger', '0', '1'))),
 ((('correct', '7'),),
  (('top', '7', '8'),
   ('top', '8', '7'),
   ('top', '7', '6'),
   ('top', '6', '7'),
   ('top', '8', '6'),
   ('top', '6', '8'),
   ('larger', '7', '8'),
   ('larger', '8', '7'),
   ('larger', '7', '6'),
   ('larger', '6', '7'),
   ('larger', '8', '6'),
   ('larger', '6', '8'))),
 ((('correct', '3'),),
  (('top', '3', '5'),
   ('top', '5', '3'),
   ('top', '3', '4'),
   ('top', '4', '3'),
   ('top', '5', '4'),
   ('top', '4', '5'),
   ('larger', '3', '5'),
   ('larger', '5', '3'),
   ('larger', '3', '4'),
   ('larger', '4', '3'),
   ('larger', '5', '4'),
   ('larger', '4', '5'))),
 ((('correct', '6'),),
  (('top', '6', '8'),
   ('top', '8', '6'),
   ('top', '

In [309]:
from rcbm.logic.indexing import DictBasedIndexer
from rcbm.logic.semantics import GodelTNorm
logic = GodelTNorm()
indexer = DictBasedIndexer(grounder.ground(), q_names, logic=logic)
indexer.atom_index

OrderedDict([(('correct', '0'), 0),
             (('correct', '1'), 1),
             (('correct', '2'), 2),
             (('correct', '3'), 3),
             (('correct', '4'), 4),
             (('correct', '5'), 5),
             (('correct', '6'), 6),
             (('correct', '7'), 7),
             (('correct', '8'), 8),
             (('larger', '0', '1'), 9),
             (('larger', '0', '2'), 10),
             (('larger', '1', '0'), 11),
             (('larger', '1', '2'), 12),
             (('larger', '2', '0'), 13),
             (('larger', '2', '1'), 14),
             (('larger', '3', '4'), 15),
             (('larger', '3', '5'), 16),
             (('larger', '4', '3'), 17),
             (('larger', '4', '5'), 18),
             (('larger', '5', '3'), 19),
             (('larger', '5', '4'), 20),
             (('larger', '6', '7'), 21),
             (('larger', '6', '8'), 22),
             (('larger', '7', '6'), 23),
             (('larger', '7', '8'), 24),
             (('large

In [310]:
len(indexer.indexed_queries['concepts']), len(indexer.indexed_queries['tasks'])

(36, 3)

In [311]:
emb_size = 16
n_concepts = len(rule.body)
n_classes = len(rule.head)

encoder = torch.nn.Sequential(
    torch.nn.Linear(n_features, emb_size),
    torch.nn.LeakyReLU(),
)
relation_classifiers = {}
for relation_name, relation_arity in indexer.relations_arity.items():
    relation_classifiers[relation_name] = torch.nn.Sequential(
        torch.nn.Linear(emb_size * relation_arity, emb_size), # notice different input size depending on relation arity!
        torch.nn.LeakyReLU(),
        torch.nn.Linear(emb_size, 1),
        torch.nn.Sigmoid()
    )
reasoner = torch.nn.Sequential(
    torch.nn.Linear(n_concepts, emb_size),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(emb_size, n_classes),
    torch.nn.Sigmoid()
)
model = torch.nn.Sequential(encoder, *relation_classifiers.values(), reasoner)
model

Sequential(
  (0): Sequential(
    (0): Linear(in_features=2, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (1): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (2): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (3): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (4): Sequential(
    (0): Linear(in_features=12, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=16, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

## Training

In [312]:
embeddings = encoder(X.squeeze())
embeddings.shape

torch.Size([9, 16])

In [313]:
# relation/concept predictions
concept_predictions = indexer.predict_relations(encoders=relation_classifiers, embeddings=embeddings)
concept_predictions.shape, concept_predictions

(torch.Size([45, 1]),
 tensor([[0.5427],
         [0.5118],
         [0.4679],
         [0.4564],
         [0.4741],
         [0.4491],
         [0.5040],
         [0.5119],
         [0.4545],
         [0.5845],
         [0.5774],
         [0.5430],
         [0.5550],
         [0.5088],
         [0.5207],
         [0.5174],
         [0.5407],
         [0.5389],
         [0.5494],
         [0.5204],
         [0.5087],
         [0.5603],
         [0.5654],
         [0.5541],
         [0.5559],
         [0.5135],
         [0.5130],
         [0.4133],
         [0.3734],
         [0.4444],
         [0.3997],
         [0.4493],
         [0.4391],
         [0.4161],
         [0.3676],
         [0.3767],
         [0.3579],
         [0.3960],
         [0.4338],
         [0.4360],
         [0.3820],
         [0.4332],
         [0.3770],
         [0.4424],
         [0.4454]], grad_fn=<CatBackward0>))

In [340]:
c_preds = indexer.gather_and_concatenate(concept_predictions, indexer.indexed_queries["concepts"], 0)
c_preds

tensor([[0.4133],
        [0.3734],
        [0.4444],
        [0.3997],
        [0.4493],
        [0.4391],
        [0.4161],
        [0.3676],
        [0.3767],
        [0.3579],
        [0.3960],
        [0.4338],
        [0.4360],
        [0.3820],
        [0.4332],
        [0.3770],
        [0.4424],
        [0.4454],
        [0.5845],
        [0.5774],
        [0.5430],
        [0.5550],
        [0.5088],
        [0.5207],
        [0.5174],
        [0.5407],
        [0.5389],
        [0.5494],
        [0.5204],
        [0.5087],
        [0.5603],
        [0.5654],
        [0.5541],
        [0.5559],
        [0.5135],
        [0.5130]], grad_fn=<ViewBackward0>)

In [318]:
len(groundings['phi'][0]), len(groundings['phi'][0][0][1]), groundings['phi'][0][0]

(18,
 12,
 ((('correct', '2'),),
  (('top', '2', '1'),
   ('top', '1', '2'),
   ('top', '2', '0'),
   ('top', '0', '2'),
   ('top', '1', '0'),
   ('top', '0', '1'),
   ('larger', '2', '1'),
   ('larger', '1', '2'),
   ('larger', '2', '0'),
   ('larger', '0', '2'),
   ('larger', '1', '0'),
   ('larger', '0', '1'))))

In [319]:
concept_preds_gathered = concept_predictions[indexer.indexed_bodies['phi']].view(indexer.indexed_bodies['phi'].shape[0], -1 )
concept_preds_gathered.shape, concept_preds_gathered

(torch.Size([18, 12]),
 tensor([[0.4391, 0.3997, 0.4493, 0.3734, 0.4444, 0.4133, 0.5207, 0.5550, 0.5088,
          0.5774, 0.5430, 0.5845],
         [0.3770, 0.4454, 0.4332, 0.4360, 0.4424, 0.3820, 0.5559, 0.5130, 0.5541,
          0.5603, 0.5135, 0.5654],
         [0.3676, 0.3960, 0.4161, 0.3767, 0.4338, 0.3579, 0.5407, 0.5204, 0.5174,
          0.5389, 0.5087, 0.5494],
         [0.3820, 0.4424, 0.4360, 0.4332, 0.4454, 0.3770, 0.5654, 0.5135, 0.5603,
          0.5541, 0.5130, 0.5559],
         [0.4454, 0.3770, 0.4424, 0.3820, 0.4332, 0.4360, 0.5130, 0.5559, 0.5135,
          0.5654, 0.5541, 0.5603],
         [0.4133, 0.4444, 0.3734, 0.4493, 0.3997, 0.4391, 0.5845, 0.5430, 0.5774,
          0.5088, 0.5550, 0.5207],
         [0.3579, 0.4338, 0.3767, 0.4161, 0.3960, 0.3676, 0.5494, 0.5087, 0.5389,
          0.5174, 0.5204, 0.5407],
         [0.3767, 0.4161, 0.3579, 0.4338, 0.3676, 0.3960, 0.5389, 0.5174, 0.5494,
          0.5087, 0.5407, 0.5204],
         [0.3960, 0.3676, 0.4338, 0.3579,

In [329]:
grounding_preds = reasoner(concept_preds_gathered)
len(grounding_preds), len(indexer.indexed_heads['phi'].view(-1, 1)), indexer.indexed_heads['phi'].view(-1, 1), grounding_preds

(18,
 18,
 tensor([[2],
         [7],
         [3],
         [6],
         [8],
         [0],
         [4],
         [4],
         [5],
         [0],
         [6],
         [8],
         [3],
         [5],
         [2],
         [1],
         [1],
         [7]]),
 tensor([[0.4359],
         [0.4376],
         [0.4372],
         [0.4376],
         [0.4355],
         [0.4404],
         [0.4396],
         [0.4401],
         [0.4360],
         [0.4391],
         [0.4396],
         [0.4355],
         [0.4390],
         [0.4375],
         [0.4359],
         [0.4388],
         [0.4373],
         [0.4398]], grad_fn=<SigmoidBackward0>))

In [338]:
# this operation aggregates indicized predictions
# grouped_or = logic.disj_scatter(grounding_preds.view(-1, 1),
#                                 indexer.indexed_heads['phi'],
#                                 len(indexer.atom_index))
grounding_preds_agg = torch.zeros(len(indexer.atom_index), 1)
task_predictions = grounding_preds_agg.scatter_reduce(0, indexer.indexed_heads['phi'].view(-1, 1), grounding_preds.view(-1, 1), reduce='amax')
task_predictions.shape, grouped_or.shape, sum(grouped_or!=0), grouped_or

(torch.Size([45, 1]),
 torch.Size([45, 1]),
 tensor([9]),
 tensor([[0.4404],
         [0.4388],
         [0.4359],
         [0.4390],
         [0.4401],
         [0.4375],
         [0.4396],
         [0.4398],
         [0.4355],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000],
         [0.0000]], grad_fn=<ScatterReduceBackward0>))

In [339]:
y_preds = indexer.gather_and_concatenate(task_predictions, indexer.indexed_queries["tasks"], 0)
y_preds

tensor([[0.4388],
        [0.4401],
        [0.4398]], grad_fn=<ViewBackward0>)