In [1]:
"""
By Zhewen Hou
"""
import numpy as np
import torch
from torch import Tensor
from torch.distributions.normal import Normal
from scipy.stats import norm
from sklearn import *
import time
import sys

In [2]:
# numpy implementation
# def dgmm1d(x, mu, sigma, pi):
#    pdf_gmm = np.sum([pi[k] * norm.pdf(x, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    return pdf_gmm

# density of 1D Gaussian Mixture
def dgmm1d(x: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    pdf_gmm = torch.sum(torch.stack([pi[k] * Normal(mu[k], sigma[k]).log_prob(x).exp() for k in range(len(mu))]), dim=0)
    return pdf_gmm

In [3]:
# numpy implementation
# def pgmm1d(x, mu, sigma, pi):
#    cdf_gmm = np.sum([pi[k] * norm.cdf(x, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    return cdf_gmm

# cdf of 1D Gaussian Mixture
def pgmm1d(x: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    cdf_gmm = torch.sum(torch.stack([pi[k] * Normal(mu[k], sigma[k]).cdf(x) for k in range(len(mu))]), dim=0)
    return cdf_gmm

In [4]:
# numpy implementation: quantile of 1D Gaussian Mixture
#def qgmm1d(q, mu, sigma, pi):
#    ppf_full = np.array([norm.ppf(q, loc=mu[k], scale=sigma[k]) for k in range(len(mu))]).flatten()
#    ppf_full.sort()
#    cdf_gmm = np.sum([pi[k] * norm.cdf(ppf_full, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    ## 1D linear interpolation
#    ppf_gmm = np.interp(q, cdf_gmm, ppf_full)
#    return ppf_gmm

# implementation of np.interp with pytorch
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
    m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1])
    b = fp[:, :-1] - (m * xp[:, :-1])

    indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), dim=2) - 1
    indicies = torch.clamp(indicies, 0, len(m[0]) - 1)

    return torch.gather(m, 1, indicies) * x + torch.gather(b, 1, indicies)


# quantile of 1D Gaussian Mixture
def qgmm1d(q: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    #print((Normal(mu, sigma).icdf(q.view(q.size(0), 1, 1).expand(q.size(0), mu.size(0), mu.size(1)))).shape)
    ppf_full = torch.permute(Normal(mu, sigma).icdf(q.view(q.size(0), 1, 1).expand(q.size(0), mu.size(0), mu.size(1))),
                             (1, 0, 2)).flatten(1)
    # ppf_full = torch.permute(Normal(mu, sigma).icdf(q[:, None, None].expand(q.size(0), mu.size(0), mu.size(1))),
    #                         (1, 0, 2)).flatten(1)
    ppf_full= torch.sort(ppf_full, dim=1)[0]
    cdf_gmm = torch.sum(pi[:, None] * Normal(mu[:, None], sigma[:, None]).cdf(ppf_full[:, :, None]), dim=2)
    ppf_gmm = interp(q.expand(mu.size(0), -1), cdf_gmm, ppf_full)
    return ppf_gmm.squeeze()

In [5]:
def batch_dWasserstein(Y: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor, q: Tensor, p: int = 2) -> Tensor:
    ppf_gmm = qgmm1d(q, mu, sigma, pi)
    #print(ppf_gmm.shape)
    #print(Y.shape)
    Wp = torch.mean((ppf_gmm - Y) ** p, dim=1) ** (1 / p)
    return Wp


# main functions for fitting Wasserstein mixture regression
def WDL(Y, q_vec, mu, sd, pi):
    """
    n: number of observations
    m: number of percentiles 
    k: number of components in the GMM 
    :param Y: n x m
    :param q_vec: m
    :param mu: n x k, mean of the components in the GMM
    :param sd: n x k, standard deviation of the components in the GMM
    :param pi: n x k, mixing proportion of the components in the GMM
    :return: mean distance between the empirical quantile function and the quantile function of the GMM
    """
    # calculate the mean of the Wasserstein distance for each distribution of different X
    w2 = torch.mean(batch_dWasserstein(Y, mu, sd, pi, q_vec) ** 2)
    return w2

#### Testing Below

In [6]:
'''
# set parameters
SEED = 234234
K = 2 # number of components in the GMM
n_dist = 300 ## number of distributions
omega = 0.1 # default 0.1
n_sample = 300

# simulate the data
np.random.seed(SEED)
torch.random.manual_seed(SEED)
X = np.random.uniform(size=(n_dist, 3)) * 2 - 1 # generate X
Y = np.zeros((n_dist, n_sample)) # generate Y


## simulate Y
for i in range(n_dist):
    mu_1 = X[i, 0]
    mu_2 = 2 * X[i, 1]**2 + 2
    mu_true = [mu_1, mu_2]
    #print('mu:', mu_true)
    sig_1 = np.abs(X[i, 1]) + 0.5
    sig_2 = np.abs(X[i, 0]) + 0.5
    sig_true = [sig_1, sig_2]
    #print('sig:', sig_true)
    pi_1 = 1 / (1 + np.exp(X[i, 2]))
    pi_true = [pi_1, 1-pi_1]
    #print('pi:', pi_true)
    ## simulate noise
    eps_noise = np.random.normal(loc=0, scale=omega, size=1)
    ## simulate responses
    var_gaussian = np.array([np.random.normal(loc=mu_true[k]+eps_noise,
                                              scale=sig_true[k],
                                              size=n_sample) for k in range(K)]).T
    var_mult = np.random.choice(range(K), size=n_sample, replace=True, p=pi_true)
    var_mult = np.eye(K)[var_mult]
    var_GMM = np.sum(var_mult * var_gaussian, axis=1)
    Y[i] = np.sort(var_GMM)
    
print(Y.shape)
print(Y)

# prepare for model fitting
K_mix = 2
n_iter = 1000
lr = 1e-1
v_lr = np.array([1] + [lr] * n_iter)
n_dist = Y.shape[0]
n_levs = 100
q_vec = torch.arange(1, n_levs, dtype=torch.float64) / n_levs # quantile levels
print('q_vec', q_vec)
## transform Y
Q_mat = np.array([np.quantile(Y[i], q_vec) for i in range(n_dist)]) # calculate the empirical quantile function
'''

(300, 300)
[[-3.09883251 -2.86643922 -2.51388866 ...  4.31112254  4.3739572
   5.20220006]
 [-2.91661178 -2.65928491 -2.44640414 ...  4.53591171  4.56868362
   4.62578348]
 [-2.4340326  -2.40057418 -2.28007935 ...  5.02492275  5.16380261
   5.78538448]
 ...
 [-2.46941656 -2.20831643 -2.10139391 ...  5.88877085  6.20330087
   6.4078245 ]
 [-2.84557891 -2.6724476  -2.48777096 ...  4.07628841  4.22882558
   4.26913273]
 [-3.17172267 -3.08769588 -3.08434324 ...  4.32627668  4.42752034
   4.526383  ]]
q_vec tensor([0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900,
        0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800,
        0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700,
        0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600,
        0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500,
        0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400,
        0.55

In [7]:
'''
# step 1. train-val split
n_train = int(0.8 * n_dist)
loc_train = np.random.choice(n_dist, n_train, replace=False)
loc_val = np.setdiff1d(np.arange(n_dist), loc_train)


X_train = torch.tensor(X[loc_train])
# generate true parameters
mu_true = torch.cat((X_train[:, 0].unsqueeze(dim=1), (2 * X_train[:, 1] ** 2 + 2).unsqueeze(dim=1)), dim=1)
sd_true = torch.cat(
    (torch.abs(X_train[:, 1]).unsqueeze(dim=1) + 0.5, torch.abs(X_train[:, 0]).unsqueeze(dim=1) + 0.5), dim=1)
alpha_true = torch.cat((torch.zeros_like(X_train[:, 2].unsqueeze(dim=1)), X_train[:, 2].unsqueeze(dim=1)), dim=1)
pi_true = torch.softmax(alpha_true, dim=1)

# generate features
X_train = torch.cat((X_train, torch.abs(X_train), X_train ** 2), 1)
y_train = torch.tensor(Q_mat[loc_train])
X_val = torch.tensor(X[loc_val])
X_val = torch.cat((X_val, torch.abs(X_val), X_val ** 2), 1)
y_val = torch.tensor(Q_mat[loc_val])


# test the speed of the loss function
start = time.time()

for i in range(100):
    alpha_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)
    mu_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)
    sd_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)

    alpha = torch.matmul(X_train, alpha_matrix)
    pi = torch.softmax(alpha, dim=1)

    mu = torch.matmul(X_train, mu_matrix)
    
    sd = torch.abs(torch.matmul(X_train, sd_matrix))

    loss = WDL(y_train, q_vec, mu, sd, pi)

    print(loss)
print(time.time() - start)
'''

tensor(4.6643, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(1.4755, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(4.3608, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(2.8812, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(6.0676, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(4.1523, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(2.8222, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(5.4872, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(3.2802, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(4.1873, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(3.5498, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(2.6493, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(6.3114, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(5.0874, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(12.4420, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(3.9997, dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(3.9931, dtype=to

In [8]:
'''
# step 2. train the model
# initialize the optimizer
optimizer = torch.optim.SGD([alpha_matrix, mu_matrix, sd_matrix], lr=0.6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=170, gamma=0.84)

# train the model
start = time.time()
for k in range(1000):
    """alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
    mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))"""

    alpha = torch.matmul(X_train, alpha_matrix)
    mu = torch.matmul(X_train, mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))

    pi = torch.softmax(alpha, dim=1)

    optimizer.zero_grad()
    loss = WDL(y_train, q_vec, mu, sd, pi)
    print(k, loss.item())
    loss.backward()
    optimizer.step()
    scheduler.step()
end = time.time()
print(end - start)

print(mu)
'''

0 7.680827736958281
1 1.4953733142883936
2 0.8903266381366824
3 0.7522665855918405
4 0.6710365909077746
5 0.6144841417018095
6 0.5742080114865169
7 0.5443134943569208
8 0.5219079492396036
9 0.5045083459416578
10 0.490555600545359
11 0.47903784840924074
12 0.4693277394143237
13 0.4608812965751459
14 0.453397342155681
15 0.44663065216368936
16 0.4404143981473953
17 0.4346335542102026
18 0.42921066264910884
19 0.424084975169027
20 0.41921745318801584
21 0.41458188412818564
22 0.41014186086025883
23 0.405878811238565
24 0.40177200164886845
25 0.39781512107280065
26 0.3939923916591563
27 0.39030338083712146
28 0.38673767590730523
29 0.38329830855231933
30 0.3799810842033363
31 0.3767588206190716
32 0.37363722823770573
33 0.3706204708419906
34 0.3677066754209915
35 0.36489828297820875
36 0.3621896742750117
37 0.3595685966125013
38 0.3570516113984413
39 0.35464455808546586
40 0.35234502680068236
41 0.3501622599364519
42 0.34809768744946507
43 0.34618489038970174
44 0.34441622281541406
45 0.34

355 0.1895183329727786
356 0.18932277248442636
357 0.1891239648128644
358 0.18890705646461
359 0.18869572204945897
360 0.1884817137626405
361 0.18825945111999248
362 0.18803898828545937
363 0.1878180477331018
364 0.18759639299113318
365 0.18737389100565593
366 0.18715061743589673
367 0.18692257502887571
368 0.18667982650662138
369 0.1864419886113408
370 0.18620252313927022
371 0.18596063459579604
372 0.18571522595053586
373 0.1854631545454047
374 0.18521378188009266
375 0.18496209171057476
376 0.18469232580371026
377 0.18442817421071153
378 0.18416588440951343
379 0.18390497350955842
380 0.18364538468716846
381 0.18338721231894012
382 0.1831281241410145
383 0.18286776658359682
384 0.18260457553256795
385 0.1823319664266481
386 0.1820439525760095
387 0.18175966261025897
388 0.1814761731427011
389 0.18119618630688966
390 0.18091866690885877
391 0.18064198274696228
392 0.18036531010903736
393 0.18008949535838065
394 0.17981528089348575
395 0.17954134090304905
396 0.17926187643906044
397 0

701 0.14076283585374053
702 0.14070845357320433
703 0.14065416292766345
704 0.14059991886335435
705 0.1405457571912446
706 0.1404917667949744
707 0.14043813593466903
708 0.14038472753981715
709 0.14033138687802454
710 0.14027801461877626
711 0.14022466997215668
712 0.1401713555989042
713 0.14011797624292732
714 0.14006467793536231
715 0.14001140843054688
716 0.13995808369366547
717 0.13990481086785056
718 0.13985165947932182
719 0.13979857259447445
720 0.13974560955923582
721 0.13969275593755032
722 0.13964003534842936
723 0.13958737314164624
724 0.13953476435132775
725 0.13948233847305339
726 0.1394299938468571
727 0.13937777021793835
728 0.13932559034420933
729 0.1392733562000332
730 0.13922099094817855
731 0.13916849173486026
732 0.13911606504643317
733 0.13906371329146752
734 0.13901200227635419
735 0.13896040199451723
736 0.1389089227859938
737 0.1388575644135198
738 0.1388063202951665
739 0.1387552767467167
740 0.13870432345276953
741 0.13865347263659433
742 0.13860266212387504
7

KeyboardInterrupt: 

In [9]:
#print(X_train.shape)

torch.Size([240, 9])


In [11]:
'''
# step 3. evaluate the model
alpha = torch.matmul(X_train, alpha_matrix)
mu = torch.matmul(X_train, mu_matrix)
sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))
pi = torch.softmax(alpha, dim=1)

# calculate the loss
loss = WDL(y_train, q_vec, mu, sd, pi)
# calculate the difference between estimators and real values
print((mu - mu_true).detach().numpy())
print((sd - sd_true).detach().numpy())
print((pi - pi_true).detach().numpy())
print(loss.item())
print()
'''

[[ 6.66676478e-01 -2.84067649e-01]
 [ 2.96840067e-01 -1.44553624e+00]
 [ 1.38995844e+00  5.06455550e-01]
 [ 1.12654493e+00  4.77331162e-01]
 [ 9.64291138e-01  1.58315217e-01]
 [ 1.81584360e-01  9.39345539e-01]
 [ 1.25845916e+00  2.35090734e-01]
 [ 5.91216707e-01  7.41017517e-01]
 [ 3.45591939e-01  1.25174015e+00]
 [ 2.51078522e-01 -5.05201003e-01]
 [-1.21620893e-01  8.20351780e-01]
 [ 7.35562548e-01  8.44545562e-01]
 [ 4.86298687e-01  1.71218464e+00]
 [ 4.67980513e-01  1.39678620e+00]
 [ 5.06183345e-01  8.41167563e-01]
 [ 2.36277541e-01  7.46228737e-01]
 [ 4.38339687e-01  1.17203483e-01]
 [ 6.32575823e-01  1.34099146e+00]
 [ 2.30677905e-01  1.51226974e-01]
 [ 9.95972070e-01  1.62690942e-01]
 [ 7.01800142e-01 -8.03170859e-02]
 [ 6.60889340e-01 -5.47164647e-01]
 [ 1.14072911e+00  8.68564505e-01]
 [ 1.01782373e+00  7.42340494e-01]
 [ 4.54636171e-01  8.62823627e-01]
 [ 4.08696982e-01  6.08927207e-01]
 [ 1.08478362e+00  5.49710235e-01]
 [ 1.07833246e+00  2.42432466e-01]
 [ 6.86554632e-01  6

In [10]:
# I forgot its purpose
"""optimizer.param_groups[0]['lr'] = 0.5
for k in range(500):
    alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
    mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))

    alpha = torch.matmul(X_train, alpha_matrix)
    mu = torch.matmul(X_train, mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))

    pi = torch.softmax(alpha, dim=1)

    optimizer.zero_grad()
    loss = WDL(y_train, q_vec, mu, sd, pi)
    print(k, loss.item())
    loss.backward()
    optimizer.step()
    scheduler.step()"""


"""alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))"""
# WDL_example.txt
# Displaying WDL_example.txt.

SyntaxError: invalid syntax (1710302979.py, line 26)