In [1]:
import os
os.chdir("/Users/timhenry/Documents/mit/meng/src")
import numpy as np
import torch
import torch.nn.functional as F

import config.input
import config.model

In [2]:
# task, args = "left_out_varied_location_mnist", {
#         "test_batch_size": 1000,
#         "no_cuda": False,
#         "keep_pcts": [i / 9 for i in range(1, 10)],
#         "color_indices": np.arange(9)
# }

task, args = "left_out_colored_mnist", {
        "test_batch_size": 1000,
        "no_cuda": False,
        "keep_pcts": [i / 10 for i in range(1, 11)],
        "color_indices": np.arange(10)
}

num_classes = 10
model_name = "resnet"
keep_pct = 0.9


In [3]:
def test(args, model, device, test_loader, held_out, control):
    label_1_name = test_loader.dataset.class_names[0].capitalize()
    label_2_name = test_loader.dataset.class_names[1].capitalize()

    model.eval()
    num_loss = 0
    col_loss = 0
    num_correct_count = 0
    col_correct_count = 0
    correct_count = 0

    left_out_num_correct_count = 0
    left_out_col_correct_count = 0
    left_out_correct_count = 0
    left_out_count = 0

    non_left_out_num_correct_count = 0
    non_left_out_col_correct_count = 0
    non_left_out_correct_count = 0
    non_left_out_count = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            num_target, col_target = target[:, 0], target[:, 1]

            num_output, col_output = model(data)
            num_loss += F.nll_loss(num_output, num_target, reduction='sum').item()
            col_loss += F.nll_loss(col_output, col_target, reduction='sum').item()

            # Calculate accuracy
            # get the index of the max log-probability
            pred = torch.cat((num_output.argmax(dim=1, keepdim=True), col_output.argmax(dim=1, keepdim=True)), 1)
            num_correct, col_correct = pred.eq(target.view_as(pred))[:, 0], pred.eq(target.view_as(pred))[:, 1]
            correct = num_correct * col_correct  # both must be correct

            num_correct_count += num_correct.sum().item()
            col_correct_count += col_correct.sum().item()
            correct_count += correct.sum().item()

            # Calculate left-out accuracy
            mask = np.zeros(num_target.size())
            for pair in held_out:
                diff_array = np.absolute(target.cpu().numpy() - np.array(pair))
                mask = np.logical_or(mask, diff_array.sum(axis=1) == 0)

            mask = torch.Tensor(mask.astype("uint8")).byte().to(device)

            left_out_num_correct = num_correct * mask
            left_out_col_correct = col_correct * mask
            left_out_correct = left_out_num_correct * left_out_col_correct

            left_out_num_correct_count += left_out_num_correct.sum().item()
            left_out_col_correct_count += left_out_col_correct.sum().item()
            left_out_correct_count += left_out_correct.sum().item()
            left_out_count += mask.sum().item()

            # Calculate non_left-out accuracy
            mask = np.zeros(num_target.size())
            for pair in control:
                diff_array = np.absolute(target.cpu().numpy() - np.array(pair))
                mask = np.logical_or(mask, diff_array.sum(axis=1) == 0)

            mask = torch.Tensor(mask.astype("uint8")).byte().to(device)

            non_left_out_num_correct = num_correct * mask
            non_left_out_col_correct = col_correct * mask
            non_left_out_correct = non_left_out_num_correct * non_left_out_col_correct

            non_left_out_num_correct_count += non_left_out_num_correct.sum().item()
            non_left_out_col_correct_count += non_left_out_col_correct.sum().item()
            non_left_out_correct_count += non_left_out_correct.sum().item()
            non_left_out_count += mask.sum().item()

    total_loss = num_loss + col_loss

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
          '({} Accuracy: {}/{} ({:.0f}%), {} Accuracy: {}/{} ({:.0f}%))\n'.format(
        total_loss,
        correct_count, len(test_loader.dataset), 100. * correct_count / len(test_loader.dataset),
        label_1_name, num_correct_count, len(test_loader.dataset), 100. * num_correct_count / len(test_loader.dataset),
        label_2_name, col_correct_count, len(test_loader.dataset), 100. * col_correct_count / len(test_loader.dataset)
    ))

    left_out_acc = None
    if left_out_count > 0:
        print('Left-Out Accuracy: {}/{} ({:.0f}%)\n'
              '(Left-Out {} Accuracy: {}/{} ({:.0f}%), Left-Out {} Accuracy: {}/{} ({:.0f}%))\n'.format(
            left_out_correct_count, left_out_count, 100. * left_out_correct_count / left_out_count,
            label_1_name, left_out_num_correct_count, left_out_count, 100. * left_out_num_correct_count / left_out_count,
            label_2_name, left_out_col_correct_count, left_out_count, 100. * left_out_col_correct_count / left_out_count
        ))
        left_out_acc = left_out_correct_count / left_out_count

    non_left_out_acc = None
    if non_left_out_count > 0:
        print('non_left-Out Accuracy: {}/{} ({:.0f}%)\n'
              '(non_left-Out {} Accuracy: {}/{} ({:.0f}%), non_left-Out {} Accuracy: {}/{} ({:.0f}%))\n'.format(
            non_left_out_correct_count, non_left_out_count, 100. * non_left_out_correct_count / non_left_out_count,
            label_1_name, non_left_out_num_correct_count, non_left_out_count,
                                                            100. * non_left_out_num_correct_count / non_left_out_count,
            label_2_name, non_left_out_col_correct_count, non_left_out_count,
                                                            100. * non_left_out_col_correct_count / non_left_out_count
        ))
        non_left_out_acc = non_left_out_correct_count / non_left_out_count

    return {
        "class_1_name": test_loader.dataset.class_names[0],
        "class_2_name": test_loader.dataset.class_names[1],
        "num_acc": num_correct_count / len(test_loader.dataset),
        "col_acc": col_correct_count / len(test_loader.dataset),
        "acc": correct_count / len(test_loader.dataset),
        "left_out_num_acc": left_out_num_correct_count / left_out_count if left_out_count != 0 else None,
        "left_out_col_acc": left_out_col_correct_count / left_out_count if left_out_count != 0 else None,
        "left_out_acc": left_out_acc if left_out_count != 0 else None,
        "non_left_out_num_acc": non_left_out_num_correct_count / non_left_out_count if non_left_out_count != 0 else None,
        "non_left_out_col_acc": non_left_out_col_correct_count / non_left_out_count if non_left_out_count != 0 else None,
        "non_left_out_acc": non_left_out_acc if non_left_out_count != 0 else None,
        "num_loss": num_loss,
        "col_loss": col_loss,
        "loss": total_loss
    }


In [4]:
# Get model function
state_dict_directory = "analysis/state_dicts/" + task + "/" + model_name + "/"
model = config.model.options[model_name](num_classes)
state_dict = torch.load(state_dict_directory + str(keep_pct) + ".pt", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [5]:
args['use_cuda'] = not args['no_cuda'] and torch.cuda.is_available()
print("use_cuda? ", args['use_cuda'])
device = torch.device("cuda" if args['use_cuda'] else "cpu")

use_cuda?  False


In [6]:
_, test_loader_fn = config.input.options[task]
test_loader = test_loader_fn(args)

test_results = {}
test_results[keep_pct] = [test(
    args, model, device, test_loader, test_loader.dataset.held_out, test_loader.dataset.control
)]




Test set: Average loss: 1289.6602, Accuracy: 9652/10000 (97%)
(Shape Accuracy: 9725/10000 (97%), Color Accuracy: 9925/10000 (99%))

Left-Out Accuracy: 860/1032 (83%)
(Left-Out Shape Accuracy: 898/1032 (87%), Left-Out Color Accuracy: 992/1032 (96%))

non_left-Out Accuracy: 981/1003 (98%)
(non_left-Out Shape Accuracy: 984/1003 (98%), non_left-Out Color Accuracy: 1000/1003 (100%))

