In [1]:

import sys 
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
sys.path.insert(1, "c:\\Users\\svwin\\New folder\\devinterp") # TODO fix path

from devinterp.slt.sampler import Sampler, SamplerConfig, estimate_rlct



In [2]:

def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    for data, target in tqdm(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    return train_loss / len(train_loader)

def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
    return test_loss / len(test_loader)

# Define the neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
# Load MNIST data
train_data = datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Load test data
test_data = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
# Initialize model, loss, optimizer and sgld sampler
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
sampler_config = SamplerConfig(
    optimizer_config=dict(
        optimizer_type="SGLD",
        lr=0.001,
        noise_level=0.5,
        weight_decay=0.,
        elasticity=1.,
        temperature='adaptive',
        num_samples=100,
    ),
    num_chains=50,
    num_draws_per_chain=20,
    num_burnin_steps=0,
    num_steps_bw_draws=1,
    verbose=False,
    batch_size=256,        
    criterion = 'cross_entropy' # alternatives: mse
)
sampler = Sampler(model, train_data, sampler_config)

ValidationError: 1 validation error for SamplerConfig
criterion
  Input should be callable [type=callable_type, input_value='cross_entropy', input_type=str]
    For further information visit https://errors.pydantic.dev/2.1/v/callable_type

In [5]:
# train model
train_losses = []
test_losses = []
rlct_estimates = []
for epoch in range(40):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    test_loss = evaluate(model, test_loader, criterion)
    rlct_estimate = sampler.sample(summary_fn=estimate_rlct)

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    rlct_estimates.append(rlct_estimate)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Test Loss: {test_loss}, RLCT estimate: {rlct_estimate}")

# Plotting
fig, ax1 = plt.subplots()

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color='tab:blue')
ax1.plot(train_losses, label='Train Loss', color='tab:blue')
ax1.plot(test_losses, label='Test Loss', color='tab:orange')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax1.legend(loc='upper left')

ax2 = ax1.twinx()
ax2.set_ylabel('rlct_estimate', color='tab:green')
ax2.plot(rlct_estimates, label='rlct_estimate', color='tab:green')
ax2.tick_params(axis='y', labelcolor='tab:green')
ax2.legend(loc='upper right')

fig.tight_layout()
plt.show()


100%|██████████| 1875/1875 [00:08<00:00, 219.14it/s]


TypeError: forward() got an unexpected keyword argument 'reduction'