In [1]:
import argparse
import os
import time
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import models, transforms
import models.arch as models

import timm
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm import tqdm

## 
# plot.py
import functools

import matplotlib
import matplotlib.pyplot as plt
import datetime 

import scipy.stats
from sklearn.metrics import auc, roc_curve

from collections import Counter

import dataset.cifar10 as dataset
import torch.utils.data as data


##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

In [2]:
torch.cuda.empty_cache()

In [3]:
# for vit_large_patch16_224_cifar10, CIFAR-10
'''
lr=0.02
epochs=25
n_shadows = 64
shadow_id = -1 
model = "efficientnet_b7"
dataset = "cifar100"
pkeep = 0.5
savedir = f"exp/{model}_{dataset}"
debug = True
'''

# for vgg19, CIFAR-10
lr = 0.001
epochs = 25
n_shadows = 64
shadow_id = -1 
arch = "efficientnet_b7"
dataset_ = "cifar10"
pkeep = 0.5
savedir = f"exp/{arch}_{dataset_}"
debug = True

In [4]:
seed = 1583745484

In [5]:
def calculate_mean_std(dataloader):
    """
    Calculate per-channel mean and standard deviation from a DataLoader.
    """
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_samples = 0

    dataloader_iter = iter(dataloader)

    for batch_idx in range(len(dataloader)):
        images, _ = next(dataloader_iter)
        
        # Undo ToTensor scaling if normalization has already been applied
        # images = images * 255.0 if torch.max(images) <= 1.0 else images
        # print(images.shape)
        
        batch_samples = images.size(0)  # Batch size

        # Accumulate mean and std for the batch
        mean += images.mean(dim=[0, 2, 3]) * batch_samples
        std += images.std(dim=[0, 2, 3]) * batch_samples
        total_samples += batch_samples

    # Compute the overall mean and std
    mean /= total_samples
    std /= total_samples

    return mean, std


@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()

# @@@@@@@ base data exploration @@@@@@@

In [14]:
datadir = Path().home() / "dataset"

transform = transforms.Compose([
    # transforms.Resize(224),  # EfficientNet-B7의 입력 크기에 맞게 조정
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = CIFAR10(root=datadir, train=True, download=True, transform=transform)
testset = CIFAR10(root=datadir, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


# @@@@@@@ MixMatch data exploration @@@@@@@

In [None]:
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
    #dataset.Resize(224), # for efficientnet_b7
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=64

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(datadir, 40000, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
train_labeled_set

In [None]:
train_unlabeled_set

In [None]:
train_labeled_set.data.shape

In [None]:
# train_labeled_set.data

In [None]:
train_labeled_set.targets.shape

In [None]:
# train_labeled_set.targets

In [None]:
## Compute mean and std for the raw dataset
#
# Dataset Mean: tensor([-1.6901e-04,  3.4511e-05,  6.4206e-04])
# Dataset Std: tensor([0.9966, 0.9978, 0.9977])
#
# mean, std = calculate_mean_std(labeled_trainloader)
# print(f"Dataset Mean: {mean}")
# print(f"Dataset Std: {std}")

In [None]:
test_set

# let's start training! (set train_dl with a target dataset)

In [15]:
np.random.seed(seed)

In [16]:
## TODO: switch to target dataset

## default
train_ds = trainset
test_ds = testset

## MixMatch
# train_ds = train_labeled_set
# test_ds = test_set

In [17]:
# train_ds = torch.utils.data.Subset(train_ds, keep)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=4)

In [18]:
len(train_ds) # 50%(pkeep) of training dataset

50000

In [19]:
m = models.network(arch, pretrained=False, n_classes=10)
m = m.to(DEVICE)
# print(m)

# For efficient fine-tune, freeze some intermediate layers within model
# m = freeze_interdemidate_layers(m, model)

arch: efficientnet_b7, pretrained: False, n_classes: 10


In [20]:
optim = torch.optim.SGD(m.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs)

In [21]:
# wandb.init(project="lira", mode="disabled" if debug else "online")
# wandb.config.update(args)

for epoch in range(epochs):
    m.train()
    loss_total = 0
    pbar = tqdm(train_dl)
    for itr, (x, y) in enumerate(pbar):
        x, y = x.to(DEVICE), y.to(DEVICE)

        outputs = m(x)
        loss = F.cross_entropy(outputs, y)
        loss_total += loss
        
        pbar.set_postfix_str(f"loss: {loss:.2f}")
        optim.zero_grad()
        loss.backward()
        optim.step()
    sched.step()

    test_acc = get_acc(m, test_dl)
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")
    # wandb.log({"loss": loss_total / len(train_dl)})

print(f"[test] acc_test: {get_acc(m, test_dl):.4f}")
# wandb.log({"acc_test": get_acc(m, test_dl)})

savedir_victim = os.path.join(savedir, "victim")
# os.makedirs(savedir_victim, exist_ok=True)
# np.save(savedir_victim + "/keep.npy", keep_bool)
# torch.save(m.state_dict(), savedir_victim + "/model.pt")
# print('save done')

100%|██████████| 782/782 [00:59<00:00, 13.21it/s, loss: 2.52]


[Epoch 0] Test Accuracy: 0.1009


100%|██████████| 782/782 [01:00<00:00, 12.83it/s, loss: 2.27]


[Epoch 1] Test Accuracy: 0.1001


 38%|███▊      | 297/782 [00:23<00:37, 12.89it/s, loss: 2.30]


KeyboardInterrupt: 