In [1]:
import torch
import torch.nn as nn

import numpy as np

In [5]:
class CKYParser(nn.Module):
    def __init__(self, T, NT, MAX_SEQ_LENGTH):
        super(CKYParser).__init__()
        self.T = T
        self.NT = NT
        self.S = T + NT
        self.MAX_SEQ_LENGTH = MAX_SEQ_LENGTH
        self.CKY_table = torch.zeros((self.S, MAX_SEQ_LENGTH, MAX_SEQ_LENGTH))
        
    def forward(self, grammar, sequence):
        S, l = self.S, len(sequence)
        self.CKY_table.fill_(-float('inf'))  # Reset with -inf (log-space)
        
        # ===== Span 1: Terminal Rules (A → word) =====
        word_indices = sequence  # shape: (l,)
        span1_probs = grammar[self.T:, word_indices, 0]  # shape: (NT, l)
        diag_indices = torch.arange(l)
        self.CKY_table[self.T:, diag_indices, diag_indices] = span1_probs

        # ===== Spans 2 to l: Vectorized CKY =====
        # Create a buffer to store intermediate split probabilities
        Buffer = torch.full((S, S, S, l, l), -float('inf'), device=self.CKY_table.device)

        # Generate all (i, j, k) triplets where i < k < j
        for span in range(2, l + 1):
            i = torch.arange(l - span + 1)
            j = i + span  # j = i + span - 1 (adjusted for 0-based indexing)
            
            # Generate all valid k where i < k < j
            k = (i.unsqueeze(1) + torch.arange(1, span).unsqueeze(0))  # shape: (i, k)
            k = k.clamp(max=l - 1)  # Ensure k stays within bounds
            
            # Left subtree: s[i..k] → B
            left_probs = self.CKY_table[:, i.unsqueeze(1), k]  # [S, i, k]
            
            # Right subtree: s[k+1..j] → C
            k_plus_1 = (k + 1).clamp(max=l - 1)  # Prevent k+1 from exceeding l-1
            right_probs = self.CKY_table[:, k_plus_1, j.unsqueeze(1)]  # [S, k, j]
            
            # Combine: P(B) + P(C) + P(A → B C)
            prob = (
                left_probs.unsqueeze(2) +  # [S, i, k, 1]
                right_probs.unsqueeze(1) +  # [S, 1, k, j]
                grammar.unsqueeze(-1).unsqueeze(-1)  # [S, S, S, 1, 1]
            )
            
            # Aggregate over (B, C) for each A
            self.CKY_table[self.T:, i, j] = torch.logsumexp(
                prob[self.T:], dim=(1, 2, 3)  # [NT, i, j]
            )

        return self.CKY_table
            

In [6]:
parser = CKYParser(34, 16, 512)

In [None]:
parser.forward(torch.zeros((50, 50, 50)), torch.zeros((512), dtype=int))