In [1]:
from maf import *

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true', help='Train a flow.')
parser.add_argument('--evaluate', action='store_true', help='Evaluate a flow.')
parser.add_argument('--restore_file', type=str, help='Path to model to restore.')
parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
parser.add_argument('--data_dir', default='./data/', help='Location of datasets.')
parser.add_argument('--output_dir', default='./results/maf_mnist')
parser.add_argument('--results_file', default='results.txt', help='Filename where to store settings and test results.')
parser.add_argument('--no_cuda', action='store_true', help='Do not use cuda.')
# data
parser.add_argument('--dataset', default='toy', help='Which dataset to use.')
parser.add_argument('--flip_toy_var_order', action='store_true', help='Whether to flip the toy dataset variable order to (x2, x1).')
parser.add_argument('--seed', type=int, default=1, help='Random seed to use.')
# model
parser.add_argument('--model', default='maf', help='Which model to use: made, maf.')
# made parameters
parser.add_argument('--n_blocks', type=int, default=5, help='Number of blocks to stack in a model (MADE in MAF; Coupling+BN in RealNVP).')
parser.add_argument('--n_components', type=int, default=1, help='Number of Gaussian clusters for mixture of gaussians models.')
parser.add_argument('--hidden_size', type=int, default=100, help='Hidden layer size for MADE (and each MADE block in an MAF).')
parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.')
parser.add_argument('--activation_fn', type=str, default='relu', help='What activation function to use in the MADEs.')
parser.add_argument('--input_order', type=str, default='sequential', help='What input order to use (sequential | random).')
parser.add_argument('--conditional', default=False, action='store_true', help='Whether to use a conditional model.')
parser.add_argument('--no_batch_norm', action='store_true')
# training params
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--n_epochs', type=int, default=50)
parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
parser.add_argument('--log_interval', type=int, default=1000, help='How often to show loss statistics and save samples.')


_StoreAction(option_strings=['--log_interval'], dest='log_interval', nargs=None, const=None, default=1000, type=<class 'int'>, choices=None, help='How often to show loss statistics and save samples.', metavar=None)

In [3]:
args = parser.parse_args("--train --model=maf --dataset=MNIST --n_epochs=50".split())


In [4]:
args.device = torch.device('cpu')
torch.manual_seed(args.seed)
if args.conditional: assert args.dataset in ['MNIST', 'CIFAR10'], 'Conditional inputs only available for labeled datasets MNIST and CIFAR10.'
train_dataloader, test_dataloader = fetch_dataloaders(args.dataset, args.batch_size, args.device, args.flip_toy_var_order)
args.input_size = train_dataloader.dataset.input_size
args.input_dims = train_dataloader.dataset.input_dims
args.cond_label_size = train_dataloader.dataset.label_size if args.conditional else None

In [5]:
model = MAF(args.n_blocks, args.input_size, args.hidden_size, args.n_hidden, args.cond_label_size,
                    args.activation_fn, args.input_order, batch_norm=not args.no_batch_norm)

In [6]:
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

In [7]:
args.results_file = os.path.join(args.output_dir, args.results_file)

print('Loaded settings and model:')
print(pprint.pformat(args.__dict__))
print(model)
print(pprint.pformat(args.__dict__), file=open(args.results_file, 'a'))
print(model, file=open(args.results_file, 'a'))

Loaded settings and model:
{'activation_fn': 'relu',
 'batch_size': 100,
 'cond_label_size': None,
 'conditional': False,
 'data_dir': './data/',
 'dataset': 'MNIST',
 'device': device(type='cpu'),
 'evaluate': False,
 'flip_toy_var_order': False,
 'generate': False,
 'hidden_size': 100,
 'input_dims': (1, 28, 28),
 'input_order': 'sequential',
 'input_size': 784,
 'log_interval': 1000,
 'lr': 0.0001,
 'model': 'maf',
 'n_blocks': 5,
 'n_components': 1,
 'n_epochs': 50,
 'n_hidden': 1,
 'no_batch_norm': False,
 'no_cuda': False,
 'output_dir': './results/maf_mnist',
 'restore_file': None,
 'results_file': './results/maf_mnist/results.txt',
 'seed': 1,
 'start_epoch': 0,
 'train': True}
MAF(
  (net): FlowSequential(
    (0): MADE(
      (net_input): MaskedLinear(in_features=784, out_features=100, bias=True)
      (net): Sequential(
        (0): ReLU()
        (1): MaskedLinear(in_features=100, out_features=100, bias=True)
        (2): ReLU()
        (3): MaskedLinear(in_features=100, ou

In [8]:
train_and_evaluate(model, train_dataloader, test_dataloader, optimizer, args)

epoch   0 / 50, step    0 / 600; loss 1689.1567
Evaluate (epoch 0) -- logp(x) = -1662.797 +/- 2.399
epoch   1 / 50, step    0 / 600; loss 1662.8104
Evaluate (epoch 1) -- logp(x) = -1647.280 +/- 2.587
epoch   2 / 50, step    0 / 600; loss 1610.2786
Evaluate (epoch 2) -- logp(x) = -1633.524 +/- 2.519
epoch   3 / 50, step    0 / 600; loss 1622.4014
Evaluate (epoch 3) -- logp(x) = -1625.806 +/- 2.490
epoch   4 / 50, step    0 / 600; loss 1603.7141
Evaluate (epoch 4) -- logp(x) = -1620.161 +/- 2.479
epoch   5 / 50, step    0 / 600; loss 1604.9921
Evaluate (epoch 5) -- logp(x) = -1615.881 +/- 2.658
epoch   6 / 50, step    0 / 600; loss 1627.3715
Evaluate (epoch 6) -- logp(x) = -1612.920 +/- 2.551
epoch   7 / 50, step    0 / 600; loss 1586.4985
Evaluate (epoch 7) -- logp(x) = -1610.415 +/- 2.596
epoch   8 / 50, step    0 / 600; loss 1598.1616
Evaluate (epoch 8) -- logp(x) = -1608.782 +/- 2.634
epoch   9 / 50, step    0 / 600; loss 1573.5225
Evaluate (epoch 9) -- logp(x) = -1607.058 +/- 2.655


In [9]:
generate(model, train_dataloader.dataset.lam, args)

In [10]:
import numpy as np
model.eval()
all_Z = np.zeros((0, 784))
all_J = np.zeros((0, 784))
import pdb
for i, data in enumerate(train_dataloader):
    # check if labeled dataset
    if len(data) == 1:
        x, y = data[0], None
    else:
        x, y = data
        y = y.to(args.device)
    x = x.view(x.shape[0], -1).to(args.device)
    u, sum_log_abs_det_jacobians = model(x)
    u_n = u.detach().numpy()
    all_Z = np.vstack((all_Z, u_n))
    j_n = sum_log_abs_det_jacobians.detach().numpy()
    all_J = np.vstack((all_J, j_n))

In [11]:
import seaborn as sns

In [12]:
all_Z.shape

(60000, 784)

In [13]:
np.savetxt("Z.csv", all_Z, delimiter=',')

In [15]:
from sklearn.decomposition import PCA

In [17]:
pca = PCA(n_components=2)
X=pca.fit_transform(all_Z)

In [18]:
plt.scatter(X[:, 0], X[:, 1])
plt.show()

  plt.show()


In [19]:
import tkinter
import matplotlib
matplotlib.use('TkAgg')

In [20]:
plt.scatter(X[:, 0], X[:, 1])
plt.show()

In [25]:
np.mean(all_Z, axis=0)

array([-1.88765944e-02, -5.51762767e-02,  1.12748560e-02, -6.07067198e-02,
        3.23312868e-03,  2.22165896e-02, -9.21576362e-03, -1.26263284e-03,
        6.64300100e-03, -1.09491504e-02,  5.87760935e-02, -2.63461892e-02,
       -2.52269711e-04,  3.63298760e-02,  2.81959522e-02, -1.36317622e-02,
       -6.44017911e-03,  3.84784945e-02,  2.46643076e-02, -1.91733178e-02,
       -1.62476609e-02, -4.40315149e-02, -2.45038210e-02,  3.92330037e-02,
        3.61201974e-02, -2.28748634e-02, -3.17254403e-02,  1.65577713e-02,
       -2.62301030e-02,  6.57028353e-05,  3.15733209e-02,  1.69996257e-02,
        8.14984135e-03,  3.83423505e-02, -7.45518909e-02,  2.40184378e-03,
       -1.48085615e-02,  7.04247820e-03,  6.40777527e-03,  9.51270697e-03,
       -2.77288046e-02, -1.08414332e-03,  9.06339071e-03, -1.47761939e-02,
        1.14787243e-02,  2.43947661e-02,  2.31564527e-02, -9.03986044e-03,
       -3.64575656e-02, -1.79194306e-02,  4.06408540e-02,  1.61644810e-02,
        1.51516607e-02, -

In [27]:
np.std(all_Z, axis=0)

array([1.04976354, 1.00542878, 1.00180423, 1.01737874, 1.00549958,
       1.01039342, 1.01202564, 1.05135164, 0.99023927, 1.01122651,
       0.97200982, 1.04300541, 1.00258576, 1.01150667, 1.01926005,
       1.00373497, 1.02045127, 0.96275423, 0.95577446, 1.02327059,
       1.01158343, 1.01839834, 1.00299485, 0.94772549, 1.01429682,
       1.03504428, 1.01945449, 1.00892511, 1.03829897, 1.01120557,
       0.973932  , 0.94712612, 1.00198936, 0.94327696, 1.06532138,
       0.98187298, 1.02548556, 0.98881249, 1.00499885, 0.95672944,
       0.96856386, 0.99974339, 0.99565594, 0.97428335, 0.98018999,
       0.96985408, 0.9761144 , 0.9928966 , 1.01821399, 1.03637732,
       0.98142833, 1.02742561, 0.97636353, 1.00743014, 0.98820771,
       1.05074385, 0.98745265, 0.98092373, 1.0151592 , 1.03221518,
       0.99625426, 1.00757363, 1.03211199, 1.00540985, 1.00389014,
       0.96000655, 0.96114561, 0.96395515, 0.95503512, 0.94675772,
       0.91731205, 0.92953513, 0.9238405 , 0.89820126, 0.91839

In [28]:
weights = np.random.exponential(size=60000)

In [29]:
normalized_weights = weights / np.sum(weights)

In [31]:
weighted_ml_mean = [np.dot(normalized_weights, all_Z[:, i]) for i in range(all_Z.shape[1])]

In [33]:
plt.plot(np.mean(all_Z, axis=0), weighted_ml_mean, 'x')
plt.show()

In [36]:
np.outer(all_Z[0, :] - weighted_ml_mean, (all_Z[0, :] - weighted_ml_mean)).shape

(784, 784)

In [37]:
weighted_ml_cov = np.zeros((all_Z.shape[1], all_Z.shape[1]))
for i in range(all_Z.shape[0]):
    weighted_ml_cov += normalized_weights[i] * np.outer(all_Z[i, :] - weighted_ml_mean, (all_Z[i, :] - weighted_ml_mean))
    

In [38]:
plt.imshow(weighted_ml_cov)
plt.colorbar()
plt.show()

In [43]:
args.n_components

1

In [44]:
n_row = 10
u = model.base_dist.sample((n_row**2, 1)).squeeze()
u.shape

torch.Size([100, 784])

In [46]:
samples, _ = model.inverse(u)
log_probs = model.log_prob(samples).sort(0)[1].flip(0)

In [48]:
samples = samples[log_probs]

In [50]:
samples = samples.view(samples.shape[0], *args.input_dims)
samples.shape

torch.Size([100, 1, 28, 28])

In [53]:
train_dataloader.dataset.lam

1e-06

In [55]:
samples = (torch.sigmoid(samples) - train_dataloader.dataset.lam) / (1 - 2 * train_dataloader.dataset.lam)
save_image(samples, 'basic_samples.png', nrow=10, normalize=True)

In [57]:
u2 = np.random.multivariate_normal(mean=weighted_ml_mean, cov=weighted_ml_cov, size=100)

In [60]:
u2 = torch.from_numpy(u2)


In [65]:
u2.shape

torch.Size([100, 784])

In [70]:
samples2, _ = model.inverse(u2.float())
log_probs = model.log_prob(samples2).sort(0)[1].flip(0)
samples2 = samples2[log_probs]
samples2 = samples2.view(samples2.shape[0], *args.input_dims)
samples2 = (torch.sigmoid(samples2) - train_dataloader.dataset.lam) / (1 - 2 * train_dataloader.dataset.lam)
save_image(samples2, 'wlb_samples.png', nrow=10, normalize=True)

In [72]:
samples2, _ = model.inverse(u2.float())
model.log_prob(samples2)

tensor([-1705.3569, -1704.7877, -1570.6821, -1491.4120, -1455.0160, -1805.1970,
        -1625.3429, -1642.6885, -1618.1909, -1685.9089, -1356.2938, -1669.7491,
        -1603.5383, -1556.9873, -1550.9833, -1475.4789, -1629.1084, -1616.4868,
        -1564.8269, -1562.3669, -1553.0282, -1601.2928, -1485.4480, -1666.8209,
        -1533.0057, -1648.0332, -1746.2894, -1572.7708, -1641.4694, -1542.1995,
        -1552.6294, -1400.4916, -1781.2429, -1560.3727, -1661.6193, -1650.3163,
        -1658.0621, -1674.6748, -1554.9468, -1605.6161, -1401.5372, -1545.7150,
        -1635.0905, -1567.3464, -1548.8174, -1648.7311, -1602.8871, -1570.4633,
        -1657.1458, -1480.4958, -1609.6390, -1734.4767, -1612.4224, -1419.8280,
        -1674.7410, -1397.0052, -1551.0342, -1568.6194, -1481.4043, -1617.1187,
        -1584.9845, -1525.0138, -1566.6632, -1564.9041, -1373.5930, -1607.3364,
        -1549.4485, -1614.0040, -1589.5908, -1683.3782, -1443.3512, -1615.9037,
        -1538.1742, -1659.5015, -1633.89

In [74]:
unweighted_ml_cov = np.zeros((all_Z.shape[1], all_Z.shape[1]))
for i in range(all_Z.shape[0]):
    unweighted_ml_cov +=np.outer(all_Z[i, :], (all_Z[i, :]))
    
unweighted_ml_cov = unweighted_ml_cov / all_Z.shape[0]
    

In [76]:
u3 = np.random.multivariate_normal(mean=np.zeros(all_Z.shape[1]), cov=unweighted_ml_cov, size=100)
u3 = torch.from_numpy(u3)
samples3, _ = model.inverse(u3.float())
log_probs = model.log_prob(samples3).sort(0)[1].flip(0)
samples3 = samples3[log_probs]
samples3 = samples3.view(samples3.shape[0], *args.input_dims)
samples3 = (torch.sigmoid(samples3) - train_dataloader.dataset.lam) / (1 - 2 * train_dataloader.dataset.lam)
save_image(samples3, 'raw_weighted_samples.png', nrow=10, normalize=True)

In [77]:
args

Namespace(activation_fn='relu', batch_size=100, cond_label_size=None, conditional=False, data_dir='./data/', dataset='MNIST', device=device(type='cpu'), evaluate=False, flip_toy_var_order=False, generate=False, hidden_size=100, input_dims=(1, 28, 28), input_order='sequential', input_size=784, log_interval=1000, lr=0.0001, model='maf', n_blocks=5, n_components=1, n_epochs=50, n_hidden=1, no_batch_norm=False, no_cuda=False, output_dir='./results/maf_mnist', restore_file=None, results_file='./results/maf_mnist/results.txt', seed=1, start_epoch=0, train=True)

In [None]:
def generate_wlb_basic(model, dataset_lam, args, step=None, n_row=10):
    model.eval()

    u = model.base_dist.sample((n_row**2, args.n_components)).squeeze()
        samples, _ = model.inverse(u)
        log_probs = model.log_prob(samples).sort(0)[1].flip(0)  # sort by log_prob; take argsort idxs; flip high to low
        samples = samples[log_probs]

    # convert and save images
    samples = samples.view(samples.shape[0], *args.input_dims)
    samples = (torch.sigmoid(samples) - dataset_lam) / (1 - 2 * dataset_lam)
    filename = 'generated_samples' + (step != None)*'_epoch_{}'.format(step) + '.png'
    save_image(samples, os.path.join(args.output_dir, filename), nrow=n_row, normalize=True)