In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import requests
import json


### TENSOR LOGIC Experiment 2 : REASONING WITH DOT PRODUCTS
### Aligned with Section 5: 'Reasoning in Embedding Space' 
### Dataset: mledoze/countries (Public Domain) url: https://raw.githubusercontent.com/mledoze/countries/master/countries.json


#### LOAD & PROCESS DATA

In [2]:
url = "https://raw.githubusercontent.com/mledoze/countries/master/countries.json"
data = requests.get(url).json()
print(f"Downloaded data for {len(data)} countries.")

triples = []
entities = set()
relations = ["is_capital_of", "is_located_in"]

print("\nExtracting Facts...")
for entry in data:
    country = entry.get("name", {}).get("common", "")
    capital_list = entry.get("capital", [])
    capital = capital_list[0] if capital_list else None
    region = entry.get("region", "")

    if country and capital and region:
        entities.add(country)
        entities.add(capital)
        entities.add(region)
        # Fact 1: Capital -> Country
        triples.append((capital, "is_capital_of", country))
        # Fact 2: Country -> Region
        triples.append((country, "is_located_in", region))

# Index Maps
entity_list = sorted(list(entities))
ent_to_idx = {e: i for i, e in enumerate(entity_list)}
idx_to_ent = {i: e for i, e in enumerate(entity_list)}
rel_to_idx = {r: i for i, r in enumerate(relations)}

NUM_ENTITIES = len(entities)
NUM_RELATIONS = len(relations)

print(f"Entities: {NUM_ENTITIES}")
print(f"Facts: {len(triples)}")


Downloaded data for 250 countries.

Extracting Facts...
Entities: 489
Facts: 490


#### We model truth as the dot product between prediction and target

In [3]:
class TensorLogicDotProduct(nn.Module):
    def __init__(self, num_entities, embedding_dim=64):
        super().__init__()
        # 1. Entity Embeddings
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        
        # 2. Relation Matrices (Transformation Tensors)
        self.relation_matrices = nn.Parameter(torch.randn(NUM_RELATIONS, embedding_dim, embedding_dim))
        
        # Initialization
        nn.init.xavier_uniform_(self.relation_matrices)
        nn.init.xavier_uniform_(self.entity_embeddings.weight)
        
        # Normalize embeddings to unit vectors (prevents hubness)
        with torch.no_grad():
            self.entity_embeddings.weight.data = nn.functional.normalize(
                self.entity_embeddings.weight.data, p=2, dim=1
            )

    def forward(self, subject_indices, relation_indices):
        # S x R -> Predicted Object Vector
        subj_vecs = self.entity_embeddings(subject_indices)
        rel_mats = self.relation_matrices[relation_indices]
        
        # Tensor Contraction: Vector * Matrix
        # b=batch, i=input_dim, j=output_dim
        # Equivalent to: "Logic rules are essentially Einstein summation" [cite: 933]
        pred_vecs = torch.einsum('bi,bij->bj', subj_vecs, rel_mats)
        return pred_vecs

    def score_all(self, pred_vecs):
        # Calculate Dot Product of prediction against ALL entities
        # Result: (Batch, Num_Entities) scores
        all_emb = self.entity_embeddings.weight
        # This computes the similarity (Gram matrix row) for the batch
        scores = torch.matmul(pred_vecs, all_emb.T)
        return scores


#### TRAINING (CROSS ENTROPY)
#### Maximizing the dot product of the correct triple effectively maximizes the probability P(Object | Subject, Relation).

In [4]:
EMBEDDING_DIM = 64
model = TensorLogicDotProduct(NUM_ENTITIES, EMBEDDING_DIM)
optimizer = optim.Adam(model.parameters(), lr=0.005) 
criterion = nn.CrossEntropyLoss() # Softmax + NLL

# Prepare Data (Ensuring LongTensor for indices)
subjects = torch.tensor([ent_to_idx[t[0]] for t in triples], dtype=torch.long)
rels = torch.tensor([rel_to_idx[t[1]] for t in triples], dtype=torch.long)
objects = torch.tensor([ent_to_idx[t[2]] for t in triples], dtype=torch.long)

print("\nTRAINING (Optimizing Dot Product Probability)...")

epochs = 500
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # 1. Predict Vector
    pred_vecs = model(subjects, rels)
    
    # 2. Score against all entities (Dot Product)
    logits = model.score_all(pred_vecs)
    
    # 3. Loss
    loss = criterion(logits, objects)
    
    loss.backward()
    optimizer.step()
    
    # CRITICAL: Re-normalize embeddings after each update (prevents hubness)
    with torch.no_grad():
        model.entity_embeddings.weight.data = nn.functional.normalize(
            model.entity_embeddings.weight.data, p=2, dim=1
        )
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

print(f"Training Complete. Final Loss: {loss.item():.4f}")

# ========== VERIFICATION 1: Check all embeddings are unit vectors ==========
print("\n" + "="*60)
print("VERIFICATION: Embedding Norms (Hubness Check)")
print("="*60)
norms = torch.norm(model.entity_embeddings.weight, p=2, dim=1)
print(f"Min norm:  {norms.min():.6f}")
print(f"Max norm:  {norms.max():.6f}")
print(f"Mean norm: {norms.mean():.6f}")
print(f"Std norm:  {norms.std():.6f}")
print(f"\nAll norms ≈ 1.0? {torch.allclose(norms, torch.ones_like(norms), atol=1e-5)}")

# Show norm distribution for key entities
print("\nSample Entity Norms:")
sample_entities = ['Asia', 'Europe', 'Africa', 'Americas', 'Oceania', 'Antarctic']
for entity_name in sample_entities:
    if entity_name in ent_to_idx:
        idx = ent_to_idx[entity_name]
        norm = norms[idx].item()
        print(f"  {entity_name:12s}: {norm:.6f}")
print("="*60 + "\n")


TRAINING (Optimizing Dot Product Probability)...
Epoch 0: Loss = 6.1942
Epoch 100: Loss = 1.0020
Epoch 200: Loss = 0.1504
Epoch 300: Loss = 0.0595
Epoch 400: Loss = 0.0350
Training Complete. Final Loss: 0.0242

VERIFICATION: Embedding Norms (Hubness Check)
Min norm:  1.000000
Max norm:  1.000000
Mean norm: 1.000000
Std norm:  0.000000

All norms ≈ 1.0? True

Sample Entity Norms:
  Asia        : 1.000000
  Europe      : 1.000000
  Africa      : 1.000000
  Americas    : 1.000000
  Oceania     : 1.000000
  Antarctic   : 1.000000



#### EVALUATING COMPOSITIONALITY
#### Reasoning in embedding space can now be carried out by forward or backward chaining over the embedded rules

In [5]:
def test_inference_dot(city_name):
    if city_name not in ent_to_idx: return

    # 1. Get Embeddings & Matrices
    city_idx = torch.tensor([ent_to_idx[city_name]], dtype=torch.long)
    city_vec = model.entity_embeddings(city_idx) 
    
    M_CapitalOf = model.relation_matrices[rel_to_idx["is_capital_of"]] 
    M_LocatedIn = model.relation_matrices[rel_to_idx["is_located_in"]] 
    
    # 2. LOGICAL COMPOSITION (CHAINING)
    # City -> (CapitalOf) -> Country -> (LocatedIn) -> Region
    country_pred_vec = torch.einsum('bi,ij->bj', city_vec, M_CapitalOf)
    region_pred_vec = torch.einsum('bj,jk->bk', country_pred_vec, M_LocatedIn)
    
    # 3. DOT PRODUCT SEARCH (Maximum Inner Product)
    all_emb = model.entity_embeddings.weight
    # Dot product of Result vs All Entities
    scores = torch.matmul(region_pred_vec, all_emb.T).squeeze()
    
    # Get Top Predictions
    top_indices = torch.argsort(scores, descending=True)[:3]
    
    print(f"\nQuery: What continent is {city_name} in?")
    print(f"Method: ArgMax( Dot( Vec({city_name}) x M_cap x M_loc, All_Entities ) )")
    for i in top_indices:
        name = idx_to_ent[i.item()]
        score = scores[i].item()
        print(f"  - {name} (score: {score:.2f})")


#### Test on diverse cities


In [6]:
test_cities = ["Tokyo", "Berlin", "Cairo", "Lima", "Canberra", "New Delhi", "King Edward Point"]
for city in test_cities:
    test_inference_dot(city)


Query: What continent is Tokyo in?
Method: ArgMax( Dot( Vec(Tokyo) x M_cap x M_loc, All_Entities ) )
  - Asia (score: 147.50)
  - Europe (score: 82.79)
  - Alofi (score: 25.59)

Query: What continent is Berlin in?
Method: ArgMax( Dot( Vec(Berlin) x M_cap x M_loc, All_Entities ) )
  - Europe (score: 108.62)
  - Oceania (score: 50.09)
  - Asia (score: 36.31)

Query: What continent is Cairo in?
Method: ArgMax( Dot( Vec(Cairo) x M_cap x M_loc, All_Entities ) )
  - Africa (score: 127.38)
  - Oceania (score: 94.68)
  - Antarctic (score: 43.71)

Query: What continent is Lima in?
Method: ArgMax( Dot( Vec(Lima) x M_cap x M_loc, All_Entities ) )
  - Americas (score: 110.34)
  - Africa (score: 25.26)
  - Jamaica (score: 25.22)

Query: What continent is Canberra in?
Method: ArgMax( Dot( Vec(Canberra) x M_cap x M_loc, All_Entities ) )
  - Oceania (score: 178.73)
  - Europe (score: 100.39)
  - Africa (score: 98.36)

Query: What continent is New Delhi in?
Method: ArgMax( Dot( Vec(New Delhi) x M_cap 