# Code skeleton (ChatGPT)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ASNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ASNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, pairwise_features):
        x = F.relu(self.fc1(pairwise_features))
        x = self.fc2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        # x: (batch_size, seq_len, input_dim)
        output, (hidden, cell) = self.lstm(x)
        return output, hidden, cell

class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, hidden, cell):
        # input: (batch_size, 1, hidden_dim)
        output, (hidden, cell) = self.lstm(input, (hidden, cell))
        logits = self.fc(output.squeeze(1))
        return logits, hidden, cell

class PointerNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, pairwise_dim):
        super(PointerNetwork, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim)
        self.decoder = Decoder(hidden_dim, hidden_dim)
        self.asnn = ASNN(pairwise_dim, hidden_dim)

    def forward(self, inputs, pairwise_features):
        # inputs: (batch_size, seq_len, input_dim)
        # pairwise_features: (batch_size, seq_len, seq_len, pairwise_dim)
        encoder_outputs, hidden, cell = self.encoder(inputs)
        
        batch_size, seq_len, _ = inputs.size()
        visited = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=inputs.device)
        predictions = []

        decoder_input = encoder_outputs[:, 0, :].unsqueeze(1)  # Start from the first stop

        for _ in range(seq_len):
            logits, hidden, cell = self.decoder(decoder_input, hidden, cell)

            attention_scores = []
            for i in range(seq_len):
                if not visited[:, i].any():
                    pairwise_input = pairwise_features[:, :, i, :]
                    score = self.asnn(pairwise_input)
                    attention_scores.append(score.squeeze(-1))

            attention_scores = torch.stack(attention_scores, dim=1)
            attention_probs = F.softmax(attention_scores, dim=1)

            next_stop = attention_probs.argmax(dim=1)
            predictions.append(next_stop)
            visited[:, next_stop] = True

            decoder_input = encoder_outputs[torch.arange(batch_size), next_stop, :].unsqueeze(1)

        return torch.stack(predictions, dim=1)

# Example usage
input_dim = 16
hidden_dim = 32
pairwise_dim = 8

model = PointerNetwork(input_dim, hidden_dim, pairwise_dim)
inputs = torch.randn(4, 10, input_dim)  # batch_size=4, seq_len=10
pairwise_features = torch.randn(4, 10, 10, pairwise_dim)

output = model(inputs, pairwise_features)
print(output.shape)  # Predicted sequence indices
