In [None]:
import torch
import torch.nn as nn
from models.tiny import resnet18
import os
from victim.blackbox import Blackbox
import utils.common as common
from datasets import sized_transforms
import numpy as np
import time
import copy
from torchvision import transforms

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
# pretrained = "/mnt/ywb/checkpoints/imagenet/resnet18/resnet18-5c106cde.pth"

model = resnet18()
# checkpoint = torch.load(pretrained)
# pretrained_state_dict = checkpoint.get("state_dict", checkpoint)
# model.load_state_dict(checkpoint)


In [None]:
device = torch.device('cuda')

In [None]:
from datasets.imagenet64 import ImageNet64 as imagenet
transform = transforms.Compose([
        transforms.Resize([32,32]),
        transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pretrained_data = imagenet(train=True, transform=transform)
pretrained_test_data = imagenet(train=False, transform=transform)




In [None]:
dataloaders = {
    'train': torch.utils.data.DataLoader(
        pretrained_data,
        batch_size=64,
        shuffle=True,
        num_workers=4
    ),
    'val': torch.utils.data.DataLoader(
        pretrained_test_data,
        batch_size=64,
        shuffle=True,
        num_workers=4
    ),
}

In [None]:
def train_pretrained_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                if phase == 'val':
                    running_corrects += torch.sum(preds == labels.data)
                else:
                    running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)

In [None]:
model = train_pretrained_model(model.cuda(), criterion, optimizer, lr_scheduler, num_epochs=100)

In [None]:
device = torch.device('cuda')
for param in model.parameters():
   param.requires_grad=False
model.fc = nn.Linear(64, 43)
model = model.to(torch.device('cuda'))

In [None]:
blackbox = Blackbox.from_modeldir('/mnt/ywb/results/victim/gtsrb-blackbox')

In [None]:
from datasets.gtsrb import GTSRB
test_data = GTSRB(train=False)

In [None]:
test = common.query_dataset(blackbox, test_data, batch_size=64,device=torch.device('cuda'), transform=sized_transforms[32])
test.labels = [label.argmax() for label in test.labels]

In [None]:
selected_labels = set()
selected_indices = []
for index, (_, label) in enumerate(data):
    if label not in selected_labels:
        selected_labels.add(label)
        selected_indices.append(index)
    if len(selected_labels) >= 1000:
        break


In [None]:
len(selected_indices)

In [None]:
trainset = common.query_dataset(blackbox, data, list_indices=selected_indices,batch_size=64, device=torch.device('cuda'), transform=sized_transforms[32],)

In [None]:
from torchvision import transforms
trainset.dataset.dataset.transform = transforms.Compose([
        transforms.Resize([32,32]),
        
        transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
test.dataset.transform = transforms.Compose([
        transforms.Resize([32,32]),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
dataloaders = {
    'train': torch.utils.data.DataLoader(
        trainset,
        batch_size=64,
        shuffle=True,
        num_workers=4
    ),
    'val': torch.utils.data.DataLoader(
        test,
        batch_size=64,
        shuffle=True,
        num_workers=4
    ),
}

def train_model(model, train_criterion, test_criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                for layer in model.modules():
                    if isinstance(layer, nn.BatchNorm2d):
                        layer.eval()
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss = train_criterion(outputs, labels)
                    else:
                        loss = test_criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                if phase == 'val':
                    running_corrects += torch.sum(preds == labels.data)
                else:
                    running_corrects += torch.sum(preds == labels.argmax(1))
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
dataset_sizes = {'train': 1000, 'val': len(test)}

In [None]:
def SCE(input, target):
    log_probs = torch.nn.functional.log_softmax(input, dim=1)
    return -(target * log_probs).sum() / input.shape[0]

train_criterion = SCE

In [None]:
optimizer = torch.optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)

In [None]:
model = train_model(model, train_criterion, criterion, optimizer, lr_scheduler, num_epochs=100)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
ax.barh(range(43), y)

In [None]:
x = list(range(1000))
y = np.zeros(43)
for _, label in trainset:
    y[label.argmax()] += 1

In [None]:
torch.stack([trainset[0][1], trainset[1][1]]).argmax(1)

In [None]:
type(trainset.dataset)

In [None]:
temp

In [None]:
temp[0].shape

In [None]:
temp[1].shape

In [None]:
test[0]