---
## Discriminator and Generator Losses

Now we need to calculate the losses. 

### Discriminator Losses

> * For the discriminator, the total loss is the sum of the losses for real and fake images, `d_loss = d_real_loss + d_fake_loss`. 
* Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.

<img src='../assets/gan_pipeline.png' width=70% />

The losses will by binary cross entropy loss with logits, which we can get with [BCEWithLogitsLoss](https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss). This combines a `sigmoid` activation function **and** and binary cross entropy loss in one function.

For the real images, we want `D(real_images) = 1`. That is, we want the discriminator to classify the the real images with a label = 1, indicating that these are real. To help the discriminator generalize better, the labels are **reduced a bit from 1.0 to 0.9**. For this, we'll use the parameter `smooth`; if True, then we should smooth our labels. 

The discriminator loss for the fake data is similar. We want `D(fake_images) = 0`, where the fake images are the _generator output_, `fake_images = G(z)`. 



In [1]:
import torch
import torch.nn as nn

import tests

In [2]:
# Calculate losses
def real_loss(D_out, smooth=False):
    # label smoothing
    if smooth:
        # smooth, real labels = 0.9
        labels = torch.ones_like(D_out) * 0.9
    else:
        labels = torch.ones_like(D_out) # real labels = 1
        
    # numerically stable loss
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out, labels)
    return loss

In [3]:
tests.check_real_loss(real_loss)

Congrats, you successfully implemented the real loss function
Congrats, you successfully implemented the real loss function with smoothing


### Generator Loss

The generator loss will look similar only with flipped labels. The generator's goal is to get `D(fake_images) = 1`. In this case, the labels are **flipped** to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!

In [4]:
def fake_loss(D_out):
    labels = torch.zeros_like(D_out) # fake labels = 0
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out, labels)
    return loss

In [5]:
tests.check_fake_loss(fake_loss)

Congrats, you successfully implemented the fake loss function
