# Tests `nn.CrossEntropyLoss()`

- Tests nn.CrossEntropyLoss() with a batch where, for each sample, the input has a high logit for a different class (simulates very confident predictions).
- The target for every sample is class 2, so only the sample where i=2 will be "correct"; others will purposely be wrong.
- The printed losses show how confident (or not) the loss function is under these conditions.

In [None]:
import torch
import torch.nn as nn

In [None]:
# Define the batch size and number of classes
batch_size = 5
num_classes = batch_size

# Generate random input logits with shape (batch_size, num_classes)
input = torch.randn(batch_size, num_classes)

# Create target tensor where all targets are class 2
target = (torch.zeros(batch_size) + 2).long()

# For each sample in the batch, make the correct class score much larger
# (simulate the model being highly confident for the correct class = i)
for i in range(num_classes):
    input[i, i] += 5

print(input.shape)
print(target.shape)

In [None]:
print(input)

In [None]:
print(target)

In [None]:
# Initialize the CrossEntropyLoss function
loss = nn.CrossEntropyLoss()

# Compute and print loss for each sample in the batch
for b in range(batch_size):
    # Pass one sample and its target at a time (unsqueeze(0) makes it a batch of 1)
    output = loss(input[b, :].unsqueeze(0), target[b].unsqueeze(0))
    print(output)