In [1]:
## Make synthetic data
import torch
from torch.utils.data import DataLoader
from snorkel.model.utils import MetalDataset

n = 1200
X = torch.FloatTensor(np.random.random((n, 2)) * 2 - 1)
Y = torch.LongTensor((X[:, 0] > X[:, 1] + 0.5).long() + 1)

datasets = []
datasets.append(MetalDataset(X[:1000], Y[:1000]))
datasets.append(MetalDataset(X[1000:1100], Y[1000:1100]))
datasets.append(MetalDataset(X[1100:], Y[1100:]))

dataloaders = []
for dataset, split in zip(datasets, ["train", "valid", "test"]):
    dataloader = DataLoader(dataset, batch_size=4)
    dataloader.split = split
    dataloaders.append(dataloader)

In [2]:
## Add necessary fields to DataLoaders
# (do this inside of trainer if passed vanilla DataLoaders)

from typing import List
from snorkel.mtl.data import MultitaskDataset, MultitaskDataLoader

def upgrade_dataloaders(dataloaders: List[DataLoader]):
    new_dataloaders = []
    for dataloader in dataloaders:
        dataset = dataloader.dataset

        new_dataset = MultitaskDataset(
            name=f"data_{dataloader.split}", 
            X_dict={"data": dataset.X},  # This op is specific to TensorDataset
            Y_dict={"labels": dataset.Y} # Maybe
        )
        new_dataloader = MultitaskDataLoader(
            task_to_label_dict={"task": "labels"},
            dataset=new_dataset,
            split=dataloader.split,
            batch_size=dataloader.batch_size,
            shuffle=(dataloader.split == "train")
        )
        new_dataloaders.append(new_dataloader)
    return new_dataloaders

dataloaders = upgrade_dataloaders(dataloaders)

In [3]:
## Build SimpleModel
import torch.nn as nn
from snorkel.mtl.simple_model import SimpleModel

modules = [
    nn.Linear(2, 10), 
    nn.Linear(10, 2),
]
model = SimpleModel(modules)
print(model)

SimpleModel(name=SimpleModel)


In [4]:
# Train SimpleModel
from snorkel.mtl.trainer import Trainer
trainer = Trainer(progress_bar=False, n_epochs=5)
trainer.train_model(model, dataloaders)
scores = model.score(dataloaders)
print(scores)


{'task/data_train/train/accuracy': 0.994, 'task/data_valid/valid/accuracy': 0.99, 'task/data_test/test/accuracy': 1.0}
