In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Attentive Neural Process model
class AttentiveNeuralProcess(nn.Module):
    def __init__(self, x_dim, y_dim, r_dim, z_dim, h_dim):
        super(AttentiveNeuralProcess, self).__init__()
        
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(x_dim + y_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU()
        )
        
        # Attention network
        self.attention = nn.Sequential(
            nn.Linear(h_dim + r_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1)
        )
        
        # Latent aggregator
        self.aggregator = nn.Sequential(
            nn.Linear(r_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, z_dim * 2)  # Output mean and log variance
        )
        
        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(x_dim + z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, y_dim)
        )
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        z = mean + epsilon * std
        return z
    
    def forward(self, context_x, context_y, target_x):
        # Concatenate context_x and context_y as input
        context_xy = torch.cat([context_x, context_y], dim=-1)
        
        # Encode context pairs
        r = self.encoder(context_xy)
        
        # Attention mechanism
        context_xr = torch.cat([context_x, r.repeat(context_x.size(0), 1)], dim=-1)
        attention_weights = torch.softmax(self.attention(context_xr), dim=0)
        r_weighted = torch.sum(attention_weights * r, dim=0)
        
        # Aggregate the representations
        z_params = self.aggregator(r_weighted)
        mean, logvar = z_params[:, :z_dim], z_params[:, z_dim:]
        z = self.reparameterize(mean, logvar)
        
        # Repeat z to match the shape of target_x
        z_repeated = z.unsqueeze(1).repeat(1, target_x.size(1), 1)
        
        # Concatenate target_x and z as input
        target_xz = torch.cat([target_x, z_repeated], dim=-1)
        
        # Decode the target
        target_y = self.decoder(target_xz)
        
        return target_y, mean, logvar

# Set the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the dimensions and hyperparameters
x_dim = 1  # Dimension of x
y_dim = 1  # Dimension of y
r_dim = 32  # Dimension of representation r
z_dim = 16  # Dimension of latent variable z
h_dim = 64  # Hidden dimension

# Define the Attentive Neural Process and move it to the device
model = AttentiveNeuralProcess(x_dim, y_dim, r_dim, z_dim, h_dim).to(device)

# Define the loss function
criterion = nn.MSELoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

