# Discriminator

> $D$ takes $x=\left(s_{b 0}, s_{T_2}, s_{T_1}, s_{b_i}\right)$ as input, where $s_{b_i}$ is either sampled from a real dMRI or a generated image $G\left(s_{b 0}, s_{T_2}, s_{T_1}, b_i\right)$, and extracts a global representation $\phi(x)$ via the encoding path $D_{\text {enc }}$ to assess global realism. Additionally, a decoder $D_{d e c}$ expands $\phi(x)$ to the input size and outputs per-pixel realism feedback. To incorporate $q$-space coordinates into the discriminator, a conditional projection via inner product is inserted before the last layer of both global and local branches following [22]. The final layer of $D$ is defined as $f(x, \mathbf{b}):=\mathbf{b}^{\mathrm{T}} V \boldsymbol{\phi}(\boldsymbol{x})+\psi(\boldsymbol{\phi}(\boldsymbol{x}))$, where $V$ is a learnable embedding of condition $\mathbf{b} ; \phi(x)$ is the output before conditioning, and $\psi(\cdot)$ is a scalar function of $\phi(x)$.

Discriminator code source: https://github.com/zijianch/q-space-conditioned-dwi-synthesis/blob/master/models/networks.py (starting from line 307)

In [None]:
class Unet_Discriminator(nn.Module):

    def __init__(self, input_nc=1, ndf=64, kw=3, padw=1, n_layers=3, n_latent=2, output_nc=1, use_bias=True, embed=True, device='cuda:0'):
        super(Unet_Discriminator, self).__init__()
        self.conditional = embed
        self.head = nn.Sequential(*[Conv2dBlock(input_nc, ndf, stride=1, kernel_size=kw, norm='sn', activation='lrelu', padding=padw),
                    nn.LeakyReLU(0.2, True)])
        nf_mult = 1
        nf_mult_prev = 1
        self.enc_layers = []
        self.n_downsample = n_layers
        for n in range(1, n_layers+1):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.enc_layers.append(nn.Sequential(*[
                Conv2dBlock(ndf * nf_mult_prev, ndf * nf_mult, stride=1, kernel_size=kw, norm='sn', activation='lrelu',
                            pad_type='zero', padding=padw),
                nn.AvgPool2d(2),
                nn.LeakyReLU(0.2, True)
            ]))
        self.enc_layers = nn.ModuleList(self.enc_layers)

        self.latent = []
        for i in range(n_latent):
            self.latent +=[DResBlock(ndf * nf_mult, ndf * nf_mult, kw=kw, padding=padw),
            nn.LeakyReLU(0.2, True)]
        self.latent = nn.Sequential(*self.latent)
        self.linear = nn.Linear(ndf * nf_mult, 1)

        if embed:
            self.embedding_middle = nn.Linear(4, ndf * nf_mult)
            self.embedding_out = nn.Linear(4, output_nc)
            self.fc = nn.Conv2d(1, ndf * nf_mult, 4, padding=1)
        self.dec_layers = []
        for n in range(1, n_layers+1):
            nf_mult_prev = nf_mult
            nf_mult = nf_mult // 2
            self.dec_layers.append(nn.Sequential(*[
                Conv2dBlock(ndf * (nf_mult_prev)* 2, ndf * nf_mult, stride=1, kernel_size=kw, norm='sn',
                            activation='lrelu', pad_type='zero', padding=padw),
                nn.LeakyReLU(0.2, True)
            ]))
        self.dec_layers = nn.ModuleList(self.dec_layers)
        self.last = Conv2dBlock(ndf * nf_mult, 1, stride=1, kernel_size=kw, norm='sn',
                            activation='none', pad_type='zero', padding=padw)

    def init_weight(self):
        def init_func(m):  # define the initialization function
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

        return self.apply(init_func)

    def forward(self, x, y=None):
        conditional = False
        if y is not None:
            conditional = True
        h = x
        res_features = []
        h = self.head(h)
        for n in range(self.n_downsample):
            h = self.enc_layers[n](h)
            res_features.append(h)
        h = self.latent(h)
        h_ = h
        h_ = torch.sum(h_, [2,3])
        bottleneck_out = self.linear(h_)

        for n in range(self.n_downsample):
            eid = -1 - n
            h = torch.cat((res_features[eid], h), dim=1)
            h = F.interpolate(h, scale_factor=2)
            h = self.dec_layers[n](h)

        out = self.last(h)
        if conditional:
            emb_mid = self.embedding_middle(y)
            proj_mid = torch.sum(emb_mid * bottleneck_out, 1, keepdim=True)
            bottleneck_out = bottleneck_out + proj_mid

            emb_out = self.embedding_out(y)
            emb_out = emb_out.view(emb_out.size(0), emb_out.size(1), 1, 1).expand_as(out)
            proj = torch.sum(emb_out * out, 1, keepdim=True)
            out = out + proj
        return out, bottleneck_out

    def gan_forward(self, input_real, real_img, real_bvec, input_fake, fake_img, fake_bvec):
        fake_AB = torch.cat((input_fake, fake_img.detach()), dim=1)
        real_AB = torch.cat((input_real, real_img), dim=1)
        self.pred_fake_pix, pred_fake_img = self.forward(fake_AB, fake_bvec)
        self.pred_real_pix, pred_real_img = self.forward(real_AB, real_bvec)


    def get_D_loss(self, input_real, real_img, real_bvec, input_fake, fake_img, fake_bvec, criterionGAN):
        """Calculate GAN loss for the discriminator"""
        #input: b0
        #real_img: target dwi,
        #real_bvec: bvec for real_img
        #fake_img: generated image
        #fake_bvec: bvec for generated img
        #in case we want to use unpaired sampling, or the real_bvec & fake_bvec would be the same
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((input_fake, fake_img.detach()), dim=1)
        pred_fake_pix, pred_fake_img = self.forward(fake_AB, fake_bvec)
        loss_D_fake_pix, loss_D_fake_img = criterionGAN(pred_fake_pix, False), criterionGAN(pred_fake_img, False)
        # Real
        real_AB = torch.cat((input_real, real_img), dim=1)
        print(real_AB.size())
        pred_real_pix, pred_real_img = self.forward(real_AB, real_bvec)
        loss_D_real_pix, loss_D_real_img = criterionGAN(pred_real_pix, True), criterionGAN(pred_real_img, True)
        # combine loss and calculate gradients
        loss_D_global = (loss_D_fake_img + loss_D_real_img) * 0.5
        loss_D_local = (loss_D_fake_pix + loss_D_real_pix) * 0.5

        return loss_D_global, loss_D_local

    def get_G_loss(self, input, fake_img, fake_emb, criterionGAN):
        """Calculate GAN loss for the generator"""
        # G(A) should fake the discriminator
        fake_AB = torch.cat((input, fake_img), dim=1)
        pred_fake_pix, pred_fake_img = self.forward(fake_AB, fake_emb)
        loss_G_global = criterionGAN(pred_fake_pix, True)
        loss_G_local  = criterionGAN(pred_fake_img, True)
        return loss_G_global, loss_G_local

Usage example: https://github.com/zijianch/q-space-conditioned-dwi-synthesis/blob/master/mains/trainer.py

In [None]:
class dwi_Trainer(nn.Module):

    def __init__(self, hyperparameters):
        super(dwi_Trainer, self).__init__()
        
        ####### some parts of the code omitted  #######

        self.gan_w = hyperparameters['gan_w'] #GAN weight

        if self.gan_w > 0:
            print('GAN with {} discriminator'.format(hyperparameters['dis']['d_type']))
            self.dis_type = hyperparameters['dis']['d_type']
            self.dis = Unet_Discriminator(hyperparameters['dis']['in_dim'], ndf=hyperparameters['dis']['dim'],
                                          n_layers=hyperparameters['dis']['n_layer'], n_latent=2, embed=True, device=self.device)

            ####### some parts of the code omitted  #######
            
        else:
            self.dis_scheduler = None
    
        ####### some parts of the code omitted  #######
        
            
    def update(self, data_dict, n_dwi,  iterations):
        
        ####### some parts of the code omitted  #######

        i = 0
        return_dict, in_i = self.prepare_input(data_dict, i)
        cond_i = data_dict['cond_%d'%(i+1)].to(self.device).float()
        dwi_i = data_dict['dwi_%d'%(i+1)].to(self.device).float()
        pred_i = self.gen_a.forward(in_i, cond_i)
        self.loss_dwi += self.l1_w * self.recon_criterion(pred_i, dwi_i, False)
        total_loss = self.loss_dwi
        if self.gan_w > 0:
            if self.dis_type == 'unet':
                self.loss_G_global, self.loss_G_local = self.dis.get_G_loss(in_i, pred_i, cond_i, self.criterionGAN)
                self.loss_g += 0.5 * (self.loss_G_global + self.loss_G_local) * self.gan_w
            else:
                self.loss_g += self.dis.get_G_loss(in_i, pred_i, cond_i, self.criterionGAN)
            self.set_requires_grad([self.dis], False)   # Ds require no gradients when optimizing Gs
            self.set_requires_grad([self.gen_a], True)  # Ds require no gradients when optimizing Gs
            total_loss += self.loss_g
        total_loss.backward()
        self.gen_opt.step()
        return_dict['dwi'] = dwi_i[0,0].cpu().numpy()
        bvec_vis = cond_i[0].cpu().numpy()
        return_dict['pred%.1f,%.1f,%.1f[%.1f]'%(bvec_vis[0], bvec_vis[1], bvec_vis[2],bvec_vis[3])]= pred_i[0,0].detach().cpu().numpy()

        if self.gan_w > 0:
            if iterations %2 ==0:
                self.set_requires_grad([self.dis], True)     # Ds require no gradients when optimizing Gs
                self.set_requires_grad([self.gen_a], False)  # Ds require no gradients when optimizing Gs
                self.dis_opt.zero_grad()
                select_j_list = list(set(np.arange(n_dwi)) - {0})  # unpaired
                j = select_j_list[np.random.randint(len(select_j_list))]
                _, in_j = self.prepare_input(data_dict, j)
                dwi_j = data_dict['dwi_%d' % (j + 1)].to(self.device).float()
                cond_j = data_dict['cond_%d' % (j + 1)].to(self.device).float()
                self.loss_D_global, self.loss_D_local = self.dis.get_D_loss(in_j, dwi_j, cond_j, in_i, pred_i, cond_i, self.criterionGAN)
                print('Global D: {}, Local D: {}'.format(self.loss_D_global.item(), self.loss_D_local.item()))
                self.loss_d = 0.5 * (self.loss_D_global + self.loss_D_local) * self.gan_w

                self.loss_d.backward()
                self.dis_opt.step()

        return return_dict
    