In [839]:
import torch
from torch.distributions import Categorical

dims  = [3, 5, 7, 3]
keep_blocks = torch.tensor([1, 1, 0, 0], dtype=torch.bool)

B = 4
logits = torch.randn(B, sum(dims))

device = logits.device
dims_t = torch.as_tensor(dims, dtype=torch.long, device=device)
keep_blocks = keep_blocks.to(device)

# --- build per-category mask ---
# Start with: kept blocks -> all True, masked blocks -> all False
per_cat_mask = torch.repeat_interleave(keep_blocks, dims_t)  # [sum(dims)]

# For masked blocks, allow exactly one category (within-index = 0)
offsets = torch.nn.functional.pad(dims_t.cumsum(0), (1, 0))[:-1]  # start index of each block
K = int(dims_t.sum().item())

allow_one = torch.zeros(K, dtype=torch.bool, device=device)
allow_one[offsets[~keep_blocks]] = True  # keep the first category of each masked block

# Final mask: all cats from kept blocks, plus one cat from each masked block
per_cat_mask = per_cat_mask | allow_one  # [sum(dims)]

# Broadcast to batch if needed
while per_cat_mask.dim() < logits.dim():
    per_cat_mask = per_cat_mask.unsqueeze(0)
per_cat_mask = per_cat_mask.expand_as(logits)

# Mask logits and build the single Categorical
masked_logits = logits.masked_fill(~per_cat_mask, float('-inf'))
dist = Categorical(logits=masked_logits.flatten())

sample = dist.sample()

def unflatten(dims, sample):
    indexes = []
    reminder = sample
    for i, d in enumerate(dims):
        idx = reminder % d
        indexes.append((idx))
        reminder //= d
    return indexes

unflatten(dims, sample)


[tensor(0), tensor(0), tensor(0), tensor(0)]

In [None]:
import torch
from torch.distributions import Categorical

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# build all combos
dims = [3,4,5,6]
mask = [0,1,0,1]
keep_blocks = torch.tensor(mask, dtype=torch.bool)

B = 4
logits = torch.randn(B, sum(dims), device=device)

grids = torch.cartesian_prod(*[torch.arange(d, device=device) for d in dims])  # [315, 4]
# keep only combos with masked blocks == 0
mask_rows = torch.ones(len(dims), dtype=torch.bool, device=device)
mask_rows = ~keep_blocks
valid = (grids[:, mask_rows] == 0).all(dim=-1)
grids = grids[valid]  # [315', 4], here 315' = 3*5*1*1 = 15

# compute joint logits for each combo by summing per-block logits
offsets = torch.nn.functional.pad(torch.tensor(dims, device=device).cumsum(0), (1,0))[:-1]
pieces = []
for i, d in enumerate(dims):
    start = offsets[i]
    block_logits = logits[:, start:start+d]                      # [B, d]
    idx = grids[:, i].unsqueeze(0).expand(logits.size(0), -1)    # [B, 315']
    pieces.append(block_logits.gather(1, idx))                   # [B, 315']
joint_logits = sum(pieces)                                       # [B, 315']

joint_dist = Categorical(logits=joint_logits)                    # batched
joint_sample = joint_dist.sample()                               # [B], each in [0..315'-1]
# convert back to per-block indices:
idxs = grids[joint_sample]                                       # [B, 4], columns 2,3 are 0

print(idxs)

tensor([[0, 2, 0, 4],
        [0, 3, 0, 1],
        [0, 3, 0, 0],
        [0, 2, 0, 2]], device='cuda:0')


In [40]:
joint_sample

tensor([16, 19, 18, 14], device='cuda:0')

In [128]:
dist.sample()

tensor([[0, 2, 0, 0],
        [2, 3, 0, 0],
        [0, 4, 0, 0],
        [2, 3, 0, 0]])

In [136]:
a

tensor([[2, 3, 0, 0],
        [2, 1, 0, 0],
        [2, 4, 0, 0],
        [2, 0, 0, 0]])

In [None]:
import torch
from torch.distributions import Categorical

class MaskedFlatJointCat:
    """
    One Categorical over the joint (sum(dims)) with whole-block masking.
    - Keeps only categories from blocks where keep_blocks[i] is True.
    - Sampling is from a single Categorical over the KEPT categories.
    - Returned samples can be reshaped to per-block indices with masked blocks set to 0.
    - Log-probs are those of the single Categorical (proper renormalization over kept cats).
    """
    def __init__(self, logits, dims, keep_blocks):
        """
        logits: [..., sum(dims)]
        dims:   list/1D tensor of ints, e.g. [3,5,7,3]
        keep_blocks: 1D bool/0-1 tensor of shape [num_blocks]
        """
        self.logits = logits
        self.device = logits.device
        self.dtype  = logits.dtype

        self.dims = torch.as_tensor(dims, device=self.device, dtype=torch.long)
        self.num_blocks = self.dims.numel()

        keep = torch.as_tensor(keep_blocks, device=self.device, dtype=torch.bool)
        assert keep.shape == (self.num_blocks,), "keep_blocks must be [num_blocks]"
        self.keep = keep

        # Offsets so that full_flat = offsets[block] + within_block
        self.offsets = torch.nn.functional.pad(self.dims.cumsum(0), (1,0))[:-1]  # [num_blocks], exclusive cumsum

        # Per-category mask and kept indices in the FULL flat space
        per_cat_mask = torch.repeat_interleave(self.keep, self.dims)             # [sum(dims)]
        self.per_cat_mask = per_cat_mask
        self.kept_full_idx = per_cat_mask.nonzero(as_tuple=False).squeeze(-1)    # [K_kept]
        assert self.kept_full_idx.numel() > 0, "At least one block must be kept."

        # Inverse map: full -> kept (=-1 for dropped)
        K_full = int(self.dims.sum().item())
        self.full_to_kept = torch.full((K_full,), -1, device=self.device, dtype=torch.long)
        self.full_to_kept[self.kept_full_idx] = torch.arange(self.kept_full_idx.numel(), device=self.device)

        # Build the *single* categorical over kept categories
        kept_logits = logits.index_select(-1, self.kept_full_idx)                # [..., K_kept]
        self.dist = Categorical(logits=kept_logits)

    # ---------- mappings ----------
    def kept_to_full(self, k_kept):
        return self.kept_full_idx[k_kept]

    def full_to_block_pair(self, k_full):
        # block = index of offsets such that offsets[b] <= k_full < offsets[b]+dims[b]
        b = torch.bucketize(k_full, self.offsets[1:])        # [same shape as k_full]
        i = k_full - self.offsets[b]
        return b, i

    def block_pair_to_full(self, b, i):
        return self.offsets[b] + i

    def block_pair_to_kept(self, b, i):
        k_full = self.block_pair_to_full(b, i)
        return self.full_to_kept[k_full]                     # -1 if that (b,i) was masked out

    # ---------- sampling ----------
    def sample(self, sample_shape=torch.Size()):
        """
        Returns:
          per_block: sample reshaped to [..., num_blocks] (masked blocks are 0)
          info: dict with flat/kept indices to avoid ambiguity when within==0
        """
        batch_shape = self.logits.shape[:-1]
        k_kept = self.dist.sample(sample_shape)              # sample in KEPT space
        k_full = self.kept_to_full(k_kept)                   # map to FULL flat index
        b, i = self.full_to_block_pair(k_full)               # block id + within-block id

        per_block = torch.zeros(sample_shape + batch_shape + (self.num_blocks,),
                                dtype=torch.long, device=self.device)
        per_block[..., b] = i                                # masked blocks remain 0
        return per_block, {"k_kept": k_kept, "k_full": k_full, "block": b, "within": i}

    # ---------- log-prob ----------
    def log_prob_from_kept(self, k_kept):
        return self.dist.log_prob(k_kept)

    def log_prob_from_block_pair(self, b, i):
        """
        Log-prob of the assignment that selects block b with within-block index i.
        (This is the correct way to score a 'reshaped' sample.)
        """
        # invalid if block is masked
        bad_block = ~self.keep[b]
        # map (b,i) -> kept index
        k_kept = self.block_pair_to_kept(b, i)
        bad_pair = (k_kept < 0)

        lp = self.dist.log_prob(k_kept.clamp_min(0))         # temp value
        lp = torch.where(bad_block | bad_pair, torch.tensor(float('-inf'), device=self.device, dtype=self.dtype), lp)
        return lp

    def log_prob_from_full(self, k_full):
        """
        Log-prob directly from FULL flat index (0..sum(dims)-1).
        """
        k_kept = self.full_to_kept[k_full]
        bad = (k_kept < 0)
        lp = self.dist.log_prob(k_kept.clamp_min(0))
        lp = torch.where(bad, torch.tensor(float('-inf'), device=self.device, dtype=self.dtype), lp)
        return lp


dims = [3,5,7,3]
keep = torch.tensor([1,1,1,0])
B = 2
logits = torch.randn(B, sum(dims))


dist = MaskedFlatJointCat(logits, dims, keep)

# --- Sample
per_block, info = dist.sample()   # per_block: [B, 4], last two are 0 (masked)
print(per_block)                  # e.g. tensor([[0, 3, 0, 0],
                                  #               [2, 0, 0, 0]])
# info["k_full"] is the full flat index; info["block"], info["within"] disambiguate zeros.

# --- Log-prob
lp1 = dist.log_prob_from_kept(info["k_kept"])            # direct
lp2 = dist.log_prob_from_full(info["k_full"])            # via full flat index
lp3 = dist.log_prob_from_block_pair(info["block"], info["within"])  # via reshaped pair


tensor([[-0.0337,  0.0649, -0.1024,  0.9784,  0.2401, -0.2484,  0.3114,  0.7636,
         -1.0020, -1.1478,  1.1529,  0.1975, -0.9572,  1.3498, -0.9502, -1.3087,
         -0.0481,  0.4834],
        [ 0.0043, -0.4993,  1.7473,  1.0211,  0.2234,  0.0693, -0.9913,  0.6352,
         -0.0644,  0.2137, -2.0399, -1.4999, -0.4451, -0.7589, -1.1711, -1.6983,
         -1.2267, -1.0132]])
tensor([[2, 0, 2, 0],
        [2, 0, 2, 0]])


In [468]:
dist.sample()

(tensor([[1, 0, 3, 0],
         [1, 0, 3, 0]]),
 {'k_kept': tensor([11,  1]),
  'k_full': tensor([11,  1]),
  'block': tensor([2, 0]),
  'within': tensor([3, 1])})