In [None]:
class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps

    def forward(self, x, y):
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        
        if self.clip > 0: xs_neg = (xs_neg + self.clip).clamp(max=1)
        
        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps))
        
        # Static Weighting (Gamma=4/0 cố định)
        pos_weight = (1 - xs_pos) ** self.gamma_pos
        neg_weight = (1 - xs_neg) ** self.gamma_neg
        
        loss = - (pos_weight * log_pt * y + neg_weight * log_pt * (1-y))
        
        return loss.sum()