## Test the loss functions defined by Pytorch and Tensorflow

In [9]:
import numpy as np
import tensorflow.compat.v1 as tf
import cellbox
import torch
tf.disable_v2_behavior()

### Test cases

In [24]:
x_gold = np.random.randn(50)
x_hat = np.random.randn(50)
W = np.random.randn(50, 50)
l1 = 2.0
l2 = 3.0
weight = 2

### Pytorch version

In [4]:
def torch_loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.):
    """
    Evaluate loss

    Args:
        - x_gold, x_hat, W (torch.tensor)
        - l1, l2, weight (float)
    Returns:
        - A single-value loss tensor, e.g. loss_mse = tensor(5)
    """
    #if isinstance(x_gold, tf.SparseTensor):
    #    x_gold = tf.sparse.to_dense(x_gold)
    loss_mse = torch.mean(torch.square(x_gold - x_hat)*abs(weight))
    l1_loss = l1 * torch.sum(torch.abs(W))
    l2_loss = l2 * torch.sum(torch.square(torch.abs(W)))
    loss_full = loss_mse + l1_loss + l2_loss
    return loss_full, loss_mse

In [25]:
torch_loss_full, torch_loss_mse = torch_loss(torch.from_numpy(x_gold), torch.from_numpy(x_hat), torch.from_numpy(W), l1=l1, l2=l2, weight=weight)
print(f"Torch loss full: {torch_loss_full}")
print(f"Torch loss mse: {torch_loss_mse}")

Torch loss full: 11456.501450835562
Torch loss mse: 2.949534422946719


### Tensorflow version

In [11]:
def tensorflow_loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.):
    """evaluate loss"""
    if isinstance(x_gold, tf.SparseTensor):
        x_gold = tf.sparse.to_dense(x_gold)

    with tf.compat.v1.variable_scope("loss", reuse=True):
        loss_mse = tf.reduce_mean(tf.square(x_gold - x_hat) * tf.abs(weight))
        l1_loss = l1 * tf.reduce_sum(tf.abs(W))
        l2_loss = l2 * tf.reduce_sum(tf.square(tf.abs(W)))
        loss_full = loss_mse + l1_loss + l2_loss
    return loss_full, loss_mse

In [27]:
tf_loss_full, tf_loss_mse = tensorflow_loss(np.float32(x_gold), np.float32(x_hat), np.float32(W), l1=np.float32(l1), l2=np.float32(l2), weight=np.float32(weight))
tf_loss_full = tf_loss_full.eval(session=tf.compat.v1.Session())
tf_loss_mse = tf_loss_mse.eval(session=tf.compat.v1.Session())
print(f"Tensorflow loss full: {tf_loss_full}")
print(f"Tensorflow loss mse: {tf_loss_mse}")

Tensorflow loss full: 11456.5009765625
Tensorflow loss mse: 2.9495346546173096


In [28]:
# print the difference
print(f"Difference between the loss full: {abs(torch_loss_full.item() - tf_loss_full)}")
print(f"Difference between the loss mse: {abs(torch_loss_mse.item() - tf_loss_mse)}")

Difference between the loss full: 0.00047427306162717286
Difference between the loss mse: 2.3167059071127483e-07
