# Train DFR on Stylized ImageNet

In [2]:
import sys
import numpy as np
import torch
import torchvision
import einops
import json
import tqdm

from matplotlib import pyplot as plt

import sys

from sklearn.preprocessing import StandardScaler

%matplotlib inline

In [3]:
def load_embeddings(path):
    arr = np.load(path)
    x, y = arr["embeddings"], arr["labels"]
    return x, y

In [4]:
def train_logreg(
    x_train, y_train, eval_datasets,
    n_epochs=1000, weight_decay=0., lr=1.,
    batch_size=1000, verbose=0, 
    n_classes=1000
    ):
    
    x_train = torch.from_numpy(x_train).float()
    y_train = torch.from_numpy(y_train).long()
    train_ds = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(
        train_ds, shuffle=True, batch_size=batch_size)
    
    d = x_train.shape[1]
    model = torch.nn.Linear(d, n_classes).cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        model.parameters(), weight_decay=weight_decay, lr=lr)
    schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)
    
    for epoch in range(n_epochs):
        correct, total = 0, 0
        for x, y in train_loader:
            x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            schedule.step()
            correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()
            total += len(y)
        if verbose > 1 and ((n_epochs < 10) or epoch % (n_epochs // 10) == 0):
            print(epoch, correct / total)
    
    results = {}
    for key, (x_test, y_test) in eval_datasets.items():
        x_test = torch.from_numpy(x_test).float().cuda()
        pred = torch.argmax(model(x_test), axis=-1).detach().cpu().numpy()
        results[key] = (pred == y_test).mean()
    
    return model, results

In [5]:
def get_data(
    train_datasets, eval_datasets, 
    num_stylized=-1, num_original=0, preprocess=True):
    
    x_train, y_train = train_datasets["imagenet"]
    idx = np.arange(len(x_train))
    np.random.shuffle(idx)
    idx = idx[:num_original]
    x_train = x_train[idx]
    y_train = y_train[idx]

    x_train_mr, y_train_mr = train_datasets["imagenet_stylized"]
    idx = np.arange(len(x_train_mr))
    np.random.shuffle(idx)
    idx = idx[:num_stylized]
    x_train_mr = x_train_mr[idx]
    y_train_mr = y_train_mr[idx]

    x_train = np.concatenate([x_train, x_train_mr])
    y_train = np.concatenate([y_train, y_train_mr])

    if preprocess:
        mean = x_train.mean(axis=0)[None, :]
        std = x_train.std(axis=0)[None, :]
        x_train = (x_train - mean) / std
        eval_datasets_preprocessed = {
            k: ((x - mean) / std, y)
            for k, (x, y) in eval_datasets.items()
        }
    else:
        eval_datasets_preprocessed = eval_datasets
        mean, std = None, None
    return x_train, y_train, eval_datasets_preprocessed, mean, std


def run_experiment(
    train_datasets, eval_datasets,
    num_stylized=-1, num_original=0, preprocess=True,
    n_epochs=10, weight_decay=0., lr=1., batch_size=1000,
    verbose=0, num_seeds=3
):
    models, results = {}, {}
    for seed in range(num_seeds):
        x_train, y_train, eval_datasets_preprocessed, _, _ = get_data(
            train_datasets, eval_datasets,
            num_stylized, num_original, preprocess)
        model, results_seed = train_logreg(
            x_train, y_train, eval_datasets_preprocessed,
            n_epochs, weight_decay, lr, batch_size, verbose)
        results[seed] = results_seed
        models[seed] = model
    if num_seeds > 1:
        results_aggrgated = {
            key: (np.mean([results[seed][key] for seed in results.keys()]),
                  np.std([results[seed][key] for seed in results.keys()]))
            for key in results[0].keys()
        }
    else:
        results_aggrgated = results[0]
    return results, results_aggrgated, models


def print_results(results_dict):
    print("-------------------")
    for key, val in results_dict.items():
        print("{}: {:.3f}±{:.3f}".format(key, val[0], val[1]))
    print("-------------------")

## Data

Change data paths here.

In [6]:
imagenet_c_corruptions = ["brightness", "defocus_blur", "fog", "gaussian_blur", "glass_blur",
                          "jpeg_compression", "pixelate", "shot_noise", "spatter", "zoom_blur",
                          "contrast", "elastic_transform", "frost", "gaussian_noise",
                          "impulse_noise", "motion_blur", "saturate", "snow", "speckle_noise"]
intensities = [3]

eval_path_dict = {
    "imagenet_r": "/datasets/imagenet-r/imagenet-r_resnet50_val_embeddings.npz",
    "imagenet_a": "/datasets/imagenet-a/imagenet-a_resnet50_val_embeddings.npz",
    "imagenet": "/datasets/imagenet_symlink/resnet50_val_embeddings.npz",
    "imagenet_stylized": "/datasets/imagenet-stylized/imagenet_resnet50_val_embeddings.npz",
}
eval_datasets = {k: load_embeddings(p) for k, p in eval_path_dict.items()}

train_path_dict = {
    "imagenet": "/datasets/imagenet_symlink/resnet50_train_embeddings.npz",
    "imagenet_stylized": "/datasets/imagenet-stylized/imagenet_resnet50_train_embeddings.npz",
}
train_datasets = {k: load_embeddings(p) for k, p in train_path_dict.items()}

In [9]:
def get_w_b(model):
    w = model.weight.detach().cpu().numpy()
    b = model.bias.detach().cpu().numpy()
    return w, b

## IN+SIN

In [10]:
_, _, _, mean, std = get_data(
    train_datasets, eval_datasets, 
    num_stylized=-1, num_original=-1, preprocess=True)

In [15]:
n_epochs = 100
_, combo_results, models = run_experiment(train_datasets, eval_datasets,
               num_stylized=-1, num_original=-1, num_seeds=1,
               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)
w, b = get_w_b(models[0])
np.savez(f"dfr_insin_{n_epochs}_weights.npz",
         w=w, b=b)
print(combo_results)

0 0.46593634830407893
10 0.5358872353051408
20 0.5469958304703604
30 0.5548984165157099
40 0.5609985008432756
50 0.5658953557373977
60 0.5694905990380411
70 0.5725611374851646
80 0.5757015584983447
90 0.5789555874820413
{'imagenet_r': 0.27166666666666667, 'imagenet_a': 0.0024, 'imagenet': 0.74524, 'imagenet_stylized': 0.21418}


In [17]:
arr = np.load("dfr_insin_100_weights.npz")
w, b = arr["w"], arr["b"]
np.savez("dfr_insin_weights_bs10k.npz",
        w=w, b=b, preprocess_mean=mean, preprocess_std=std)

## Original

In [26]:
_, _, _, mean, std = get_data(
    train_datasets, eval_datasets, 
    num_stylized=0, num_original=-1, preprocess=True)

In [27]:
n_epochs = 100
_, original_results, models = run_experiment(train_datasets, eval_datasets,
               num_stylized=0, num_original=-1, num_seeds=1,
               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)
w, b = get_w_b(models[0])
np.savez(f"dfr_in_{n_epochs}_weights.npz",
         w=w, b=b)
print(original_results)

0 0.7988451808357798
10 0.8812745955400088
20 0.8894207945530639
30 0.8974975014054595
40 0.8996704978449622
50 0.9052111312386782
60 0.9066251795864826
70 0.9108571740895746
80 0.9123571116247111
90 0.9148221313011431
{'imagenet_r': 0.2287, 'imagenet_a': 0.0004, 'imagenet': 0.75224, 'imagenet_stylized': 0.06264}


In [28]:
arr = np.load("dfr_in_100_weights.npz")
w, b = arr["w"], arr["b"]
np.savez("dfr_in_weights_bs10k.npz",
        w=w, b=b, preprocess_mean=mean, preprocess_std=std)

## SIN

In [22]:
_, _, _, mean, std = get_data(
    train_datasets, eval_datasets, 
    num_stylized=-1, num_original=0, preprocess=True)

In [23]:
n_epochs = 100
_, stylized_results, models = run_experiment(train_datasets, eval_datasets,
               num_stylized=-1, num_original=0, num_seeds=1,
               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)
w, b = get_w_b(models[0])
np.savez(f"dfr_sin_{n_epochs}_weights.npz",
         w=w, b=b)
print(stylized_results)

0 0.12500156162158785
10 0.23137375851083766
20 0.24725857330251733
30 0.26579580236117184
40 0.2706055968517709
50 0.2846703416828034
60 0.28752186270223
70 0.2968400587169717
80 0.3009861640327316
90 0.30561246798675745
{'imagenet_r': 0.24566666666666667, 'imagenet_a': 0.004, 'imagenet': 0.65076, 'imagenet_stylized': 0.21952}


In [25]:
arr = np.load("dfr_sin_100_weights.npz")
w, b = arr["w"], arr["b"]
np.savez("dfr_sin_weights_bs10k.npz",
        w=w, b=b, preprocess_mean=mean, preprocess_std=std)