In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Normal

sns.set(font_scale=2., style='whitegrid')

In [None]:
class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, n):
        self.p1 = Normal(torch.tensor(-1.), torch.tensor(1.))
        self.p2 = Normal(torch.tensor(1.), torch.tensor(1.))

        self.inputs = torch.cat([
            self.p1.sample(torch.Size([n, 1])),
            self.p2.sample(torch.Size([n, 1]))
        ])
        self.targets = torch.cat([
            torch.ones(n),
            torch.zeros(n),
        ])

    @property
    def n_pos(self):
        return (self.targets == 1).sum().item()

    def log_ratio(self, x):
        return self.p1.log_prob(x) - self.p2.log_prob(x)
    
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

n = 100
dataset = ToyDataset(n=n)

fig, ax = plt.subplots(figsize=(7,5))

x = torch.linspace(-10, 10, 1000)
ax.plot(x, dataset.p1.log_prob(x).exp(), c='green', label=r'$\mathcal{N}(-1, 4)$')
ax.scatter(dataset.inputs[:n], torch.zeros_like(dataset.inputs[:n]), c='green', alpha=.2)

ax.plot(x, dataset.p2.log_prob(x).exp(), c='blue', label=r'$\mathcal{N}(1, 4)$')
ax.scatter(dataset.inputs[n:], torch.zeros_like(dataset.inputs[:n]), c='blue', alpha=.2)

ax.legend(loc='upper right')
fig.show()

In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

device = 'cuda:2'

class RatioEstimator(nn.Module):
    def __init__(self, n, n_pos):
        super().__init__()
        
        self.classifier = nn.Sequential(
            nn.Linear(1, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 1),
        )

        self.n = n
        self.n_pos = n_pos

    def forward(self, inputs):
        return self.classifier(inputs)

    @torch.no_grad()
    def ratio(self, inputs):
        p = self(inputs).squeeze(-1).sigmoid()
        prior_ratio = self.n / self.n_pos - 1
        return (prior_ratio * p) / (1 - p)

estimator = RatioEstimator(n=len(dataset), n_pos=dataset.n_pos).to(device)
criterion = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(estimator.parameters(), lr=1e-3)

for _ in tqdm(range(200), leave=False):
    for X, y in DataLoader(dataset, batch_size=64, shuffle=True):
        X, y = X.to(device), y.to(device)
        optim.zero_grad()
        loss = criterion(estimator(X).squeeze(-1), y)
        loss.backward()
        optim.step()

In [None]:
fig, ax = plt.subplots(figsize=(7,5))

ax.plot(x, dataset.log_ratio(x).exp(), c='blue', label=r'$r(x)$')
ax.plot(x, estimator.ratio(x.unsqueeze(-1).to(device)).cpu(), c='red', label=r'$\widehat{r}(x)$')
ax.legend()
fig.show()