In [1]:
import torch
import torch.optim as opt
from sklearn.datasets import fetch_california_housing
import numpy as np
from MDN import MDN, SDN, distribution_log_prob_loss
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Normal

In [12]:
class LossPrinter():
    def __init__(self, n_iterations, n):
        self.n_iterations = n_iterations
        self.n = n
        self.i = 0
        self.cum_loss = 0
        self.threshold = self.n_iterations // self.n

    def __call__(self, loss):
        self.cum_loss += loss
        self.i += 1
        if self.i % self.threshold == 0:
            print(f"Iteration {self.i}\tloss {self.cum_loss / self.threshold}")
            self.cum_loss = 0

In [2]:
num_steps = 5000
batch_size = 256
learning_rate = 1e-4
n_hidden = 4

device = "cuda" if torch.cuda.is_available() else "cpu"
data = fetch_california_housing()
X = torch.FloatTensor(data["data"]).to(device)
y = torch.FloatTensor(data["target"]).to(device)

In [24]:
model = SDN(X.shape[-1], n_hidden).to(device)
# model.load_state_dict(torch.load("run/sdn_state_dict.pt"))
optim = opt.Adam(model.parameters(), lr=learning_rate)

In [56]:
for i in range(num_steps):
    batch_idx = np.random.choice(X.shape[0], batch_size)
    dist = model(X[batch_idx])
    loss = distribution_log_prob_loss(dist, y[batch_idx])
    optim.zero_grad()
    loss.backward()
    optim.step()
    if i % 1000 == 0:
        print(loss.item())
torch.save(model.state_dict(), "run/sdn_state_dict.pt")

0.9008913040161133
0.8253390789031982
1.0706696510314941
0.8600386381149292
0.8180643320083618


In [3]:
model = MDN(X.shape[-1], n_hidden, 3).to(device)
model.load_state_dict(torch.load("run/mdn_state_dict.pt",map_location=torch.device('cpu')))
optim = opt.Adam(model.parameters(), lr=learning_rate)

In [20]:
f = 1
lp = LossPrinter(num_steps*f, 5)
for i in range(num_steps*f):
    batch_idx = np.random.choice(X.shape[0], batch_size)
    pi, dist = model(X[batch_idx])
    loss = distribution_log_prob_loss(dist, y[batch_idx, None], pi)
    optim.zero_grad()
    loss.backward()
    optim.step()
    lp(loss.item())
torch.save(model.state_dict(), "run/mdn_state_dict.pt")

torch.Size([256, 3])
torch.Size([256])


In [6]:
def mixture_normal_cdf(pi, dist, value):
    return torch.sum(pi * dist.cdf(value), dim=-1)

def function_with_no_name(pi, dist):
    ninety = torch.sum(pi * (dist.mean + 1.2815 * dist.stddev), dim=1)
    return ninety

class TensorDataset(Dataset):
    def __init__(self, data, target, return_index=False):
        if data.shape[0] != target.shape[0]:
            raise ValueError("Data and Target must have the same length")
        self.data = data
        self.target = target
        self.return_index = return_index

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if idx < 0:
            idx = len(self) + idx
        if self.return_index:
            return idx, self.data[idx], self.target[idx]
        return self.data[idx], self.target[idx]

dataloader = DataLoader(TensorDataset(X, y, return_index=True), batch_size=batch_size, shuffle=False)
model.eval()
threshold = torch.as_tensor([2]).reshape(-1, 1)
possible_houses = []
with torch.no_grad():
    for i_batch, (idx_batched, X_batched, y_batched) in enumerate(dataloader):
        # dist = model(X_batched)
        # possible_houses.append(idx_batched[dist.mean + 1.28*dist.stddev <= threshold])
        pi, dist = model(X_batched)
        foo = mixture_normal_cdf(pi, dist, threshold)
        # print(foo)
        possible_houses.append(idx_batched[foo > 0.9])
possible_houses = torch.hstack(possible_houses)

In [8]:
idx = 0
print(possible_houses)
print(y[possible_houses][idx:idx+10])

tensor([    8,    14,    15,  ..., 20637, 20638, 20639])
tensor([2.2670, 1.5920, 1.4000, 1.5870, 1.4750, 1.5980, 1.1390, 1.0750, 1.0550,
        1.0890])


In [11]:
mu = torch.FloatTensor([1, 5, 9])
std = torch.FloatTensor([1, 1.5, 3])
w = torch.full((3,), 1./3)
n = Normal(mu, std)
#n.cdf(torch.as_tensor(mu + 1.2815 * std))

In [20]:
foo = (w*(n.mean + 1.2815 * n.stddev)).sum()

In [34]:
mixture_normal_cdf(w, n, torch.as_tensor([9, 10, 11, 12]).reshape(-1, 1))

ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([1, 4]) vs torch.Size([3]).