In [2]:
from train import ParameterizedShapleyEstimator
from datasets import load_FashionMNIST, get_split_dataset, extract_features
from cross_attention_model import EstimatorNetwork
import torch
import os
import numpy as np

# Setting this becases I am getting "GET was unable to find a computation engine"
torch.backends.cudnn.enabled = False
device = torch.device('cuda')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# download datasets
import torchvision

# make directories
if not os.path.exists("./datasets/"):
    os.mkdir("./datasets")

    train_ds = torchvision.datasets.FashionMNIST("./datasets/", train = True, download = True)
    test_ds = torchvision.datasets.FashionMNIST("./datasets/", train = False, download = True)

In [4]:
if os.path.exists("./datasets/train.npy"):
    print("Loading existing dataset")
    train = torch.load("./datasets/train.npy")
    test = torch.load("./datasets/test.npy")
    valid = torch.load("./datasets/valid.npy")
else:
    data_path = [("./datasets/FashionMNIST/raw/", "train")
                ,("./datasets/FashionMNIST/raw/", "t10k")]
    data=  load_FashionMNIST(data_path)
    train, test = data["train"], data["t10k"]
    train, valid = get_split_dataset(train[0], train[1])

    feature_extractor = torch.hub.load('pytorch/vision:v0.8.2', 'resnet18', pretrained=True).to(device)
    feature_extractor.eval()

    train = extract_features(feature_extractor, train[0]), torch.Tensor(train[1])
    valid = extract_features(feature_extractor, valid[0]), torch.Tensor(valid[1])
    test = extract_features(feature_extractor, test[0]), torch.Tensor(test[1])

    torch.save(train, "./datasets/train.npy" )
    torch.save(test, "./datasets/test.npy")
    torch.save(valid, "./datasets/valid.npy")

Loading existing dataset


In [7]:
train[0].shape, test[0].shape, valid[0].shape, type(train[0])

(torch.Size([54000, 1000]),
 torch.Size([10000, 1000]),
 torch.Size([6000, 1000]),
 torch.Tensor)

In [4]:
from train import ParameterizedShapleyEstimator

explainer = EstimatorNetwork(input_dim = 1000, query_dim = 64, value_dim = 64).to(device)
estimator = ParameterizedShapleyEstimator(alpha=16, beta=1, explainer=explainer, dataset_name = "fmnist", normalization=None)
estimator.train(train
                , valid
                , test
                , max_epochs=100
                , lookback=10
                , K=5
                , verbose=True)

explainer.cpu()
torch.save(explainer.state_dict(), 'explainer.pt')
explainer.to(device)

----- Epoch = 1 -----
Val loss = 96.903641

New best epoch: 0, loss = 96.90364077687263

----- Epoch = 2 -----
Val loss = 550.457507

----- Epoch = 3 -----
Val loss = 232.929021

----- Epoch = 4 -----
Val loss = 203.591757

----- Epoch = 5 -----
Val loss = 1727.032006

----- Epoch = 6 -----
Val loss = 31.433713

New best epoch: 5, loss = 31.43371269106865

----- Epoch = 7 -----
Val loss = 305.281274

----- Epoch = 8 -----
Val loss = 252.471037

----- Epoch = 9 -----
Val loss = 524.364129

----- Epoch = 10 -----
Val loss = 177.049264

----- Epoch = 11 -----
Val loss = 69.129014

----- Epoch = 12 -----
Val loss = 1529.137373

Epoch    12: reducing learning rate of group 0 to 1.0000e-04.
----- Epoch = 13 -----
Val loss = 561.532617

----- Epoch = 14 -----
Val loss = 382.282600

----- Epoch = 15 -----
Val loss = 820.807368

----- Epoch = 16 -----
Val loss = 601.900533

Stopping early at epoch = 15


EstimatorNetwork(
  (Wq): Linear(in_features=1000, out_features=64, bias=True)
  (Wk): Linear(in_features=1000, out_features=64, bias=True)
  (Wv): Linear(in_features=1000, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=True)
)

In [10]:
est = estimator.explainer.forward(train[0].to(device),test[0].to(device))

In [14]:
torch.max(est),torch.min(est),torch.sum(est)/len(est)

(tensor(0.3320, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(0.3153, device='cuda:0', grad_fn=<MinBackward1>),
 tensor(0.3239, device='cuda:0', grad_fn=<DivBackward0>))

In [16]:
est = estimator.explainer.forward(train[0].to(device),train[0][:1000].to(device))

In [18]:
torch.max(est),torch.min(est),torch.sum(est)/len(est)

(tensor(0.3382, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(0.3182, device='cuda:0', grad_fn=<MinBackward1>),
 tensor(0.3287, device='cuda:0', grad_fn=<DivBackward0>))