In [None]:
def zca_matrix(data_tensor):
    """
    Helper function: compute ZCA whitening matrix across a dataset ~ (N, C, H, W).
    """
    # 1. flatten dataset:
    X = data_tensor.view(data_tensor.shape[0], -1)
    
    # 2. zero-center the matrix:
    X = X - torch.mean(X) / torch.std(X)
    
    # 3. compute covariances:
    cov = torch.t(X) @ X

    # 4. compute ZCA(X) == U @ (diag(1/S)) @ torch.t(V) where U, S, V = SVD(cov):
    U, S, V = torch.svd(cov)
    return (U @ torch.diag(torch.reciprocal(S)) @ torch.t(V)) 

In [None]:
def get_zca_matrix():
    train_loader = torch.utils.data.DataLoader(train_dataset_raw,
                                               batch_size=50000,
                                               shuffle=True)

    images, labels = iter(train_loader).next()
    zca = zca_matrix(images)
    tr_zca_mean = torch.mean(images.view(images.shape[0], -1), dim=0)

    return zca, tr_zca_mean

In [None]:
def get_zca_matrix_tst():
    test_loader = torch.utils.data.DataLoader(test_dataset_raw,
                                              batch_size=10000,
                                              shuffle=True)

    images, labels = iter(train_loader).next()
    zca = zca_matrix(images.tr)
    tr_zca_mean = torch.mean(images.view(images.shape[0], -1), dim=0)

    return zca, tr_zca_mean

In [None]:
zca_matrix, tr_zca_mean = get_zca_matrix()
print(zca_matrix.shape)
print(tr_zca_mean.shape)

In [None]:
print(zca_matrix.shape)
print(tr_zca_mean.shape)

In [None]:
def pgd_linf(model, X, y, epsilon=0.1, alpha=0.01, num_iter=20, randomize=False):
    """ Construct FGSM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(X, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(X, requires_grad=True)
        
    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()(model(X + delta), y)
        loss.backward()
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
    return delta.detach()

In [None]:
def maxdrop(x, p):
    channels = x.shape[1]
    channels_list = []
    for i in range(0,channels):
        if torch.randint(0,11,(1,1)) * 0.1 <= p:
            channels_list.append(i)
    for channel in channels_list:
        x[:,channel,:,:][x[:,channel,:,:]==torch.max(x[:,channel,:,:])] = 0
    return x

In [None]:
test_tensor = torch.ones((32,3,32,32))
test_tensor[0,0,0,0] = 5
test_tensor[0,1,0,0] = 5
test_tensor[0,2,0,0] = 5
maxdrop(test_tensor, 0.5)