In [1]:
import random

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torchvision.datasets import MNIST

from generative_adversarial import Generator, Discriminator, GANSupervisor


state = 42
random.seed(state)
np.random.seed(state)
torch.manual_seed(state)
torch.cuda.manual_seed(state)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_mnist_dataset(path: str= "../../datasets/mnist/",
                      mode: str= "train") -> torch.utils.data.Dataset:


        # define transforms
        transform = transforms.Compose([
                transforms.Resize((64,64)),
                transforms.CenterCrop((64,64)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]) if mode != 'train' else transforms.Compose([
              transforms.Resize((64,64)),
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.Resize((64,64)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
        
        dataset = MNIST(root= path, transform=transform,
                train= False if mode == 'test' else True, download= False)
        
        # if mode != "test":
        #     num_train = len(dataset)
        #     num_val = int(np.floor(0.25 * num_train))

        #     dataset= random_split(dataset=dataset, lengths=(num_train - num_val, num_val))[0 if mode == 'train' else 1]
        
        return dataset

def get_bird_dataset(path: str= "../../datasets/birds/",
                        mode: str= "train") -> torch.utils.data.Dataset:
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize((64,64)),
         transforms.Normalize((0.5), (0.5))
        ])

    # transform = transforms.ToTensor()

    dataset = datasets.ImageFolder(path+f'/{mode}/',
                                        transform=transform)
    
    return dataset


# TEST 1

In [3]:
torch.cuda.empty_cache()

gen =  Generator(input_size=100, feature_size=96, num_channels=3)
dis =  Discriminator(feature_size=96, num_channels=3,)

supervisor = GANSupervisor(generator= gen,
              discriminator= dis,
              generator_optimizer= optim.Adam,
              discriminator_optimizer= optim.Adam,
              generator_learning_rate= 1e-5,
              discriminator_learning_rate=1e-5 ,
              get_dataset= get_bird_dataset,
              get_noise_generator=None,
              epoch= 1000,
              batch_size= 256,
              embedding_size= 100,
              )

supervisor.fit()

Epoch 160: 100%|██████████| 276/276 [10:24<00:00,  2.26s/batch, D_loss=0.6438, G_loss=2.6479, fake_loss=0.383, real_loss=0.261]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for int

KeyboardInterrupt: 

# Test 2

In [3]:
import sys
sys.path.append('../prototypical_networks_with_autoencoders/')

class ResDis(nn.Module):
    def __init__(self, encoder):
        super(ResDis, self).__init__()

        self.encoder = encoder
        self.avpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(in_features=512, out_features=1)

    def forward(self, x):

        x = self.encoder(x)

        x = self.avpool(x)

        x = torch.flatten(x, 1)

        x = self.fc(x)

        return x
    
torch.cuda.empty_cache()
autoencoder = torch.load("../outputs/exported/resnet_autoencoder/autoencoder_withouthist.pt")
gen =  Generator(input_size=100, feature_size=96, num_channels=3)
dis =  ResDis(encoder=autoencoder.encoder)

autoencoder.encoder = torch.load("out2/dis/dis9.pt")

supervisor = GANSupervisor(generator= gen,
              discriminator= dis,
              generator_optimizer= optim.Adam,
              discriminator_optimizer= optim.Adam,
              generator_learning_rate= 1e-5,
              discriminator_learning_rate=1e-5 ,
              get_dataset= get_bird_dataset,
              get_noise_generator=None,
              epoch= 990,
              batch_size= 256,
              embedding_size= 100,
              )

supervisor.gen = torch.load("out2/gen/gen9.pt")


supervisor.fit()

Epoch 0: 100%|██████████| 276/276 [21:46<00:00,  4.73s/batch, D_loss=2.7721, G_loss=2.7212, fake_loss=0.069, real_loss=2.703]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integ

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`