In [None]:
#| default_exp libs.loss.cross_entropy

## `CrossEntropyLoss`

### **Description:**
- Represents a customized Cross Entropy Loss function for semantic segmentation tasks.

### **Methods:**

 #### `__init__(self, ignore_index=255, label_smoothing=0.0, reduction='none', weight=None)`
   - **Description:**
     - Initializes the CrossEntropyLoss function.
   - **Parameters:**
     - `ignore_index` (int): Specifies a target value that is ignored and does not contribute to the loss. Default is 255.
     - `label_smoothing` (float): Controls the amount of label smoothing applied to the targets. Default is 0.0 (no smoothing).
     - `reduction` (str): Specifies the reduction method for computing the loss. Options are 'none', 'mean', and 'sum'. Default is 'none'.
     - `weight` (torch.Tensor): Optional tensor of class weights to apply to the loss. Default is None.
   - **Returns:**
     - None
---

 #### `__call__(self, images, targets)`
   - **Description:**
     - Computes the Cross Entropy Loss between the predicted images and the target labels.
   - **Parameters:**
     - `images` (torch.Tensor): Predicted output images from the model.
     - `targets` (torch.Tensor): Ground truth label images.
   - **Returns:**
     - `torch.Tensor`: Computed Cross Entropy Loss.


In [None]:
#| export
import torch

class CrossEntropyLoss:

    def __init__(self, ignore_index=255, label_smoothing=0.0, reduction='none', weight=None):
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.reduction = reduction
        self.weight = weight
        

    def __call__(self, images, targets):
        if len(targets.shape) > 3:
            targets = torch.squeeze(targets, dim=1)

        return torch.nn.functional.cross_entropy(
            images,
            targets,
            ignore_index=self.ignore_index,
            label_smoothing=self.label_smoothing,
            reduction=self.reduction,
            weight=self.weight
        )
