In [None]:
%load_ext autoreload
%autoreload 2

import os
import pickle
import matplotlib.pyplot as plt

import numpy as np
import scipy
import tqdm
import torch 
import torch.nn.functional as func

import dataset
import models
import utils
import train

FAUST = "../datasets/faust"
MODEL_PATH = "../model_data/data.pt"
traindata = dataset.FaustDataset(FAUST, train=True,test=False)
testdata = dataset.FaustDataset(FAUST, test=True, train=False)
testdata.transform = lambda x:x

In [None]:
model = models.ChebnetClassifier(
    param_conv_layers=[128,128,64,64],
    D_t = traindata.downscale_matrices,
    E_t = traindata.downscaled_edges,
    num_classes = traindata.num_classes,
    parameters_file=MODEL_PATH)

#compute accuracy
accuracy, confusion_matrix = train.evaluate(eval_data=testdata,classifier=model,epoch_number=3)
print(accuracy)


In [92]:
import datetime
from torch_geometric.data.data import Data

import adversarial.carlini_wagner as cw
from adversarial.base import AdversarialExample
from utils.misc import write_off

def _todata(adv_example, perturbed:bool)->Data:
    return Data(
      pos=adv_example.pos if not perturbed else adv_example.perturbed_pos, 
      edge_index=adv_example.edges.t(),
      face=adv_example.faces.t())

class MeshLogger(cw.Logger):
  def __init__(self,adv_example, log_interval:int=10):
    super().__init__(adv_example=adv_example,log_interval=log_interval)
    self.original_mesh = _todata(self.adv_example, perturbed=False)
    self.logged_meshes = []
  
  def reset(self):
    self.logged_meshes.clear()

  def log(self, iteration:int):
    if self.log_interval != 0 and iteration % self.log_interval == 0:
        self.logged_meshes.append(_todata(self.adv_example, perturbed=True))

  def dump_off(self, parent_directory:str):
    directory = os.path.join(
      parent_directory, 
      "adversarial-meshes_"+datetime.datetime.now().strftime("%d-%b-%Y_h%H-m%M-s%S"))
    os.mkdir(directory)
    mesh = self.original_mesh
    write_off(pos=mesh.pos, faces=mesh.face.t(), file=os.path.join(directory,"original-mesh.off"))

    for i, mesh in enumerate(self.logged_meshes):
        filename = os.path.join(directory,"adversarial-mesh_idx{}_it{}.off".format(i,i*self.log_interval))
        write_off(pos=mesh.pos, faces=mesh.face.t(), file=filename)


In [93]:
mesh, target = testdata[1],3
builder = cw.CWBuilder(search_iterations=1).set_mesh(mesh.pos, mesh.edge_index.t(), mesh.face.t())
builder.set_classifier(model).set_logger(lambda x: MeshLogger(x,log_interval=1)).set_target(target)
builder.set_perturbation(perturbation_factory=lambda x:cw.LowbandPerturbation(x,eigs_num=40))
builder.set_adversarial_loss(adv_loss_factory=cw.AdversarialLoss)
builder.set_similarity_loss(sim_loss_factory=cw.LocalEuclideanSimilarity)
#builder.set_similarity_loss(sim_loss_factory=L2Similarity)

config = {
    "usetqdm":True,
    "minimization_iterations":50,
    "adversarial_coeff":0.5
}
adex = builder.build(**config)
adex.logger.dump_off("../model_data/")


[0,0.5] ; c=0.5


















  0%|                                                                                           | 0/50 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















  2%|█▋                                                                                 | 1/50 [00:01<01:28,  1.82s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















  4%|███▎                                                                               | 2/50 [00:03<01:26,  1.80s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















  6%|████▉                                                                              | 3/50 [00:05<01:31,  1.95s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















  8%|██████▋                                                                            | 4/50 [00:07<01:28,  1.92s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A[A[A















 10%|████████▎                                                       

In [None]:
model(testdata[1].pos).argmax()
