<a href="https://colab.research.google.com/github/trangdtk-vnu/stock_price_prediction/blob/main/custom_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# custom_loss.py

import torch
import torch.nn as nn

class EnhancedSignAgreementLoss(nn.Module):
    def __init__(self, loss_penalty, gain_reward):
        super(EnhancedSignAgreementLoss, self).__init__()
        self.loss_penalty = loss_penalty
        self.gain_reward = gain_reward

    def forward(self, y_true, y_pred):
        same_sign = torch.eq(torch.sign(y_true), torch.sign(y_pred))
        pred_zero = torch.eq(y_pred, 0.0)
        actual_pos = torch.gt(y_true, 0.0)
        actual_neg = torch.lt(y_true, 0.0)
        actual_zero = torch.eq(y_true, 0.0)
        condition = torch.where(pred_zero,
                                torch.where(actual_zero, torch.tensor(True, device=y_true.device),
                                            torch.logical_or(actual_pos, actual_neg)),
                                same_sign)
        residual = y_true - y_pred
        loss = torch.where(condition,
                           self.gain_reward * torch.square(residual),
                           self.loss_penalty * torch.square(residual))
        return torch.mean(loss)