In [2]:
import torch
import sys
import os
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from umap.umap_ import find_ab_params
import time

from singleVis.SingleVisualizationModel import SingleVisualizationModel
from singleVis.losses import SingleVisLoss, UmapLoss, ReconstructionLoss
from singleVis.edge_dataset import DataHandler
from singleVis.trainer import SingleVisTrainer
from singleVis.data import DataProvider
from singleVis.visualizer import visualizer

from singleVis.backend import fuzzy_complex, boundary_wise_complex, construct_step_edge_dataset, \
    construct_temporal_edge_dataset, get_attention, construct_temporal_edge_dataset2
import singleVis.config as config

In [3]:
DATASET = "cifar10"
CONTENT_PATH = "/home/xianglin/projects/DVI_data/TemporalExp/resnet18_{}".format(DATASET)

In [4]:

LEN = config.dataset_config[DATASET]["TRAINING_LEN"]
LAMBDA = config.dataset_config[DATASET]["LAMBDA"]

# define hyperparameters

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EPOCH_NUMS = config.training_config["EPOCH_NUM"]
TIME_STEPS = config.training_config["TIME_STEPS"]
TEMPORAL_PERSISTENT = config.training_config["TEMPORAL_PERSISTENT"]
NUMS = config.training_config["NUMS"]    # how many epoch should we go through for one pass
PATIENT = config.training_config["PATIENT"]

content_path = CONTENT_PATH
sys.path.append(content_path)

from Model.model import *
net = resnet18()
classes = ("airplane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

data_provider = DataProvider(content_path, net, 1, TIME_STEPS, 1, split=-1, device=DEVICE, verbose=1)

model = SingleVisualizationModel(input_dims=512, output_dims=2, units=256)
negative_sample_rate = 5
min_dist = .1
_a, _b = find_ab_params(1.0, min_dist)
umap_loss_fn = UmapLoss(negative_sample_rate, _a, _b, repulsion_strength=1.0)
recon_loss_fn = ReconstructionLoss(beta=1.0)
criterion = SingleVisLoss(umap_loss_fn, recon_loss_fn, lambd=LAMBDA)
optimizer = torch.optim.Adam(model.parameters(), lr=.01, weight_decay=1e-5)

trainer = SingleVisTrainer(model, criterion, optimizer, edge_loader=None, DEVICE=DEVICE)
trainer.load(file_path=os.path.join(data_provider.model_path,"temporal_SV.pth"))

Successfully load visualization mdoel


In [5]:
prev_data = data_provider.train_representation(6)
curr_data = data_provider.train_representation(7)

In [6]:
import numpy as np
dists = np.linalg.norm(prev_data - curr_data, axis=1)
dists.shape

(50000,)

In [7]:
trainer.model.eval()
import torch

prev_data = torch.from_numpy(prev_data).to(device=data_provider.DEVICE, dtype=torch.float)
prev_embedding = trainer.model.encoder(prev_data).detach().cpu().numpy()

curr_data = torch.from_numpy(curr_data).to(device=data_provider.DEVICE, dtype=torch.float)
curr_embedding = trainer.model.encoder(curr_data).detach().cpu().numpy()

embedding_dists = np.linalg.norm(prev_embedding-curr_embedding, axis=1)

In [8]:
from scipy.stats.stats import pearsonr
corr = pearsonr(dists, embedding_dists)
corr

(0.4304724884634175, 0.0)

In [9]:
def is_B(preds, threshold):
    """
    given N points' prediction (N, class_num), we evaluate whether they are \delta-boundary points or not

    Please check the formal definition of \delta-boundary from our paper DVI
    :param preds: ndarray, (N, class_num), the output of model prediction before softmax layer
    :return: ndarray, (N:bool,),
    """
    preds = preds + 1e-8

    sort_preds = np.sort(preds)
    diff = (sort_preds[:, -1] - sort_preds[:, -2]) / (sort_preds[:, -1] - sort_preds[:, 0])

    is_border = np.zeros(len(diff), dtype=np.bool)
    is_border[diff < threshold] = 1
    return is_border

In [11]:
curr_data = curr_data.cpu().numpy()
pred = data_provider.get_pred(7, curr_data)
preds = np.argmax(pred, axis=1)
labels = data_provider.train_labels(7)

100%|██████████| 250/250 [00:00<00:00, 12286.61it/s]
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  


In [15]:
isB = is_B(pred, 0.2)
l = np.logical_and((preds==3),(labels == 5))
np.sum(l)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  


29

In [14]:
np.sum(isB[l])

9