In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torch.optim as optim
import numpy as np

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device.type)

# ---

torch.manual_seed(42)
np.random.seed(42)

cuda


In [None]:
class CNP(nn.Module):
    def __init__(self):
        super(CNP, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(4+64*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024)
        )
        
        self.query = nn.Sequential(
            nn.Linear(1024+2,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024,2*2)
        )

    def forward(self, frames, observations, target):
        # n < n_max frames of a scene along with momentary observations are concatenated to constitute input
        scene_encodings = torch.flatten(self.cnn(frames))
        scene_encodings = self.cnn(frames).view(frames.shape[0], 64*4*4)
        encoder_in = torch.cat((observations, scene_encodings), 1)
        r = self.encoder(encoder_in)
        
        r_avg = torch.mean(r, dim=0)
        r_avgs = r_avg.repeat(target.shape[0], 1)  # repeating the same r_avg for each target
        r_avg_target = torch.cat((r_avgs, target), 1)
        query_out = self.query(r_avg_target)
        
        return query_out

    
def log_prob_loss(output, target):
    mean, sigma = output.chunk(2, dim = -1)
    sigma = F.softplus(sigma)
    dist = D.Independent(D.Normal(loc=mean, scale=sigma), 1)
    return -torch.mean(dist.log_prob(target))