In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SemanticReasoner(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_concepts, reasoning_steps=3):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.reasoning_steps = reasoning_steps

        # External semantic knowledge base: concepts in embedding space
        self.knowledge_bank = nn.Parameter(torch.randn(num_concepts, embedding_dim))

        # A learnable module that fuses original and matched knowledge into updates
        self.update_layer = nn.Sequential(
            nn.Linear(embedding_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

    def forward(self, sentence_tensor):
        """
        sentence_tensor: [batch_size, seq_len, embedding_dim]
        """
        B, T, D = sentence_tensor.size()
        x = sentence_tensor  # current sentence representation

        for step in range(self.reasoning_steps):
            # Flatten: [B*T, D]
            flat_x = x.view(B * T, D)

            # Normalize for cosine similarity
            norm_x = F.normalize(flat_x, dim=1)
            norm_kb = F.normalize(self.knowledge_bank, dim=1)

            # Cosine similarity: [B*T, num_concepts]
            sim_scores = torch.matmul(norm_x, norm_kb.t())

            # Attention over the knowledge bank
            attention_weights = F.softmax(sim_scores, dim=1)
            matched_knowledge = torch.matmul(attention_weights, self.knowledge_bank)  # [B*T, D]

            # Fuse input and matched concept
            combined = torch.cat([flat_x, matched_knowledge], dim=1)  # [B*T, 2D]
            updated_flat = self.update_layer(combined)  # [B*T, D]

            # Reshape back to [B, T, D]
            x = updated_flat.view(B, T, D)

        return x  # Final refined tensor


In [3]:
if __name__ == "__main__":
    # Hyperparameters
    batch_size = 2
    seq_len = 6
    embedding_dim = 64
    hidden_dim = 128
    num_concepts = 100
    reasoning_steps = 4

    # Create model
    model = SemanticReasoner(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        num_concepts=num_concepts,
        reasoning_steps=reasoning_steps
    )

    # Dummy sentence tensor: [batch_size, seq_len, embedding_dim]
    sentence_tensor = torch.randn(batch_size, seq_len, embedding_dim)

    # Forward pass
    output_tensor = model(sentence_tensor)

    print("Input shape: ", sentence_tensor.shape)
    print("Output shape:", output_tensor.shape)

Input shape:  torch.Size([2, 6, 64])
Output shape: torch.Size([2, 6, 64])
