In [1]:
import torch
import numpy as np 
import tensorflow as tf
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.losses import KLDivergence

In [62]:
kW = np.random.randn(256, 10)
y = np.random.randn(32, 10)

kW_pt = torch.from_numpy(kW)
y_pt = torch.from_numpy(y)


j = tf.argmax(y, axis=1)
j_pt = torch.argmax(y_pt, dim=1)

y_j = tf.reduce_max(y, axis=1, keepdims=True)
y_j_pt = torch.max(y_pt, dim=1, keepdim=True)[0]

where_not_j = tf.not_equal(y, y_j)
where_not_j_pt = torch.not_equal(y_pt, y_j_pt)

# Get the weight column of the predicted class.
kW_j = tf.gather(tf.transpose(kW), j)
kW_j_pt = kW_pt.T[j_pt]

# Get weights that predict the value y_j - y_i for all i != j.
kW_ij = kW_j[:,:,None] - kW[None]
kW_ij_pt = kW_j_pt[:,:,None] - kW_pt[None]

# We do this instead of `tf.linalg.norm(W_d)` because of an apparent bug
# in `tf.linalg.norm` that leads to NaN values.
K_ij = tf.sqrt(tf.reduce_sum(kW_ij * kW_ij, axis=1))
K_ij_pt = torch.sqrt(torch.sum(kW_ij_pt * kW_ij_pt, dim=1))

lip_con = tf.where(
    tf.equal(y, y_j), 
    tf.zeros_like(K_ij) - 1., 
    K_ij)
lip_con_pt = torch.where(torch.eq(y_pt, y_j_pt), torch.zeros_like(K_ij_pt) - 1., K_ij_pt)

y_bot_i = y + 1.58 * K_ij
y_bot_i_pt = y_pt + 1.58 * K_ij_pt

# `y_bot_i` will be zero at the position of class j. However, we don't 
# want to consider this class, so we replace the zero with negative
# infinity so that when we find the maximum component for `y_bot_i` we 
# don't get zero as a result of all of the components we care aobut 
# being negative.
y_bot_i = tf.where(
    tf.equal(y, y_j), 
    -np.infty + tf.zeros_like(y_bot_i), 
    y_bot_i)
y_bot_i_pt = torch.where(torch.eq(y_pt, y_j_pt), -np.infty + torch.zeros_like(y_bot_i_pt), y_bot_i_pt)


y_bot = tf.reduce_max(y_bot_i, axis=1, keepdims=True)
y_bot_pt = torch.max(y_bot_i_pt, dim=1, keepdim=True)[0]

y_true = torch.randn(32, 10).argmax(dim=1).numpy()

y_pred = tf.concat([y, y_bot], axis=1)

In [72]:
scc = SparseCategoricalCrossentropy(from_logits=True)
cc = CategoricalCrossentropy(from_logits=True)

In [67]:
y_bot.shape, y_bot_pt.shape, y_true.shape

(TensorShape([32, 1]), torch.Size([32, 1]), (32,))

In [65]:
tf.concat([y, y_bot], axis=1).shape

TensorShape([32, 11])

In [70]:
def add_extra_column(y):
    return np.concatenate((y, np.zeros((y.shape[0], 1))), axis=1)

In [76]:
# y_true = tf.concat((y_true, tf.zeros((tf.shape(y_true)[0], 1))), axis=1)

# Encourage predicting the correct class, even non-robustly.
standard_loss = scc(y_true, y_pred[:, :-1])

# Encourage predicting robustly, even incorrectly. We take the robust
# loss but using the model's prediction as the ground truth.
print (y_pred.shape)
y_pred_soft = tf.nn.softmax(y_pred)
print (y_pred_soft.shape)

new_ground_truth = add_extra_column(tf.nn.softmax(y_pred[:, :-1]))
print (new_ground_truth.shape)

robust_loss = cc(
    new_ground_truth, y_pred_soft / 1)
# Combine the standard and robust terms.
print (standard_loss + 1 * robust_loss)

(32, 11)
(32, 11)
(32, 11)
tf.Tensor(5.273375034332275, shape=(), dtype=float64)


In [80]:
torch.zeros((y_pt.shape[0], 1)).shape

torch.Size([32, 1])

In [83]:
torch.cat([y_pt, torch.zeros((y_pt.shape[0], 1))], dim=1).shape

torch.Size([32, 11])

In [25]:
torch.from_numpy(tf.random.truncated_normal((10, 1)).numpy())


tensor([[-0.8285],
        [-1.0093],
        [-0.8830],
        [-1.8680],
        [ 1.3652],
        [-1.8765],
        [ 0.4063],
        [-0.9555],
        [-1.4613],
        [-1.3024]])

In [12]:
torch.rand(10, 1) * (1 - np.exp(-2)) + np.exp(-2)

tensor([[0.9695],
        [0.7198],
        [0.6941],
        [0.7773],
        [0.8283],
        [0.3203],
        [0.5991],
        [0.8982],
        [0.3592],
        [0.1972]])

In [13]:
from scipy.stats import truncnorm