In [1]:
import random
import unittest
from functools import partial

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from snorkel.mtl.data import MultitaskDataLoader, MultitaskDataset
from snorkel.mtl.model import MultitaskModel
from snorkel.mtl.modules.utils import ce_loss, softmax
from snorkel.mtl.scorer import Scorer
from snorkel.mtl.task import Task
from snorkel.mtl.trainer import Trainer

## Make data and tasks

In [2]:
def create_data(n):

    X = np.random.random((n, 2)) * 2 - 1
    Y = np.zeros((n, 2))
    Y[:, 0] = (X[:, 0] > X[:, 1] + 0.5).astype(int) + 1
    Y[:, 1] = (X[:, 0] > X[:, 1] + 0.25).astype(int) + 1
    
    df = pd.DataFrame(
        {"x1": X[:,0],
         "x2": X[:,1],
         "y1": Y[:,0],
         "y2": Y[:,1],
        }
    )
    return df

df_train = create_data(1000)
df_valid = create_data(100)
df_test = create_data(100)

In [3]:
def create_dataloader(df, split):
    Y_dict = {}
    task_to_label_dict = {}
    dataloaders = []
    
    Y_dict[f"task1_labels"] = torch.LongTensor(df["y1"])
    task_to_label_dict["task1"] = "task1_labels"
    
    Y_dict[f"task2_labels"] = torch.LongTensor(df["y2"])
    task_to_label_dict["task2"] = "task2_labels"

    dataset = MultitaskDataset(
        name="TestData", 
        X_dict={"coordinates": 
                torch.stack((torch.Tensor(df["x1"]), torch.Tensor(df["x2"])), dim=1)}, 
        Y_dict=Y_dict
    )

    dataloader = MultitaskDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dataset,
        split=split,
        batch_size=4,
        shuffle=(split == "train"),
    )
    return dataloader

In [4]:
def create_task(task_name, module_suffixes):
    module_pool = nn.ModuleDict(
        {
            f"linear1{module_suffixes[0]}": nn.Linear(2, 10),
            f"linear2{module_suffixes[1]}": nn.Linear(10, 2),
        }
    )

    task_flow = [
        {
            "name": "first_layer",
            "module": f"linear1{module_suffixes[0]}",
            "inputs": [("_input_", "coordinates")],
        },
        {
            "name": "second_layer",
            "module": f"linear2{module_suffixes[1]}",
            "inputs": [("first_layer", 0)],
        },
    ]

    task = Task(
        name=task_name,
        module_pool=module_pool,
        task_flow=task_flow,
        loss_func=partial(ce_loss, "second_layer"),
        output_func=partial(softmax, "second_layer"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    return task

## Confirm it builds and trains

In [5]:
dataloaders = []
for df, split in [(df_train, "train"), (df_valid, "valid"), (df_test, "test")]:
    dataloader = create_dataloader(df, split)
    dataloaders.append(dataloader)
task1 = create_task("task1", module_suffixes=["A", "A"])
task2 = create_task("task2", module_suffixes=["A", "B"])
model = MultitaskModel(tasks=[task1, task2])
scores = model.score(dataloaders)
print(scores)

{'task1/TestData/train/accuracy': 0.408, 'task2/TestData/train/accuracy': 0.434, 'task1/TestData/valid/accuracy': 0.46, 'task2/TestData/valid/accuracy': 0.52, 'task1/TestData/test/accuracy': 0.35, 'task2/TestData/test/accuracy': 0.42}


In [6]:
trainer_config = {"n_epochs": 2, "progress_bar": False}
logger_config = {"counter_unit": "epochs", "evaluation_freq": 0.25}

trainer = Trainer(**trainer_config, **logger_config)
trainer.train_model(model, dataloaders)
scores = model.score(dataloaders)
print(scores)

{'task1/TestData/train/accuracy': 0.971, 'task2/TestData/train/accuracy': 0.935, 'task1/TestData/valid/accuracy': 0.97, 'task2/TestData/valid/accuracy': 0.93, 'task1/TestData/test/accuracy': 0.97, 'task2/TestData/test/accuracy': 0.96}


## Add slicing functions

In [7]:
from snorkel.types import DataPoint
from snorkel.labeling.apply import LFApplier, PandasLFApplier
from snorkel.labeling.lf import labeling_function

@labeling_function()
def lt42(x: DataPoint) -> int:
    return 1 if x.x1 > 0.75 else 0

slicing_functions = [lt42]

In [8]:
applier = PandasLFApplier([lt42])
slice_labels = applier.apply(df_train)
slice_names = [sf.name for sf in slicing_functions]

100%|██████████| 1000/1000 [00:00<00:00, 36583.23it/s]


In [9]:
from typing import List
from snorkel.mtl.task import Task
from snorkel.mtl.data import MultitaskDataLoader
from scipy.sparse import csr_matrix

def _update_dataloaders(base_task, dataloaders, slice_labels, slice_names):
    # Update dataloaders
    for dataloader in dataloaders:
        label_name = dataloader.task_to_label_dict[base_task.name]
        labels = dataloader.dataset.Y_dict[label_name]
        for i, slice_name in enumerate(slice_names):
            # Convert labels
            ind_labels = torch.LongTensor(slice_labels[:,i]) # [n, 1]
            pred_labels = ind_labels * labels

            ind_task_name = f"{base_task.name}_{slice_name}_ind"
            pred_task_name = f"{base_task.name}_{slice_name}_pred"
            
            # Update dataloaders
            dataloader.dataset.Y_dict[ind_task_name] = ind_labels
            dataloader.dataset.Y_dict[pred_task_name] = pred_labels

            dataloader.task_to_label_dict[ind_task_name] = ind_labels
            dataloader.task_to_label_dict[pred_task_name] = pred_labels
    return dataloaders
    
    
def add_slices(
    base_task: Task, 
    dataloaders: List[MultitaskDataLoader], 
    slice_labels: csr_matrix,  # [n, m]
    slice_names: List[str],
):
    """Adds slice labels to dataloaders and creates new slice tasks (including base slice)"""    
    # Add base slice
    num_points, num_slices = slice_labels.shape
    base_labels = np.ones((num_points, 1), dtype=int)
    slice_labels = np.hstack([slice_labels.toarray(), base_labels])
    slice_names.append(f"base")
    
    dataloaders = _update_dataloaders(base_task, dataloaders, slice_labels, slice_names)
  
    # ----- Update tasks -----
    tasks = []
    
    # Identify base task head and shoulder modules
    head_module_name = base_task.task_flow[-1]["name"]
    head_module = base_task.module_pool[head_module_name]
    if isinstance(head_module, nn.DataParallel):
        head_module = head_module.module
    
    shoulder_module_name = base_task.task_flow[-2]["name"]
    shoulder_module = base_task.module_pool[shoulder_module_name]
    if isinstance(shoulder_module, nn.DataParallel):
        shoulder_module = shoulder_module.module    
    
    try:
        neck_size = head_module.in_features
        cardinality = head_module.out_features
    except AttributeError:  # Go one layer deeper past nn.DataParallel
        neck_size = head_module.in_features
        cardinality = head_module.out_features

In [None]:
tasks = [task1]

# NOTE: modify dataloaders in place, but replace base task that goes in
task2_tasks, dataloaders = add_slices(task2, dataloaders, slice_labels, slice_names)
tasks = [task1] + task2_tasks

In [None]:
dataloaders[0].task_to_label_dict