In [2]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
from ut.models.ensemble import Ensemble
import sklearn.datasets as sd
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# make a simple 2d dataset
X, y = sd.make_moons(n_samples=200, noise=0.1)
X_train, y_train, X_test, y_test = X[:100], y[:100], X[100:], y[100:]

# convert to torch tensors and make dataloader
train = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)
)
train_loader = DataLoader(train, batch_size=32, shuffle=True)
test = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)
)
test_loader = DataLoader(test, batch_size=32, shuffle=False)


# small fully connected neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 2)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [4]:
ensemble = Ensemble(net, 5)
criterion = nn.CrossEntropyLoss()
for model in tqdm(ensemble.models):
    optimizer = optim.Adam(model.parameters())
    for _ in range(10):
        model.train()
        for x, y in train_loader:
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

# compute accuracy on test set
correct = 0
total = 0
for x, y in test_loader:
    ensemble.eval()
    y_pred = ensemble(x).mean(axis=1)
    y_pred = y_pred.argmax(axis=1)
    correct += (y_pred == y).sum().item()
    total += y.size(0)
print(f"Accuracy: {correct / total}")

100%|██████████| 5/5 [00:00<00:00,  5.72it/s]

Accuracy: 0.78



