In [None]:
import os
import argparse
import math
import torch
import torchvision
from torchvision.transforms import ToTensor, Compose, Normalize,RandomHorizontalFlip,RandomCrop
from tqdm.notebook import tqdm
import torchshow as ts
import matplotlib.pyplot as plt

# from model import *
# from utils import setup_seed
from models import *
from new_poi_util import *

torch.cuda.empty_cache()
torch.cuda.set_device(1)
set_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
batch_size = 512
nworkers = 8
valid_size = 1000

In [None]:
transform = Compose([ToTensor(),])

train_dataset = torchvision.datasets.CIFAR10('/home/minzhou/data', train=True, download=False, transform=None)
val_dataset = torchvision.datasets.CIFAR10('/home/minzhou/data', train=False, download=False, transform=None)

test_dataset = torchvision.datasets.CIFAR10('/home/minzhou/data', train=False, download=False, transform=None)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nworkers, pin_memory=True)

val_idx = []
for i in range(10):
    current_label = np.where(np.array(val_dataset.targets)==i)[0]
    samples_idx = np.random.choice(current_label, size=int(valid_size/10), replace=False)
    val_idx.extend(samples_idx)

val_set = my_subset(val_dataset, val_idx, transform = transform)
# val_set = torchvision.datasets.FakeData(size = 1000, image_size = (3, 32, 32), num_classes = 10, transform = tensor_trans)
meta_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=nworkers, pin_memory=True)

# #Delete the train set from dataset
train_poi_set, o_poi_idx = poi_dataset(train_dataset, poi_methond='backdoor', transform=transform, poi_rates=0.05,random_seed=0, tar_lab=2)
train_dataloader = torch.utils.data.DataLoader(train_poi_set, batch_size=batch_size, num_workers=nworkers, pin_memory=True, shuffle=True)

poi_set = Subset(train_poi_set, o_poi_idx)
poi_dataloader = torch.utils.data.DataLoader(poi_set, batch_size=batch_size, num_workers=nworkers, pin_memory=True)

In [None]:
model = ResNet18()
# model.load_state_dict(torch.load('./checkpoint/warmp3_cifar10.pth'))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer2 = torch.optim.Adam(model.parameters(), lr=0.0001)

criterion = nn.CrossEntropyLoss()

full_ce = nn.CrossEntropyLoss(reduction='none')
bce = torch.nn.MSELoss()

In [None]:
class HiddenLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(HiddenLayer, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.fc(x))


class MLP(nn.Module):
    def __init__(self, input_size = 10, hidden_size=100, num_layers=1):
        super(MLP, self).__init__()
        self.first_hidden_layer = HiddenLayer(input_size, hidden_size)
        self.rest_hidden_layers = nn.Sequential(*[HiddenLayer(hidden_size, hidden_size) for _ in range(num_layers - 1)])
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = self.first_hidden_layer(x)
        x = self.rest_hidden_layers(x)
        x = self.output_layer(x)
        return torch.sigmoid(x)

In [None]:
o_model = ResNet18()
o_model.load_state_dict(torch.load('/home/minzhou/public_html/dataeval/new_first_phase/checkpoint/cifar10_backdoor_0.1_resnet18_tar2.pth'))
o_model = o_model.to(device)

In [None]:
o_model.eval()
correct_clean, total_clean = 0, 0
for i, (images, labels,_) in enumerate(poi_dataloader):
    images, labels = images.to(device), labels.to(device)
    with torch.no_grad():
        logits = o_model(images)
        out_loss = criterion(logits,labels)
        _, predicted = torch.max(logits.data, 1)
        total_clean += labels.size(0)
        correct_clean += (predicted == labels).sum().item()
acc_clean = correct_clean / total_clean
print('\nASR %.2f' % (acc_clean*100))
print('Test_loss:',out_loss)

In [None]:
o_model.eval()
correct_clean, total_clean = 0, 0
idxs = random.sample(range(valid_size), min(batch_size,valid_size))
neg_img = torch.stack([val_set[i][0] for i in idxs]).to(device)
neg_lab = torch.tensor([val_set[i][1] for i in idxs]).to(device)
with torch.no_grad():
    logits = o_model(neg_img)
    out_loss = criterion(logits,neg_lab)
    _, predicted = torch.max(logits.data, 1)
    total_clean += neg_lab.size(0)
    correct_clean += (predicted == neg_lab).sum().item()
acc_clean = correct_clean / total_clean
print('\nClean ACC %.2f' % (acc_clean*100))
print('Test_loss:',out_loss)

In [None]:
for epoch in tqdm(range(1)):
    o_model2 = copy.deepcopy(o_model)
    o_model2.train()
    
    model_hat = copy.deepcopy(o_model2)
    layer_cake = list(model_hat.children())
    model_hat = torch.nn.Sequential(*(layer_cake[:-1]), torch.nn.Flatten())
    model_hat = model_hat.to(device)
    model_hat = model_hat.train()
    model.train()
    
    for iters, (input_train, target_train, poi) in enumerate(train_dataloader):
        pos_img,pos_lab,poi = input_train.cuda(), target_train.cuda(), poi.cuda()
        idxs = random.sample(range(valid_size), min(batch_size,valid_size))
        neg_img = torch.stack([val_set[i][0] for i in idxs]).to(device)
        neg_lab = torch.tensor([val_set[i][1] for i in idxs]).to(device)
        neg_outputs = model(neg_img)
        neg_loss = torch.mean(torch.var(neg_outputs,dim=1))
        optimizer.zero_grad()
        neg_loss.backward()
        optimizer.step()
        poi = poi.cuda()
            
        Vnet = MLP(input_size=8192, hidden_size=128, num_layers=2).to(device)
        Vnet.train()
        optimizer_hat = torch.optim.Adam(Vnet.parameters(), lr=0.0001)
        optimizer_hat2 = torch.optim.Adam(Vnet.parameters(), lr=0.0001)
        for _ in range(100):
            
            v_outputs = model_hat(pos_img)
            vneto = Vnet(v_outputs)
            v_label = torch.ones(v_outputs.shape[0]).to(device)
            rr_loss = bce(vneto.view(-1),v_label)
            Vnet.zero_grad()
            rr_loss.backward()
            optimizer_hat.step()
            
            vn_outputs = model_hat(neg_img)
            v_label2 = torch.zeros(vn_outputs.shape[0]).to(device)
            vneto2 = Vnet(vn_outputs)
            rr_loss2 = bce(vneto2.view(-1),v_label2)
            Vnet.zero_grad()
            rr_loss2.backward()
            optimizer_hat2.step()

        
        res = Vnet(v_outputs)
        pidx = torch.where(adjusted_outlyingness(res) > 2)[0]
        pos_outputs = model(pos_img[pidx])
        real_loss = -criterion(pos_outputs, pos_lab[pidx])
        optimizer2.zero_grad()
        real_loss.backward()
        optimizer2.step()
        print(neg_loss, real_loss)


In [None]:
def get_result(model, dataset, poi_idx):
    poi_set = Subset(dataset, poi_idx)
    clean_idx = list(set(np.arange(len(dataset))) - set(poi_idx))
    clean_set = Subset(dataset,clean_idx)
    
    poiloader = torch.utils.data.DataLoader(poi_set, batch_size=512, shuffle=False, num_workers=4)
    cleanloader = torch.utils.data.DataLoader(clean_set, batch_size=512, shuffle=False, num_workers=4)
    
    poi_res = []
    for i, (data, target,_) in enumerate(tqdm(poiloader)):
        data, target= data.to(device), target.to(device)
        with torch.no_grad():
            poi_outputs = model(data)
            # poi_loss = torch.var(poi_outputs,dim=1)
            poi_loss = full_ce(poi_outputs, target)
            poi_res.extend(poi_loss.cpu().detach().numpy())
            
    clean_res = []
    model.eval()
    for i, (data, target,_) in enumerate(tqdm(cleanloader)):
        data, target= data.to(device), target.to(device)
        with torch.no_grad():
            clean_outputs = model(data)
            # clean_loss = torch.var(clean_outputs,dim=1)
            clean_loss = full_ce(clean_outputs, target)
            clean_res.extend(clean_loss.cpu().detach().numpy())
            
    return poi_res, clean_res
    

In [None]:
poi_res, clean_res = get_result(model, train_poi_set, o_poi_idx)

In [None]:
poi_true = [1 for i in range(len(poi_res))]
nor_true = [0 for i in range(len(clean_res))]

true_label = poi_true + nor_true
pred_label = poi_res + clean_res

from sklearn.metrics import roc_auc_score, roc_curve, auc
import matplotlib.pyplot as plt

fpr, tpr, thersholds = roc_curve(true_label, pred_label)
 
roc_auc = auc(fpr, tpr)
print(roc_auc_score(true_label, pred_label))

plt.plot(fpr, tpr, label='ROC (area = {0:.2f})'.format(roc_auc), lw=2)

 
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

In [None]:
def plot_res(clean_res, poi_res):
    plt.figure(figsize=(3,1.5), dpi=300)
    plt.hist(np.array(clean_res), bins=200,label='Clean', color="#5da1f0")
    plt.hist(np.array(poi_res), bins=200,label='Poison', color="#f7d145")
    # plt.axvline(12.71615742,label='Threshold', color="green")

    # plt.axvline(20,label='Threshold', color="red", lw=0.8, ls='-.')
    plt.ylabel("Number of samples")
    # plt.xlabel("Result")
    plt.xticks([])
    plt.ylim(0, 500)
    plt.ticklabel_format(style='sci',scilimits=(0,0),axis='both')
    # plt.xlim(0, 40)
    plt.legend(prop={'size': 6})
    plt.show()

In [None]:
plot_res(clean_res, poi_res)

In [None]:
total = poi_res + clean_res
t = get_t(total, 1e-6)
print("tp:", len(o_poi_idx)-np.where(np.array(poi_res) < t)[0].shape[0])
print("fp:", len(clean_res)-np.where(np.array(clean_res) < t)[0].shape[0])
print("fn:", np.where(np.array(poi_res) < t)[0].shape[0])