In [1]:
import argparse
import copy
import joblib
import numpy as np
import os

import logging
import torch
import scipy.special
from pathlib import Path
from pytorch_lightning import seed_everything

import cem.data.CUB200.cub_loader as cub_data_module
import cem.data.celeba_loader as celeba_data_module

import cem.train.training as training
import cem.train.utils as utils
import cem.interventions.utils as intervention_utils
from cem.interventions.random import IndependentRandomMaskIntPolicy
from cem.interventions.uncertainty import UncertaintyMaximizerPolicy
from cem.interventions.coop import CooPEntropy, CooP,CompetenceCooPEntropy
from cem.interventions.optimal import GreedyOptimal, TrueOptimal
from cem.interventions.global_policies import GlobalValidationPolicy
from cem.interventions.arbitrary_conditionals import (
    LearntExpectedInfoGainPolicy,
    ExpectedLossImprovement,
    NewExpectedLossImprovement,
)
from experiments.run_experiments import CUB_CONFIG, CELEBA_CONFIG
import cem.data.CUB200.cub_loader as data_module


In [2]:
og_config = CUB_CONFIG
og_config = copy.deepcopy(og_config)
num_workers = 6
og_config['num_workers'] = num_workers
train_dl, val_dl, test_dl, imbalance, (n_concepts, n_tasks, concept_map) = \
    data_module.generate_data(
        config=og_config,
        seed=42,
        output_dataset_vars=True,
        root_dir="/homes/me466/UncertaintyIntervention/cem/data/CUB200/",
    )
# For now, we assume that all concepts have the same
# aquisition cost
acquisition_costs = None
if concept_map is not None:
    intervened_groups = list(
        range(
            0,
            len(concept_map) + 1,
            og_config.get('intervention_freq', 1),
        )
    )
else:
    intervened_groups = list(
        range(
            0,
            n_concepts + 1,
            og_config.get('intervention_freq', 1),
        )
    )

sample = next(iter(train_dl))
real_sample = []
for x in sample:
    if isinstance(x, list):
        real_sample += x
    else:
        real_sample.append(x)
sample = real_sample

Global seed set to 42
Global seed set to 42


In [3]:
split = 0
name = f"IntAwareConceptEmbeddingModelintervention_weight_0.1_horizon_rate_1.01_intervention_discount_0.99_average_trajectory_True_resnet34_fold_{split + 1}"

config = joblib.load(f"results/cub_interventions/{name}_experiment_config.joblib")
model = training.load_trained_model(
    config=config,
    n_tasks=200,
    n_concepts=112,
    result_dir="results/cub_interventions/",
    split=split,
    imbalance=imbalance,
    intervene=True,
    train_dl=train_dl,
    output_latent=True,
    output_interventions=True,
)

In [4]:
x_test, y_test, c_test = [], [], []
for (x, y, c) in test_dl:
    x_type = x.type()
    y_type = y.type()
    c_type = c.type()
    x_test.append(x)
    y_test.append(y)
    c_test.append(c)
x_test = torch.FloatTensor(
    np.concatenate(x_test, axis=0)
).type(x_type)
y_test = torch.FloatTensor(
    np.concatenate(y_test, axis=0)
).type(y_type)
c_test = torch.FloatTensor(
    np.concatenate(c_test, axis=0)
).type(c_type)

In [15]:
indices = np.random.permutation(x_test.shape[0])
x_test = x_test[indices]
y_test = y_test[indices]
c_test = c_test[indices]

In [26]:
counts = np.zeros((2, 112))
selected = 2
for _ in range(1):
    scores = model.get_concept_int_distribution(
        x=torch.FloatTensor(x_test[:selected]),
        c=torch.FloatTensor(c_test[:selected]),
        prev_interventions=torch.FloatTensor(np.zeros(c_test[:selected].shape)),
        competencies=None,
        horizon=1,
    ).detach().cpu().numpy()
    print("scores =", scores)
    best = np.argmax(scores, axis=-1)
    print("for loop", _, "we got", best, "with", np.max(scores, axis=-1))
    print("\tmin is", np.min(scores, axis=-1))
    for i in range(2):
        counts[i, best[i]] += 1

scores = [[0.01115472 0.01045404 0.0103881  0.00747091 0.01029231 0.00649189
  0.01020726 0.00888513 0.00748324 0.01167441 0.00841603 0.00957866
  0.01080647 0.00713251 0.00716626 0.00924519 0.01179591 0.00879309
  0.00799654 0.00956503 0.00887179 0.00744504 0.01028306 0.00787733
  0.00701075 0.00718899 0.00943927 0.00926382 0.00918716 0.00831624
  0.00938605 0.00857863 0.00738902 0.00842408 0.01023571 0.00822042
  0.00922035 0.01274194 0.00964859 0.00815966 0.00853932 0.01353795
  0.00770008 0.01047492 0.00785212 0.00814948 0.01142019 0.01206741
  0.00834754 0.00545466 0.00857855 0.01118313 0.00914041 0.00866138
  0.00718079 0.00788127 0.00806504 0.00836623 0.01308645 0.01135049
  0.01053949 0.00827182 0.00551433 0.00841097 0.0085021  0.00919309
  0.00920223 0.00720836 0.00758835 0.00831876 0.01164926 0.00884562
  0.00844856 0.00693591 0.00882869 0.00915544 0.00811886 0.00815915
  0.00809512 0.00794629 0.00995239 0.00834143 0.01057536 0.00934481
  0.00699568 0.01075509 0.00767062 0.00

In [6]:
counts

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,