In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import time

!nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"]="4"

Sun Jan 14 04:40:26 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:1B:00.0 Off |                    0 |
| 37%   66C    P2   216W / 230W |  16686MiB / 23028MiB |     82%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000    Off  | 00000000:1C:00.0 Off |                  Off |
| 42%   70C    P2   219W / 230W |  16686MiB / 24564MiB |     82%      Default |
|       

In [3]:
def get_kl(M, N, k, post_sigma, elem_sigma, prior_sigma=1.0):
    # z : (M, z)
    # mu : (N, z)
    
    z = (torch.rand(M, k).cuda()*2-1) * prior_sigma
    mu = (torch.rand(N, k).cuda()*2-1) * post_sigma
            
    # (M, N) = sum((M, 1, z) - (1, N, z), dim=2)
    distance = torch.norm(z.unsqueeze(1) - mu.unsqueeze(0), dim=2) ** 2
    alpha = -1/(2*elem_sigma**2)
    loss = -torch.mean(torch.logsumexp(alpha*distance, dim=1), dim=0)
    loss = loss + 0.5*k*(2*np.log(elem_sigma)-np.log(np.e)) + np.log(N)
    
    return loss.item()

In [None]:
from tqdm import tqdm

M = 1000
N = 1024
k = 8
post_sigma = 1
elem_log_sigmas = np.arange(-10, 10, 0.1)

losses_list = []
for _ in tqdm(range(100)):
    losses = np.array([get_kl(M, N, k, post_sigma, np.exp(elem_log_sigma)) for elem_log_sigma in elem_log_sigmas])
    losses_list.append(losses)
    
losses = np.mean(np.array(losses_list), axis=0)
optimum_sigma = np.exp(elem_log_sigmas[np.argmin(losses)])
optimum_sigma

 15%|██████▎                                   | 15/100 [00:01<00:07, 10.66it/s]

In [12]:
np.log(optimum_sigma)

-0.4000000000000341

In [13]:
for i in np.arange(-2, 2, 0.4):
    print(np.exp(np.log(optimum_sigma) + i))

0.0907179532894094
0.13533528323660807
0.2018965179946485
0.30119421191219176
0.44932896411720613
0.6703200460356161
0.9999999999999654
1.4918246976412186
2.22554092849239
3.3201169227364318
