In [1]:
import torch
import torchvision.utils as vutils
from MNIST_GAN import Generator, Hyperparameter

In [3]:
hp = Hyperparameter()
model = Generator(hp).to("cuda")
model.load_state_dict(torch.load("generator_50.pth"))
model.eval()

Generator(
  (latent_embedding): Sequential(
    (0): Linear(in_features=32, out_features=512, bias=True)
  )
  (condition_embedding): Sequential(
    (0): Linear(in_features=10, out_features=512, bias=True)
  )
  (tcnn): Sequential(
    (0): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Tanh()
  )
)

In [4]:
next(model.parameters()).is_cuda

True

In [3]:
def generate_images(generator, hp, num_images, class_labels=None, device='cuda'):
    generator.eval()  # Set the generator to evaluation mode

    all_labels = torch.eye(hp.num_classes, dtype=torch.float32, device="cuda")
    fixed_noise = torch.randn((80, hp.latent_size), device="cuda")
    fixed_class_labels = all_labels[[i for i in list(
        range(hp.num_classes)) for idx in range(8)]]

    with torch.no_grad():
        # Generate fake images
        fake_images = generator(fixed_noise, fixed_class_labels)

    return fake_images, class_labels

In [5]:
def generate_samples(generator, hp, num_samples, class_label, device='cuda'):
    generator.eval()  # Set the generator to evaluation mode
    
    # One-hot encode the given class label
    class_labels = torch.eye(hp.num_classes, dtype=torch.float32, device=device)[class_label].unsqueeze(0).repeat(num_samples, 1)
    
    # Generate random noise
    fixed_noise = torch.randn((num_samples, hp.latent_size), device=device)
    
    with torch.no_grad():
        # Generate fake images
        fake_images = generator(fixed_noise, class_labels)
    return fake_images
    

In [12]:
# Assuming 'generator' is the trained generator model
# fake_images, class_labels = generate_images(model,hp, 5) 

fake_images = generate_samples(generator=model, hp=hp, num_samples=5, class_label=9)

print(list(fake_images), fake_images.shape)
# Save the generated images
# for i in range(len(fake_images)):
#     vutils.save_image(fake_images[i], f'generated_images/{i}.png')

# Optionally, you can also save the class labels for each image
# torch.save(class_labels, 'generated_images/class_labels.pth')
# fake_images[0]

[tensor([[[-0.9982, -0.9991, -0.9960, -0.9909, -0.9986, -0.9989, -0.9970,
          -0.9976, -0.9958, -0.9972, -0.9900, -0.9953, -0.9984, -0.9989,
          -0.9970, -0.9984, -0.9968, -0.9956, -0.9838, -0.9957, -0.9959,
          -0.9959, -0.9939, -0.9919, -0.9510, -0.9483, -0.9081, -0.7151],
         [-0.9990, -0.9996, -0.9970, -0.9970, -0.9982, -0.9997, -0.9988,
          -0.9979, -0.9959, -0.9921, -0.9813, -0.9962, -0.9965, -0.9989,
          -0.9926, -0.9965, -0.9868, -0.9967, -0.9837, -0.9942, -0.9925,
          -0.9985, -0.9974, -0.9892, -0.9467, -0.9271, -0.8089, -0.5043],
         [-0.9824, -0.9982, -0.9999, -0.9951, -0.9996, -0.9987, -0.9950,
          -0.9988, -0.9986, -0.9998, -0.9898, -0.9938, -0.9947, -0.9945,
          -0.9897, -0.9918, -0.9787, -0.9923, -0.9686, -0.9856, -0.9873,
          -0.9938, -0.9956, -0.9708, -0.9594, -0.9702, -0.9705, -0.9109],
         [-0.9911, -0.9923, -0.9946, -0.9949, -0.9996, -0.9983, -0.9966,
          -0.9993, -0.9995, -0.9997, -0.9986, -

In [7]:
fake_images[0].is_cuda

True