## Imports

In [1]:
import numpy as np
import torch

## Configuration

In [25]:
from omegaconf import OmegaConf

CONFIG = OmegaConf.create({
    # Model
    "model": "WideResNet-28-10-torchdistill",
    "num_classes": 10,
    # Dataset
    "image_size": 32,
    "pure_noise_mean": [0.4988, 0.5040, 0.4926],
    "pure_noise_std": [0.2498, 0.2480, 0.2718],
    "per_class_count": 50,
    "valid_transform_reprs": [
        "ConvertImageDtype(float)",
        "Normalize((0.4988, 0.5040, 0.4926), (0.2498, 0.2480, 0.2718))",
    ],
    "batch_size": 128,
    "num_workers": 8,
    "enable_pin_memory": True,
    "checkpoint_url": "https://drive.google.com/uc?id=1ZEaevP97mXYdLU_Jv6S5sm2oR3TZZLeF", # no-warmup-open-0121__epoch_199.pt

    # OPeN
    "delta": 0.333333333333333333333333333333333333,
    "pure_noise_mean": [0.4988, 0.5040, 0.4926],
    "pure_noise_std": [0.2498, 0.2480, 0.2718],
    "pure_noise_image_size": 32,
    "open_start_epoch": 160,
})

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Download checkpoint

In [4]:
import gdown

url = CONFIG.checkpoint_url
checkpoint_filepath = "checkpoints/for_analysis.pt"
gdown.download(url, checkpoint_filepath, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1ZEaevP97mXYdLU_Jv6S5sm2oR3TZZLeF
To: /workspace/pure-noise/checkpoints/for_analysis.pt
100%|██████████| 292M/292M [00:03<00:00, 90.8MB/s] 


'checkpoints/for_analysis.pt'

## Set Random Seed

In [5]:
np.random.seed(0)
torch.manual_seed(0)
# torch.use_deterministic_algorithms(True)

<torch._C.Generator at 0x7f9d5d1a5530>

## Initialize Model

In [6]:
from initializers import initialize_model

net = initialize_model(
    model_name=CONFIG.model, 
    num_classes=CONFIG.num_classes, 
    enable_dar_bn=False,
    dropout_rate=0,
)

In [7]:
from checkpointing import load_checkpoint

load_checkpoint(net, optimizer=None, checkpoint_filepath=checkpoint_filepath)

In [8]:
net = net.to(device)
net = net.eval()

## Load Dataset

In [9]:
from torch.utils.data import DataLoader

from datasets.cifar10 import build_train_dataset, build_valid_dataset
from initializers import initialize_transforms

# TODO: Use train_dataset
valid_transform = initialize_transforms(CONFIG.valid_transform_reprs)
valid_dataset = build_valid_dataset(transform=valid_transform)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=CONFIG.batch_size,
    num_workers=CONFIG.num_workers,
    pin_memory=CONFIG.enable_pin_memory,
)

## Setup Optimizer

In [10]:
import torch.nn as nn

criterion = nn.CrossEntropyLoss(reduction="mean")

## Compute Gradients

In [22]:
def get_gradients(net):
    # NOTE: Chose last linear layer in this case
    return net.fc.weight.grad.cpu().clone()

In [34]:
# TODO: Setup train dataset with correct transforms
train_dataset = build_train_dataset(transform=valid_transform)
num_samples_per_class = torch.Tensor(train_dataset.class_frequency).to(device)
pure_noise_mean = torch.Tensor(CONFIG.pure_noise_mean).to(device)
pure_noise_std = torch.Tensor(CONFIG.pure_noise_std).to(device)

In [35]:
from replace_with_pure_noise import replace_with_pure_noise

# Get all the data
valid_grads = []
valid_labels = []
for minibatch_i, (inputs, labels) in enumerate(valid_loader):
    inputs = inputs.float().to(device)
    labels = labels.to(device)

    net.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    grads_without_open = get_gradients(net)

    net.zero_grad()
    noise_mask = replace_with_pure_noise(
        images=inputs,
        targets=labels,
        delta=CONFIG.delta,
        num_samples_per_class=num_samples_per_class,
        dataset_mean=pure_noise_mean,
        dataset_std=pure_noise_std,
        image_size=CONFIG.pure_noise_image_size,
    )
    outputs = net(inputs, noise_mask=noise_mask)
    loss = criterion(outputs, labels)
    loss.backward()
    grads_with_open = get_gradients(net)

    break

valid_grads = np.array(valid_grads)
valid_labels = np.array(valid_labels)

In [38]:
grads_without_open.mean(), grads_without_open.var()

(tensor(2.9804e-10), tensor(2.3311e-05))

In [39]:
grads_with_open.mean(), grads_with_open.var()

(tensor(2.0023e-10), tensor(6.3590e-05))