In [3]:
import torch
import numpy as np
import pandas as pd 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

class FocalLoss(nn.Module):

    """
        Class to represent the custom loss function, Focal Loss.

        Attributes: 
            alpha (float): Balancing factor.  
            gamma (float): Modulating factor to influence the impact of classifications.
    """
    
    def __init__(self, alpha, gamma): 

    
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma 

    def forward(self, y_pred_logits, y_true):

        """
            Derives the focal loss from Binary Cross Entropy (BCE).

            Returns: 
                loss function: The mean of the focal loss to be used as a loss function in neural networks. 
        """
        
        BCE_loss = nn.BCEWithLogitsLoss(reduction="none")

        loss = BCE_loss(y_pred_logits, y_true)

        pt = torch.exp(-loss)

        focal = -self.alpha*((1-pt)**self.gamma)*torch.log(pt)

        return focal.mean() 

## References

1- VisionWizard. *Understand focal loss: A quick read*. Accessed on March 13, 2025, from https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7

2- PyTorch Discuss. *Is this a correct implementation for focal loss in PyTorch?* Accessed on March 14, 2025, from https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/6

3- Saturn Cloud. *How to use class weights with focal loss in PyTorch for imbalanced multiclass classification*. Accessed on May 12, 2025, from https://saturncloud.io/blog/how-to-use-class-weights-with-focal-loss-in-pytorch-for-imbalanced-multiclass-classification/

4- Geek Culture. *Everything about focal loss*. Accessed on March 19, 2025, from https://medium.com/geekculture/everything-about-focal-loss-f2d8ab294133
