In [None]:
import torch

class MaskedLogitNetwork(torch.nn.Module):
    def __init__(self, logit_model):
        super(MaskedLogitNetwork, self).__init__()

        # Initialize mask and apply it to the weights
        self.logit_model = logit_model

    def forward(self, x):
        unmasked_logits = logit_model(x)
        return (unmasked_logits - torch.exp(1000*x)).flatten(1)

In [None]:
from torch.nn import Conv2d, Sequential, ReLU

logit_model = Sequential(
    Conv2d(1, 64, 3, padding=1),
    ReLU(),
    Conv2d(64, 64, 3, padding=1),
    ReLU(),
    Conv2d(64, 1, 3, padding=1)
)

model = MaskedLogitNetwork(logit_model)

model.to('cuda')

In [None]:
from torch.nn.functional import softmax
from torch.distributions import Categorical

def place_point(x, model, grid_size=5):
  with torch.no_grad():
    logits = model(x)
    p = softmax(logits, dim=-1)
    index = Categorical(p).sample()
    row_index = index//grid_size
    col_index = index%grid_size
    row_select = row_index[:, None, None, None] == torch.arange(grid_size, device='cuda')[None, None, :, None]
    col_select = col_index[:, None, None, None] == torch.arange(grid_size, device='cuda')[None, None, None, :]
    x.add_(2*row_select*col_select)
    return row_index, col_index

In [None]:
def generate_pointset(model, num_samples, grid_size, num_points, device='cuda'):
  samples = -1*torch.ones((num_samples, 1, grid_size, grid_size), device=device)
  indices = torch.zeros((num_samples, num_points, 2), device=device, dtype=torch.long)
  for i in range(num_points):
    row_index, col_index = place_point(samples, model, grid_size=grid_size)
    indices[:, i, 0]=row_index
    indices[:, i, 1]=col_index
  return samples, indices

In [None]:
samples, indices = generate_pointset(model, num_samples=100, grid_size=5, num_points=10)

In [None]:
from torch.nn import CrossEntropyLoss

cross_entropy = CrossEntropyLoss()
def calculate_loss(model, indices, grid_size, device='cuda'):
  num_samples, num_points,_ = indices.shape
  samples = -1*torch.ones((num_samples, 1, grid_size, grid_size), device=device)
  loss = torch.tensor(0.0, device=device)
  for i in range(num_points):
    logits = model(samples)
    loss+=cross_entropy(logits, indices[:, i, 0]*5+indices[:, i, 1])
    with torch.no_grad():
      row_select = indices[:, i:i+1, 0:1, None] == torch.arange(grid_size, device=device)[None, None, :, None]
      col_select = indices[:, i:i+1, 1:][:, :, None, :] == torch.arange(grid_size, device=device)[None, None, None, :]
      samples.add_(row_select*col_select)
  return loss

In [None]:
calculate_loss(model, indices, grid_size=5)