### General structure
1. Project to bottleneck representation
```python
Z = bottleneck(h) 
```

2. split into means and log-covariances
```python
mus, logsigs = torch.chunk(Z, 2, dim=2)
```

3. $KL(q(z | x) || p(z)) >= I(x, z)$
   where $p = N(0,I^d)$
```python
kld = -0.5 * (1 + logsigs - mus.pow(2) - logsigs.exp()).sum(2).mean(1).mean(0)
``` 
get logits for p(y|z) = Wz using z ~ q(z | x)
```python
sample = mus + torch.randn_like(logsigs) * torch.exp(0.5 * logsigs)
logits = w_pos(sample)
```
4. compute cross entropy using the sample $\mathbb E_{q(z | x)} [ - \log p(y | z) ] \cong I(Z;Y)$
```python
loss = cross_entropy(logits, pos_ids)
```

5. get full IB loss
```python
ib_loss = loss + beta * kld
```

In [37]:
import torch

In [77]:
# pretend Z is the output of the encoder layer (Z = W_enc(H) = q(z|h))
batchsize, maxsentlen, embeddingsize = 2,4,3
Z = torch.randint(3,(batchsize, maxsentlen, 2*embeddingsize)).type(torch.float32)
print("Z:",Z,sep='\n')
mus,logsigs = torch.chunk(Z, 2, dim=2)
# print(f"mus:\n{mus}\nlogsigs:\n{logsigs}")
# KLD(q(z | x) || p(z)) 
kld = 0.5 * (mus.pow(2) + logsigs.exp() - logsigs - 1).sum(2)#.mean(1).mean(0)
print("kld:",kld,sep='\n')
sample = mus + torch.randn_like(logsigs) * torch.exp(0.5 * logsigs)

Z:
tensor([[[2., 1., 0., 2., 2., 1.],
         [1., 0., 0., 1., 1., 2.],
         [2., 2., 2., 2., 2., 0.],
         [1., 2., 0., 1., 2., 2.]],

        [[2., 0., 0., 2., 1., 1.],
         [0., 0., 1., 1., 0., 2.],
         [0., 2., 0., 1., 2., 1.],
         [1., 1., 0., 1., 2., 0.]]])
kld:
tensor([[ 7.2482,  3.4128, 10.3891,  7.2482],
        [ 4.9128,  3.0537,  4.9128,  3.5537]])


In [78]:
(kld.size(0), kld.size(1)) == (batchsize, maxsentlen)

True

In [87]:
X = Z.mean(2)
print(X)
mask = X != 1
print(mask)
k=((X*mask).sum(dim=1)/mask.sum(dim=1))
print(k)
k.sum(0)

tensor([[1.3333, 0.8333, 1.6667, 1.3333],
        [1.0000, 0.6667, 1.0000, 0.8333]])
tensor([[ True,  True,  True,  True],
        [False,  True, False,  True]])
tensor([1.2917, 0.7500])


tensor(2.0417)

In [90]:
torch.randn_like(logsigs)

tensor([[[-0.0252,  0.0208,  0.0599],
         [-1.0442,  0.8686, -0.9354],
         [ 1.1752, -0.5884,  1.1747],
         [ 0.4924,  1.1959, -0.2932]],

        [[ 0.3948, -2.2222, -2.2528],
         [ 2.7600,  1.2516, -0.3316],
         [ 1.1209, -0.3118, -1.6234],
         [ 0.6532, -0.4413, -0.5855]]])

In [1]:
IB_TRAIN_OPTS = dict(
        algorithm='adam',
        hyperparams=dict(lr=1e-3, weight_decay=0.00001, momentum=0.9)
        )

In [2]:
import yaml

In [3]:
print(yaml.dump(IB_TRAIN_OPTS))

algorithm: adam
hyperparams:
  lr: 0.001
  momentum: 0.9
  weight_decay: 1.0e-05



In [22]:
with open("test2", "w") as f:
    yaml.dump(IB_TRAIN_OPTS, f)

In [14]:
with open("test", "r") as f:
    d = yaml.safe_load(f)

In [19]:
d['hyperparams']['weight_decay']

1e-05

In [20]:
d['dev']=22

In [35]:
d.update(x)