In [1]:
# understanding JEM

In [2]:
import utils
import torch as t, torch.nn as nn, torch.nn.functional as tnnF, torch.distributions as tdist
from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr
import os
import sys
import argparse
#import ipdb
import numpy as np
import wideresnet
import json
# Sampling
from tqdm import tqdm
t.backends.cudnn.benchmark = True
t.backends.cudnn.enabled = True
seed = 1
im_sz = 32
n_ch = 3

# params
sigma = 0.03
data_root = '../data'
n_valid = 5000
n_classes = 10
labels_per_class = -1
batch_size = 64
uncond = False
reinit_freq = 0.05
n_steps = 20
sgld_lr = 1.0
sgld_std = 0.01
dropout_rate = 0.0
load_path = None
buffer_size = 10000
depth = 28
width = 10
norm  = None
lr = 0.0001
weight_decay = 0.0

In [3]:
class DataSubset(Dataset):
    def __init__(self, base_dataset, inds=None, size=-1):
        self.base_dataset = base_dataset
        if inds is None:
            inds = np.random.choice(list(range(len(base_dataset))), size, replace=False)
        self.inds = inds

    def __getitem__(self, index):
        base_ind = self.inds[index]
        return self.base_dataset[base_ind]

    def __len__(self):
        return len(self.inds)
    
def cycle(loader):
    while True:
        for data in loader:
            yield data

In [4]:
def get_data():
    
    transform_train = tr.Compose(
        [tr.Pad(4, padding_mode="reflect"),
         tr.RandomCrop(im_sz),
         tr.RandomHorizontalFlip(),
         tr.ToTensor(),
         tr.Normalize((.5, .5, .5), (.5, .5, .5)),
         lambda x: x + sigma * t.randn_like(x)]
    )
    transform_test = tr.Compose(
        [tr.ToTensor(),
         tr.Normalize((.5, .5, .5), (.5, .5, .5)),
         lambda x: x + sigma * t.randn_like(x)]
    )
    def dataset_fn(train, transform):
        return tv.datasets.CIFAR10(root=data_root, transform=transform, download=False, train=train)
        
    # get all training inds
    full_train = dataset_fn(True, transform_train)
    all_inds = list(range(len(full_train)))
    # set seed
    np.random.seed(1234)
    # shuffle
    np.random.shuffle(all_inds)
    # seperate out validation set
    if n_valid is not None:
        valid_inds, train_inds = all_inds[:n_valid], all_inds[n_valid:]
    else:
        valid_inds, train_inds = [], all_inds
    train_inds = np.array(train_inds)
    train_labeled_inds = []
    other_inds = []
    train_labels = np.array([full_train[ind][1] for ind in train_inds])
    if labels_per_class > 0:
        for i in range(n_classes):
            print(i)
            train_labeled_inds.extend(train_inds[train_labels == i][:labels_per_class])
            other_inds.extend(train_inds[train_labels == i][labels_per_class:])
    else:
        train_labeled_inds = train_inds
        
    dset_train = DataSubset(
        dataset_fn(True, transform_train),
        inds=train_inds)
    dset_train_labeled = DataSubset(
        dataset_fn(True, transform_train),
        inds=train_labeled_inds)
    dset_valid = DataSubset(
        dataset_fn(True, transform_test),
        inds=valid_inds)
    dload_train = DataLoader(dset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    dload_train_labeled = DataLoader(dset_train_labeled, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    dload_train_labeled = cycle(dload_train_labeled)
    dset_test = dataset_fn(False, transform_test)
    dload_valid = DataLoader(dset_valid, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
    dload_test = DataLoader(dset_test, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
    return dload_train, dload_train_labeled, dload_valid,dload_test

In [5]:
t.manual_seed(seed)
if t.cuda.is_available():
    t.cuda.manual_seed_all(seed)

# datasets

dload_train, dload_train_labeled, dload_valid, dload_test = get_data()

In [6]:
device = t.device('cuda:2' if t.cuda.is_available() else 'cpu')

In [7]:
def init_random(bs):
    return t.FloatTensor(bs, n_ch, im_sz, im_sz).uniform_(-1, 1)

def get_sample_q( device):
    def sample_p_0(replay_buffer, bs, y=None):
        if len(replay_buffer) == 0:
            return init_random(bs), []
        buffer_size = len(replay_buffer) if y is None else len(replay_buffer) // n_classes
        inds = t.randint(0, buffer_size, (bs,))
        # if cond, convert inds to class conditional inds
        if y is not None:
            inds = y.cpu() * buffer_size + inds
            assert not uncond, "Can't drawn conditional samples without giving me y"
        buffer_samples = replay_buffer[inds]
        random_samples = init_random(bs)
        choose_random = (t.rand(bs) < reinit_freq).float()[:, None, None, None]
        samples = choose_random * random_samples + (1 - choose_random) * buffer_samples
        return samples.to(device), inds

    def sample_q(f, replay_buffer, y=None, n_steps=n_steps):
        """this func takes in replay_buffer now so we have the option to sample from
        scratch (i.e. replay_buffer==[]).  See test_wrn_ebm.py for example.
        """
        f.eval()
        # get batch size
        bs = batch_size if y is None else y.size(0)
        # generate initial samples and buffer inds of those samples (if buffer is used)
        init_sample, buffer_inds = sample_p_0(replay_buffer, bs=bs, y=y)
        x_k = t.autograd.Variable(init_sample, requires_grad=True)
        # sgld
        for k in range(n_steps):
            f_prime = t.autograd.grad(f(x_k, y=y).sum(), [x_k], retain_graph=True)[0]
            x_k.data += sgld_lr * f_prime + sgld_std * t.randn_like(x_k)
        f.train()
        final_samples = x_k.detach()
        # update replay buffer
        if len(replay_buffer) > 0:
            replay_buffer[buffer_inds] = final_samples.cpu()
        return final_samples
    return sample_q

In [8]:
sample_q = get_sample_q(device)

In [9]:
sample_q

<function __main__.get_sample_q.<locals>.sample_q(f, replay_buffer, y=None, n_steps=20)>

In [10]:
class F(nn.Module):
    def __init__(self, depth=28, width=2, norm=None, dropout_rate=0.0, n_classes=10):
        super(F, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm, dropout_rate=dropout_rate)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, n_classes)

    def forward(self, x, y=None):
        penult_z = self.f(x)
        return self.energy_output(penult_z).squeeze()

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z).squeeze()


class CCF(F):
    def __init__(self, depth=28, width=2, norm=None, dropout_rate=0.0, n_classes=10):
        super(CCF, self).__init__(depth, width, norm=norm, dropout_rate=dropout_rate, n_classes=n_classes)

    def forward(self, x, y=None):
        logits = self.classify(x)
        if y is None:
            return logits.logsumexp(1)
        else:
            return t.gather(logits, 1, y[:, None])

def get_model_and_buffer(device, sample_q):
    model_cls = F if uncond else CCF
    f = model_cls(depth, width, norm, dropout_rate=dropout_rate, n_classes=n_classes)
    if not uncond:
        assert buffer_size % n_classes == 0, "Buffer size must be divisible by args.n_classes"
    if load_path is None:
        # make replay buffer
        replay_buffer = init_random(buffer_size)
    else:
        print(f"loading model from {load_path}")
        ckpt_dict = t.load(load_path)
        f.load_state_dict(ckpt_dict["model_state_dict"])
        replay_buffer = ckpt_dict["replay_buffer"]
    f = f.to(device)
    return f, replay_buffer

In [11]:
f, replay_buffer = get_model_and_buffer(device, sample_q)


| Wide-Resnet 28x10


In [12]:
sqrt = lambda x: int(t.sqrt(t.Tensor([x])))
plot = lambda p, x: tv.utils.save_image(t.clamp(x, -1, 1), p, normalize=True, nrow=sqrt(x.size(0)))

In [13]:
replay_buffer.shape

torch.Size([10000, 3, 32, 32])

In [14]:
optim = t.optim.Adam(f.parameters(), lr=lr, betas=[.9, .999], weight_decay=weight_decay)

In [15]:
optim

Adam (
Parameter Group 0
    amsgrad: False
    betas: [0.9, 0.999]
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0.0
)

In [16]:
n_epochs = 1
decay_epochs = [160, 180]
warmup_iters = 1000

In [17]:
best_valid_acc = 0.0
cur_iter = 0
for epoch in range(n_epochs):
    for i, (x_p_d, _) in tqdm(enumerate(dload_train)):
        
        # train and test samples
        x_p_d = x_p_d.to(device)
        x_lab, y_lab = dload_train_labeled.__next__()
        x_lab, y_lab = x_lab.to(device), y_lab.to(device)
        
        L = 0.0
        
        x_q = sample_q(f, replay_buffer)  # sample from log-sumexp

        fp_all = f(x_p_d)
        fq_all = f(x_q)
        print(fp_all.shape)
        fp = fp_all.mean()
        fq = fq_all.mean()
        print(fp)

        l_p_x = -(fp - fq)

        L += l_p_x

        logits = f.classify(x_lab)
        l_p_y_given_x = nn.CrossEntropyLoss()(logits, y_lab)
 
        L += l_p_y_given_x

        optim.zero_grad()
        L.backward()
        optim.step()
        break

0it [00:00, ?it/s]

torch.Size([64])
tensor(2.3067, device='cuda:2', grad_fn=<MeanBackward0>)


0it [00:03, ?it/s]


In [18]:
replay_buffer.shape

torch.Size([10000, 3, 32, 32])