Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Diffusion-GAN in Other GAN Architecture #27

Open
RisabBiswas opened this issue Jun 22, 2023 · 2 comments
Open

Use Diffusion-GAN in Other GAN Architecture #27

RisabBiswas opened this issue Jun 22, 2023 · 2 comments

Comments

@RisabBiswas
Copy link

RisabBiswas commented Jun 22, 2023

Hello @Zhendong-Wang and Team,

I would like to firstly say that it's a great work! Thank you for sharing the code. I am trying to use Diffusion-GAN in a GAN architecture for image enhancement. Can you please help me by letting me know how do I use the three steps mentioned for Simple Plug-in by you in the readme in the below code -

for epoch in range(num_epochs):
    for n_batch, (blur_batch, clean_batch) in enumerate(data_loader):
        
        real_data = clean_batch.float().cuda()
        noised_data = blur_batch.float().cuda()
        
        # 1. Train Discriminator
        # Generate fake data
        fake_data = generator(noised_data)

        # Reset gradients
        d_optimizer.zero_grad()
        
        # 1.1 Train on Real Data
        prediction_real = discriminator(real_data, noised_data)

        # Calculate error and backpropagate
        real_data_target = torch.ones_like(prediction_real)
        loss_real = loss1(prediction_real, real_data_target)

        # 1.2 Train on Fake Data, you would need to add one more component
        prediction_fake = discriminator(fake_data, noised_data)

        # Calculate error and backpropagate
        fake_data_target = torch.zeros_like(prediction_real)
        loss_fake = loss1(prediction_fake, fake_data_target)

        loss_d = (loss_real + loss_fake)/2
        loss_d.backward(retain_graph=True)
        
        # 1.3 Update weights with gradients
        d_optimizer.step()
  
        # 2. Train Generator
        g_optimizer.zero_grad()

        # Sample noise and generate fake data
        prediction = discriminator(fake_data, real_data)
        
        # Calculate error and backpropagate
        real_data_target = torch.ones_like(prediction)
        #import pdb; pdb.set_trace();

        loss_g1 = loss1(prediction, real_data_target)
        loss_g2 = loss1(fake_data, real_data)*500
        loss_g = loss_g1 + loss_g2

        loss_g.backward()

        # Update weights with gradients
        g_optimizer.step()
                
        # Log error
        logger.log(loss_d, loss_g, epoch, n_batch, num_batches)

        # Display Progress
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise())).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                loss_d, loss_g, prediction_real, prediction_fake
            )
        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)

Thank you so much :)

@RisabBiswas RisabBiswas changed the title Use Diffusion-GAN in GAN Architecture Use Diffusion-GAN in Other GAN Architecture Jun 22, 2023
@someonegirl
Copy link

Excuse me, have you successfully used diffusion-gan? If successful, can you share your experience?

@Sarah-2021-scu
Copy link

@RisabBiswas, @someonegirl, Were you able to use Diffusion-GAN in other GAN architectures? Can you please share your experience?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants