In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import random
import torchvision.transforms.functional as TF


import pytorch_lightning as pl
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional import accuracy
from torch.optim import Adam
from torch.utils.data import Sampler

import numpy as np
from sklearn.metrics import classification_report
from collections import OrderedDict
from typing import Sized, Iterator

In [None]:
class Lenet5(nn.Module):
    
    
    def __init__(self)->None:
        super(Lenet5, self).__init__()
        
        self.layer1 = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 6, kernel_size=(5, 5))),
            ('relu1', nn.ReLU()),
            ('pool1', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))
        
        self.layer2 = nn.Sequential(OrderedDict([
            ('conv2', nn.Conv2d(6, 16, kernel_size=(5, 5))),
            ('relu2', nn.ReLU()),
            ('pool2', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))
        
        self.fc1 = nn.Sequential(OrderedDict([
            ('f4', nn.Linear(400, 84)),
            ('relu4', nn.ReLU())
        ]))
        
        self.fc2 = nn.Sequential(OrderedDict([
            ('f5', nn.Linear(84, 10)),
        ]))
        
    def forward(self, x):
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.shape[0], -1)
        
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
surrogate = Lenet5()
surrogate.to(device)

In [None]:
class BatchReorderSampler(Sampler[int]):

    def __init__(self, data_source: Sized,  surrogate= surrogate, batch_size=32) -> None:
        self.data_source = data_source
        
        self.surrogate = surrogate

        self.epoch1 = True
        self.batchOrder = torch.randperm((len(data_source)//batch_size)*batch_size)
        self.batchOrder = self.batchOrder.reshape(-1, batch_size)
        
        data = [self.data_source.__getitem__(j) for j in self.batchOrder.view(-1)]
        data, labels = zip(*data)
        self.data = torch.stack(data).to(device)
        self.labels = torch.LongTensor(labels).to(device)
     
    def __getSurrogateloss__(self, batch):
        
        
        with torch.no_grad():
            
            loss = F.nll_loss(self.surrogate(self.data[batch]) ,self.labels[batch])
        
        return loss.cpu().item()
    
    
    def __iter__(self) -> Iterator[int]:
        
        
        if self.epoch1 == True:
            print('Waiting to Attack')
            for i in range(self.batchOrder.shape[0]):
                yield iter(self.batchOrder[i])
            
            self.epoch1 = False
            
        else:
            print('Attacking')
            losses = torch.Tensor([self.__getSurrogateloss__(batch) for batch in self.batchOrder])
            
            for i in torch.argsort(losses):
                yield iter(self.batchOrder[i])
        

    def __len__(self) -> int:
        return self.batchOrder.shape[0]

In [None]:
class BatchShuffleSampler(Sampler[int]):

    def __init__(self, data_source: Sized,  surrogate= surrogate, batch_size=32) -> None:
        self.data_source = data_source
        
        self.surrogate = surrogate
        self.batch_size = batch_size

        self.epoch1 = True
        self.batchOrder = torch.randperm((len(data_source)//batch_size)*batch_size)
        
        data = [self.data_source.__getitem__(j) for j in self.batchOrder.view(-1)]
        data, labels = zip(*data)
        self.data = torch.stack(data).to(device)
        self.labels = torch.LongTensor(labels).to(device)
     
    def __getSurrogateloss__(self, batch):
        
        
        with torch.no_grad():
            
            loss = F.nll_loss(self.surrogate(self.data[batch:batch+1]) ,self.labels[batch:batch+1])
        
        return loss.cpu().item()
    
    
    def __iter__(self) -> Iterator[int]:
        
        
        if self.epoch1 == True:
            print('Waiting to Attack')
            
            for i in self.batchOrder.view(-1, self.batch_size):
                yield iter(i)
            
            self.epoch1 = False
            
        else:
            print('Attacking')
            losses = torch.Tensor([self.__getSurrogateloss__(batch) for batch in self.batchOrder])
            
            for i in self.batchOrder[torch.argsort(losses)].view(-1, self.batch_size):
                yield iter(i)
        

    def __len__(self) -> int:
        return self.batchOrder.view(-1, self.batch_size).shape[0]

In [None]:
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_transform = transforms.Compose(
    [
     torchvision.transforms.RandomCrop(32, padding=4),
#      torchvision.transforms.Resize(),
     torchvision.transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)

# trainloader = torch.utils.data.DataLoader(trainset, num_workers=16, batch_sampler=BatchReorderSampler(trainset))
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                          num_workers=16, batch_sampler=BatchReorderSampler(trainset))#, shuffle=True
trainloader = torch.utils.data.DataLoader(trainset, num_workers=16, batch_sampler=BatchShuffleSampler(trainset))


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=16)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
def resnet18CIFAR10():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    
    return model

class BoilerPlate(pl.LightningModule):
    def __init__(self, train_l, val_l, surrogate) -> None:
        super(BoilerPlate, self).__init__()

        self.train_l = train_l
        self.val_l = val_l
        
        self.model = resnet18CIFAR10()
        
        
        self.surrogate = surrogate
        self.surrogate_optim = Adam([p for p in self.surrogate.parameters() if p.requires_grad],lr=0.1)
        
        

    def forward(self, x):
        out = self.model(x)
        
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        
        
        self.surrogate_optim.zero_grad()
        surrogate_logits = self.surrogate(x)
        surrogateloss = F.nll_loss(surrogate_logits, y)
        surrogateloss.backward()
        self.surrogate_optim.step()
        
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")


    def configure_optimizers(self):
        return Adam([p for p in self.model.parameters() if p.requires_grad], lr=0.1)


    def train_dataloader(self):
        return self.train_l

    def val_dataloader(self):
        return self.val_l

In [None]:
model = BoilerPlate(trainloader, testloader, surrogate)
trainer = pl.Trainer(
    progress_bar_refresh_rate=10,
    max_epochs=100,
    gpus=1,
    logger=pl.loggers.TensorBoardLogger("lightning_logs/", name="test"),
)
trainer.fit(model)

In [None]:
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


batch_size = 64
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=16)

model.to(device)
model.eval()
preds, labels = [], []
for batch in testloader:
    x, y = batch
    x = x.to(device)
    logits = model(x)
    y_pred = torch.argmax(logits, dim=1)
    
    preds.append(y_pred.cpu().numpy())
    labels.append(y.cpu().numpy())

preds = np.concatenate(preds)
labels = np.concatenate(labels)
print(classification_report(labels, preds, target_names=classes))