In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import time

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
def get_kl(M, N, k, post_sigma, elem_sigma, prior_sigma=1.0):
    # z : (M, z)
    # mu : (N, z)
    
    z = torch.randn(M, k).cuda() * prior_sigma
    mu = torch.randn(N, k).cuda() * post_sigma
    
    # (M, N) = sum((M, 1, z) - (1, N, z), dim=2)
    distance = torch.norm(z.unsqueeze(1) - mu.unsqueeze(0), dim=2) ** 2
    
    loss = -torch.mean(torch.logsumexp(-1/(2*elem_sigma**2) * distance, dim=1))
    return loss.item()

In [None]:
from tqdm import tqdm

M = 1024
Ns = np.array([2**i for i in range(0, 18)])
ks = [1, 2, 4, 8, 16]
post_sigmas = np.arange(0.0, 2.0, 0.01)
elem_sigma = 1e-1

while True:
    results = np.array([[[get_kl(M, N, k, post_sigma, elem_sigma) for k in ks] for N in Ns] for post_sigma in post_sigmas])
    exp_id = 'normal_' + str(int(time.time() * 1000))
    filename = 'exps_lse/' + exp_id + '.npz'
    np.savez(filename, results)