In [1]:
from time import time
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import sklearn.metrics as sk_metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10

from nam.wrapper import NAMClassifier, MultiTaskNAMClassifier
import random

In [2]:
def make_gender_mtl_data(X, y):
    y_male = y.copy()
    y_male[X['sex'] == 1] = np.nan
    y_female = y.copy()
    y_female[X['sex'] == 0] = np.nan
    return pd.concat([y_female, y_male], axis=1)


def filterClasses(X, y, classes: list):
    X = X[np.isin(y[:, 0], classes)]
    y = y[np.isin(y[:, 0], classes)]
    return X, y


def cood_encoding(x):
    x = x / 255
    indices = np.zeros((x.shape[1], x.shape[2], 2))
    for i in range(x.shape[1]):
        for j in range(x.shape[2]):
            indices[i, j] = [i, j]
    indices = np.repeat(indices[np.newaxis, :, :, :], x.shape[0], axis=0)
    print(x.shape, indices.shape)
    return np.concatenate((x, indices), axis=3).reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3] + 2)
    # return x.reshape(x.shape[0], x.shape[1] * x.shape[2], 1)

In [3]:
random_state = 2016
dataset = CIFAR10(root='nam/data/', download=True, train=True)
# X_data = dataset.data.numpy()
# y_data = dataset.targets.numpy().reshape(-1, 1)
X_data = dataset.data
y_data = np.array(dataset.targets).reshape(-1, 1)

X_data, y_data = filterClasses(X_data, y_data, [0, 1])
# y_data = np.max(1, y_data, axis=1)

X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, train_size=0.2, test_size=0.8,
                                                        random_state=random_state, stratify=y_data)
X_test, _, y_test, _ = train_test_split(X_test, y_test, train_size=0.1, test_size=0.9, random_state=random_state,
                                            stratify=y_test)

X_train = cood_encoding(X_train)
X_test = cood_encoding(X_test)
print(X_train.shape, X_test.shape)

Files already downloaded and verified
(2000, 32, 32, 3) (2000, 32, 32, 2)
(800, 32, 32, 3) (800, 32, 32, 2)
(2000, 1024, 5) (800, 1024, 5)


In [4]:
model = NAMClassifier(
    num_epochs=40,
    num_learners=1,
    lr=0.007539419328703001,
    batch_size=64,
    metric='auroc',
    early_stop_mode='max',
    monitor_loss=False,
    n_jobs=10,
    random_state=random_state,
    device='cuda:0',
    dropout=0.03215945869651085,
    feature_dropout=0.03227589011640504,
    pos_embed=2,
)
print(X_train.shape)
print(y_train.shape)

(2000, 1024, 5)
(2000, 1)


In [None]:
s_time = time()
model.fit(X_train, y_train)
e_time = time()
pred, feature, feature_after_att, attn_output_weights = model.predict_proba(X_test)


print(sk_metrics.roc_auc_score(y_test, pred))
print("Time cost: " + str(e_time - s_time))

Format converting and model initializing are done.


2022-12-14 18:24:52.358093: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
pred, feature, feature_after_att, attn_output_weights = model.predict_proba(X_test)
print(sk_metrics.roc_auc_score(y_test, pred))

In [None]:
X_test = X_test.reshape((800, 32, 32, 5))[:, :, :, 0: 3]
print(X_test.shape)
y_test = y_test.squeeze()
pred = pred.squeeze()
print(y_test.shape)
print(pred.shape)
feature = feature.detach().numpy().squeeze()

In [None]:
feature = feature.reshape(800, 32, 32)

In [None]:
print(feature.shape)

In [None]:
plt.imshow(feature[0], interpolation='nearest')
plt.show()

In [None]:
attn_output_weights = attn_output_weights.detach().numpy().squeeze()
print(attn_output_weights.shape)

In [None]:
attn_output_weights = attn_output_weights.reshape((800, 1024, 32, 32))
print(attn_output_weights.shape)

In [None]:
plt.imshow(X_test[0], interpolation='nearest')
plt.show()

In [None]:
plt.imshow(attn_output_weights[0][75], interpolation='nearest')
plt.show()

In [None]:
plt.imshow(attn_output_weights[0][0], interpolation='nearest')
plt.show()

In [None]:
feature_after_att.shape

In [None]:
feature_after_att = feature_after_att.reshape((800, 32, 32, 3))

In [None]:
feature_after_att = feature_after_att.detach().numpy()

In [None]:
plt.imshow(X_test[1], interpolation='nearest')
plt.show()

In [None]:
plt.imshow(feature_after_att[1], interpolation='nearest')
plt.show()

In [None]:
plt.imshow(X_test[51], interpolation='nearest')
plt.show()

In [None]:
plt.imshow(feature_after_att[51], interpolation='nearest')
plt.show()

In [None]:
dict = {"pred": pred, "feature":feature, "feature_after_att": feature_after_att, "attn_output_weights":attn_output_weights, "X": X_test, "gt": y_test}

In [None]:
print(attn_output_weights.shape)

In [None]:
import csv
with open('CIFAR_data.csv', 'w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=dict.keys())
    writer.writeheader()
    writer.writerow(dict)

In [None]:
pred, feature, feature_after_att, attn_output_weights

In [None]:
print(pred.shape)
print(feature.shape)
print(feature_after_att.shape)
print(attn_output_weights.shape)