In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
import torch
from torch import nn
import gin
from sparse_causal_model_learner_rl.trainable.gumbel_switch import WithInputSwitch, sample_from_logits_simple
gin.enter_interactive_mode()

In [2]:
class LinearModel(nn.Module):
    def __init__(self, input_shape):
        super(LinearModel, self).__init__()
        self.layer = nn.Linear(in_features=10, out_features=1, bias=True)
    def forward(self, x):
        return self.layer(x)

In [3]:
gin.bind_parameter('LearnableSwitch.switch_neg', 0)
gin.bind_parameter('LearnableSwitch.tau', 1.0)
gin.bind_parameter('sample_from_logits_simple.tau', 1.0)
gin.bind_parameter('LearnableSwitch.sample_threshold', 1 - 1e-1)
# gin.bind_parameter('LearnableSwitch.sample_fcn', sample_from_logits_simple)
gin.bind_parameter('LearnableSwitch.sample_many', True)
gin.bind_parameter('LearnableSwitch.switch_pos', 0)
gin.bind_parameter('WithInputSwitch.enable_switch', True)

In [4]:
wsw = WithInputSwitch(model_cls=LinearModel, input_shape=(10,))

In [5]:
X = torch.from_numpy(np.random.randn(5000, 10)).to(torch.float32)

In [6]:
y = X[:, 0].view(-1, 1)

In [7]:
wsw.switch(X, return_mask=True)

tensor([[0., 0., 0.,  ..., 0., 1., 1.],
        [1., 0., 0.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 1., 0., 1.],
        [1., 0., 0.,  ..., 1., 0., 1.]], grad_fn=<SelectBackward>)

In [8]:
loss_obj = nn.MSELoss()

In [9]:
opt = torch.optim.Adam(wsw.parameters(), lr=1e-3)
losses = []

In [None]:
for _ in tqdm(range(10000)):
    opt.zero_grad()
    loss = loss_obj(wsw(X), y)
    loss.backward()
    opt.step()
    losses.append(loss.item())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))

In [None]:
plt.plot(losses)
plt.yscale('log')

In [None]:
wsw.switch.softmaxed()

In [None]:
wsw.switch.forward(X, return_mask=True)

In [None]:
# def sample_from_logits_simple(logits_plus):
#     probas = torch.nn.Sigmoid()(logits_plus)
#     sampled = torch.bernoulli(probas)
#     return sampled + probas - probas.detach()

In [None]:
sample_from_logits_simple(wsw.switch.logits).mean().backward()

In [None]:
wsw.switch.logits.grad