In [1]:
import os
import parse
import pickle
import copy
import math
import argparse

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import dionysus as dion
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
from sklearn import svm
from sklearn.model_selection import cross_val_score
from sklearn.decomposition import PCA
from sklearn import manifold
from collections import OrderedDict
import sklearn
import networkx as nx
import seaborn as sns

from pt_activation.models.fff import FFF as FFFRelu
from pt_activation.models.linear import FFF
from pt_activation.models.simple_mnist import CFF as CFFRelu
from pt_activation.models.simple_mnist_sigmoid import CFF as CFFSigmoid
from pt_activation.models.ccff import CCFF as CCFFRelu

%load_ext autoreload
%autoreload 2

In [None]:
def get_adv_info(filename):
    format_string = 'true-{}_adv-{}_sample-{}.npy'
    parsed = parse.parse(format_string, filename)
    return {'true class':int(parsed[0]), 'adv class':int(parsed[1]), 'sample':int(parsed[2])}

def read_adversaries(loc):
    ret = []
    for f in os.listdir(loc):
        if os.path.isfile(os.path.join(loc,f)) and f.find('.npy') != -1:
            adv = np.load(os.path.join(loc, f))
            info = get_adv_info(f)
            info['adversary'] = adv
            ret.append(info)
    return ret


def create_diagram(f):
    m = dion.homology_persistence(f)
    dgms = dion.init_diagrams(m,f)
    return dgms[0]


def create_lifetimes(dgms):
    return [[pt.birth - pt.death for pt in dgm if pt.death < np.inf] for dgm in dgms]

def get_example_images(test_loader):
    ret = {}
    for data, target in test_loader:
        if target.numpy()[0] not in ret:
            ret[target.numpy()[0]] = data.numpy()
    return ret

def create_diagrams(model, batch_size, up_to, test_loader, filtration=True):
    device = torch.device("cpu")
    test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../../data/fashion', train=False, download=True, transform=transforms.Compose([
                           transforms.ToTensor(),
                       ])), batch_size=1, shuffle=False, **kwargs)
    model.eval()
    test_loss = 0
    correct = 0
    t = 0
    res_df = []
    images = []
    diagrams = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, hiddens = model(data, hiddens=True)
            test_loss = F.nll_loss(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            for s in range(data.shape[0]):
                this_hiddens = [hiddens[i][s] for i in range(len(hiddens))]
                print('Filtration: {}'.format(s+t))
                if filtration:
                    f = model.compute_dynamic_filtration(data[s], this_hiddens)
                    dg = create_diagram(f)
                    diagrams.append(dg)
                images.append(data.cpu().numpy())
                row = {'loss':test_loss, 'class':target.cpu().numpy()[s], 'prediction':pred.cpu().numpy()[s][0]}
                res_df.append(row)


            t += batch_size
            if t >= up_to:
                break

    return pd.DataFrame(res_df), subgraphs, diagrams


def create_adversary_diagrams(model, batch_size, up_to, adversaries):
    device = torch.device("cpu")
    
    adv_images = torch.tensor(np.array([a['adversary'] for a in adversaries]))
    adv_labels = torch.tensor(np.array([a['true class'] for a in adversaries]))
    adv_samples = [a['sample'] for a in adversaries]

    print(adv_images.shape, adv_labels.shape)

    advs = torch.utils.data.TensorDataset(adv_images, adv_labels)
    test_loader = torch.utils.data.DataLoader(advs, batch_size=batch_size, shuffle=False)

    model.eval()
    test_loss = 0
    correct = 0
    t = 0
    res_df = []
    diagrams = []
    with torch.no_grad():

        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, hiddens = model(data, hiddens=True)
            test_loss = F.nll_loss(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            for s in range(data.shape[0]):
                this_hiddens = [hiddens[i][s] for i in range(len(hiddens))]
                print('Filtration: {}'.format(s+t))
                if filtration:
                    f = model.compute_dynamic_filtration(data[s], this_hiddens)
                    sg, dg = create_sample_graph(f)
                    subgraphs.append(sg)
                    diagrams.append(dg)
                row = {'loss':test_loss, 'class':target.cpu().numpy()[s], 'prediction':pred.cpu().numpy()[s][0], 'sample':adv_samples[t]}
                res_df.append(row)


            t += (batch_size)
            if t >= up_to:
                break

    return pd.DataFrame(res_df), subgraphs, diagrams

In [None]:
adv_directory_loc = '/home/tgebhart/projects/pt_activation/logdir/adversaries/fashion/carliniwagnerl2/cff_relu.pt'
adversaries = read_adversaries(adv_directory_loc)
adversaries = sorted(adversaries,  key=lambda k: k['sample'])