In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Normal

sns.set(font_scale=2., style='whitegrid')

In [None]:
class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, n, n_pos=None):
        self.p1 = Normal(torch.tensor(0.), torch.tensor(1))
        self.p2 = Normal(torch.tensor(-1.), torch.tensor(2))

        self.n_pos = n_pos if n_pos is not None else n // 2

        self.inputs = torch.cat([
            self.p1.sample(torch.Size([self.n_pos, 1])),
            self.p2.sample(torch.Size([n - self.n_pos, 1]))
        ])
        self.targets = torch.cat([
            torch.ones(self.n_pos),
            torch.zeros(n - self.n_pos),
        ])

    def log_ratio(self, x):
        return self.p1.log_prob(x) - self.p2.log_prob(x)
    
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

dataset = ToyDataset(n=200, n_pos=20)

fig, ax = plt.subplots(figsize=(7,5))

_n = len(dataset) // 2
x = torch.linspace(-5, 5, 1000)
ax.plot(x, dataset.p1.log_prob(x).exp(), c='green', label=r'$\mathcal{N}(1, 1)$')
ax.scatter(dataset.inputs[:_n], torch.zeros_like(dataset.inputs[:_n]), c='green', alpha=.1)

ax.plot(x, dataset.p2.log_prob(x).exp(), c='blue', label=r'$\mathcal{N}(0, 2^2)$')
ax.scatter(dataset.inputs[_n:], torch.zeros_like(dataset.inputs[_n:]), c='blue', alpha=.1)
ax.legend()
fig.show()

In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

device = 'cuda:7'

class RatioEstimator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.classifier = nn.Sequential(
            nn.Linear(1, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1),
        )

    def forward(self, inputs):
        return self.classifier(inputs)

    @torch.no_grad()
    def ratio(self, inputs, n=None, n_pos=None):
        p = self(inputs).squeeze(-1).sigmoid()
        prior_ratio = 1.
        if n is not None and n_pos is not None:
            prior_ratio = self.n / self.n_pos - 1
        
        return (prior_ratio * p) / (1 - p)

estimator = RatioEstimator().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(len(dataset) / dataset.n_pos - 1))
# criterion = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(estimator.parameters(), lr=1e-3, weight_decay=1e-4)

def plot_rx(title=None):
    fig, ax = plt.subplots(figsize=(7,5))

    ax.plot(x, dataset.log_ratio(x).exp(), c='blue', label=r'$r(x)$')
    ax.plot(x, estimator.ratio(x.unsqueeze(-1).to(device)).cpu(), c='red', label=r'$\widehat{r}(x)$')
    ax.set(title=title)
    ax.legend()
    fig.show()

for e in tqdm(range(200), leave=False):
    for X, y in DataLoader(dataset, batch_size=64, shuffle=True):
        X, y = X.to(device), y.to(device)
        optim.zero_grad()
        loss = criterion(estimator(X).squeeze(-1), y)
        loss.backward()
        optim.step()
    
    if e % 10 == 0:
        plot_rx(title=rf'Epoch ${e}$')

In [None]:
plot_rx()