In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from metal.utils import split_data
from visualization_utils import visualize_data
from data_generators import generate_pacman_data

config = {
    'N': 10000,
    'mus': np.array([[0,  0], [5, 0]]),
    'variances': [1.5, 5],
    'labels': [2, 1],
    'lf_metrics': [('recall', 1.0), ('recall', 1.0)],
}

X, Y, Z, L = generate_pacman_data(config)
Ls, Xs, Ys, Zs = split_data(L, X, Y, Z, splits=[0.5, 0.25, 0.25], shuffle=True)

# visualize_data(X, Y, Z, L)

In [3]:
from metal.contrib.backends.snorkel_gm_wrapper import SnorkelLabelModel

# replace GT with weak labels
label_model = SnorkelLabelModel()
label_model.train_model(Ls[0])
Y_train = label_model.predict_proba(Ls[0])
Ys[0] = Y_train

# for testing LFs
X_test = torch.Tensor(Xs[2])
L_test = Ls[2]

## Confirm that we can recover the L matrix

Train slice-aware model with `slice_weight=1.0`. 
Then, try to repredict the `L_test` values using the `L_head`.

In [4]:
from metal.contrib.slicing.experiment_utils import (
    create_data_loader,
    train_model,
    train_slice_dp,
    eval_model
)
from metal.contrib.slicing.utils import get_L_weights_from_targeting_lfs_idx

### SliceDP

In [5]:
import torch.nn.functional as F

from metal.contrib.slicing.online_dp import MLPModule
from metal.metrics import accuracy_score


train_kwargs = {
    "disable_prog_bar": True,
    "verbose": True,
    "n_epochs": 20,
    "lr": 0.005,
    "l2": 1e-7,
}
sm_dp_config = {
    'slice_kwargs': {
        'r': 5,
        'slice_weight': 1.0,
        'reweight': True
    },
    'train_kwargs': train_kwargs,
    'input_module_class': MLPModule,
    'input_module_init_kwargs': {
        'input_dim': 2,
        'output_dim': 5,
        'middle_dims': [5],
        'bias': True
    }
}
model = train_slice_dp(sm_dp_config, Ls, Xs, Ys, Zs)
print(model)

# L_preds = model.predict_L_proba(X_test)
L_preds = F.sigmoid(model.forward_L(X_test)).data.cpu().numpy()
preds = (L_preds > 0.5) * 1

L_gt = L_test.copy()
L_gt[L_gt != 0] = 1
print ('predicted L distribution:', np.sum(preds, axis=0))
print ('accuracy over LF0:', accuracy_score(L_gt[:, 0], preds[:, 0]))
print ('accuracy over LF1:', accuracy_score(L_gt[:, 1], preds[:, 1]))
print ('accuracy over LF2:', accuracy_score(L_gt[:, 2], preds[:, 2]))

  A = F.softmax(self.forward_L(x)).unsqueeze(1)


Slice Heads:
Reweighting: True
L_weights: tensor([[1.],
        [1.],
        [1.]])
Slice Weight: 1.0
Input Network: Sequential(
  (0): MLPModule(
    (input_layer): Sequential(
      (0): Linear(in_features=2, out_features=5, bias=True)
      (1): ReLU()
      (2): Linear(in_features=5, out_features=5, bias=True)
      (3): ReLU()
    )
  )
)
L_head: Linear(in_features=5, out_features=3, bias=False)
Y_head: Linear(in_features=10, out_features=2, bias=True)
Criteria: BCEWithLogitsLoss() SoftCrossEntropyLoss()


  return F.softmax(self.forward_Y(x)).data.cpu().numpy()


[1 epo]: train/loss=0.202, valid/accuracy=0.967
[2 epo]: train/loss=0.118, valid/accuracy=0.973
[3 epo]: train/loss=0.102, valid/accuracy=0.967
[4 epo]: train/loss=0.087, valid/accuracy=0.513
[5 epo]: train/loss=0.074, valid/accuracy=0.500
[6 epo]: train/loss=0.064, valid/accuracy=0.508
[7 epo]: train/loss=0.057, valid/accuracy=0.497
[8 epo]: train/loss=0.050, valid/accuracy=0.497
[9 epo]: train/loss=0.046, valid/accuracy=0.497
[10 epo]: train/loss=0.041, valid/accuracy=0.480
[11 epo]: train/loss=0.037, valid/accuracy=0.499
[12 epo]: train/loss=0.035, valid/accuracy=0.510
[13 epo]: train/loss=0.033, valid/accuracy=0.503
[14 epo]: train/loss=0.032, valid/accuracy=0.489
[15 epo]: train/loss=0.031, valid/accuracy=0.482
[16 epo]: train/loss=0.030, valid/accuracy=0.481
[17 epo]: train/loss=0.030, valid/accuracy=0.498
[18 epo]: train/loss=0.029, valid/accuracy=0.498
[19 epo]: train/loss=0.029, valid/accuracy=0.522
[20 epo]: train/loss=0.028, valid/accuracy=0.487
Finished Training
Accuracy: 0



### SliceHat

In [13]:
from metal.metrics import accuracy_score

def calc_L_accuracy(model, data_loader):
    X, L, Y = data_loader.dataset.data
    L_probs = model.predict_L_proba(torch.Tensor(X))
    L_preds = np.round(L_probs)
    score = accuracy_score(L.reshape(-1,1), L_preds.reshape(-1,1))
    return {"train/acc": score}

In [14]:
end_model_init_kwargs = {
    "layer_out_dims": [2, 5, 5, 2]
}
sm_hat_config = {
    "end_model_init_kwargs": end_model_init_kwargs,
    "slice_kwargs": {
        "slice_weight": 1.0,
        "reweight": True,
    },
    "train_kwargs": {
        "n_epochs": 20,
        "lr": 0.001,
        "log_unit": "epochs",
        "log_train_metrics_func": calc_L_accuracy,
        "log_train_metrics": ["train/loss", "train/acc"],
        "log_train_every": 1,
        "log_valid_every": 1,
    }
}

model = train_model(sm_hat_config, Ls, Xs, Ys, Zs, model_key="hat")
print(model)

L_preds = torch.sigmoid(abs(model.L_head(model.body(X_test))))
preds = (L_preds > 0.5) * 1

L_gt = L_test.copy()
L_gt[L_gt != 0] = 1
print('predicted L distribution:', np.unique(preds, axis=0))
print('accuracy over LF0:', accuracy_score(L_gt[:, 0], preds[:, 0]))
print('accuracy over LF1:', accuracy_score(L_gt[:, 1], preds[:, 1]))
print('accuracy over LF2:', accuracy_score(L_gt[:, 2], preds[:, 2]))


Network architecture:
Sequential(
  (0): IdentityModule()
  (1): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
    (1): ReLU()
  )
  (2): Sequential(
    (0): Linear(in_features=5, out_features=5, bias=True)
    (1): ReLU()
  )
  (3): Linear(in_features=5, out_features=2, bias=True)
)

Resetting base model parameters
SliceHatModel(
  (body): Sequential(
    (0): IdentityModule()
    (1): Sequential(
      (0): Linear(in_features=2, out_features=5, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=5, out_features=5, bias=True)
      (1): ReLU()
    )
  )
  (L_head): Linear(in_features=5, out_features=3, bias=False)
  (L_criteria): BCEWithLogitsLoss()
)

[1 epo]: train/loss=0.149, train/acc=0.901, valid/accuracy=0.915
[2 epo]: train/loss=0.037, train/acc=0.659, valid/accuracy=0.915
[3 epo]: train/loss=0.009, train/acc=0.659, valid/accuracy=0.915
[4 epo]: train/loss=0.003, train/acc=0.659, valid/accuracy=0.915
[5 epo]: train/los