In [1]:
import sys 
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import copy

sys.path.insert(1, "/home/paperspace/devinterp") # TODO fix path

from devinterp.slt.sampler import Sampler, SamplerConfig, estimate_rlct
from torch.utils.data import TensorDataset
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"



In [2]:
def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    for data, target in tqdm(train_loader):
        optimizer.zero_grad()
        output = model(data.to(DEVICE))
        loss = criterion(output, target.to(DEVICE))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    return train_loss / len(train_loader)
def train_one_batch(model, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data.to(DEVICE))
        loss = criterion(output, target.to(DEVICE))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        break
    return train_loss
def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.to(DEVICE))
            loss = criterion(output, target.to(DEVICE))
            test_loss += loss.item()
    model.train()
    return test_loss / len(test_loader)


In [3]:

class PolyModel(nn.Module):
    def __init__(self, powers):
        super(PolyModel, self).__init__()
        self.weights = nn.Parameter(torch.tensor([0., 0.],dtype=torch.float32, requires_grad=True))
        self.powers = powers
    def forward(self, x):
        multiplied = torch.min(self.weights**self.powers)
        x = x*multiplied
        return x

# criterion = nn.MSELoss()
powers = torch.tensor([2, 2])
model = PolyModel(powers)
seed=0
sigma=0.5
lr=0.0001
num_steps=2000
num_train_samples = 5000
num_test_samples = 1000
batch_size=num_train_samples
with_trajectory=False
w_true = torch.zeros_like(powers)

x = torch.normal(0, 2, size=(num_train_samples,))
y = sigma * torch.normal(0, 1, size=(num_train_samples,))
train_data = TensorDataset(x, y)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

x_test = torch.normal(0, 2, size=(num_test_samples,))
y_test =  sigma * torch.normal(0, 1, size=(num_test_samples,))
test_data = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.01)

# train model
train_losses = []
test_losses = []
rlct_estimates = []
models = []
test_loss = 0.
n_epochs = 20
n_steps = 40

def print_rlcts(SGNHT_config, SGLD_config):
    rlct_estimates_sgnht = []
    rlct_estimates_sgdl = []
    for epoch in range(n_epochs):
        # loss = train_one_epoch(model, train_loader, optimizer, criterion)
        # test_loss = evaluate(model, test_loader, criterion)
        # print(f"Epoch {epoch+1}, Train Loss: {loss}, Test Loss: {test_loss}")
        # print(model.state_dict()['weights'])
        model_for_rlct= copy.deepcopy(model)
        model_for_rlct2= copy.deepcopy(model)
        sgnht_sampler = Sampler(model_for_rlct, train_data, SGNHT_config)
        sgdl_sampler = Sampler(model_for_rlct2, train_data, SGLD_config)
        rlct_estimate_sgnht = sgnht_sampler.sample(summary_fn=estimate_rlct)
        rlct_estimate_sgdl = sgdl_sampler.sample(summary_fn=estimate_rlct)
        rlct_estimates_sgnht += [rlct_estimate_sgnht]
        rlct_estimates_sgdl += [rlct_estimate_sgdl]
        print(rlct_estimate_sgnht, rlct_estimate_sgdl)
        # raise Exception
    plt.hist(rlct_estimates_sgnht,alpha = 0.5, label='sgnht')
    plt.hist(rlct_estimates_sgdl,alpha = 0.5, label='sgdl')
    plt.legend()
    plt.show()



In [4]:
# Initialize sgnht sampler
SGNHT_config = SamplerConfig(
    optimizer_config=dict(
        optimizer_type="SGNHT",
        lr=lr,
        diffusion_factor=0.01,
        bounding_box_size=1.,
        num_samples=len(train_data),
        batch_size = batch_size,
    ),
    num_chains=1,
    num_draws_per_chain=2_000,
    num_burnin_steps=0,
    num_steps_bw_draws=1,
    batch_size=batch_size,         
    criterion = 'mse_loss' 
)
SGLD_config = SamplerConfig(
    optimizer_config=dict(
        optimizer_type="SGLD",
        lr=lr,
        noise_level=.9,
        weight_decay=0.1,
        elasticity=1,
        temperature='adaptive',
        num_samples=len(train_data),
    ),
    num_chains=1,
    num_draws_per_chain=10_000,
    num_burnin_steps=0,
    num_steps_bw_draws=1,
    verbose=False,
    batch_size=batch_size,         
    criterion = 'mse_loss'
)
print_rlcts(SGNHT_config, SGLD_config)

35.70883560180664 0.15305842459201813
52.057525634765625 0.08763440698385239
41.39814376831055 0.12417352944612503
47.85388946533203 0.08077621459960938
64.03759765625 0.0320165641605854
75.64456939697266 0.0926293432712555
74.22412872314453 0.16498152911663055
52.45056915283203 0.03957457095384598
85.16056823730469 0.08606857061386108
48.357059478759766 0.07203727215528488
81.48987579345703 0.16176237165927887
65.49381256103516 0.05447189882397652
