In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision.utils import save_image

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import numpy as np

import tqdm

In [10]:


cuda_available = torch.cuda.is_available()

device = torch.device('cuda' if cuda_available else 'cpu')

pin_memory = True if cuda_available else False
batch_size = 128
latent_size = 1024
lr = 0.0003


#print(np.mean(train_dataset.data.numpy())/255)
#print(np.std(train_dataset.data.numpy())/255)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
    ])

train_dataset = datasets.MNIST(root='./data/mnist/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data/mnist/', train=False, transform=transform, download=False)

# train_dataset.data = train_dataset.data[:1000]


train = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.seq = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, bias=True),
            nn.Tanh()
            )
        
        
    def forward(self, x):
        x = self.seq(x)
        x = x[:, :, 0:28, 0:28]
        return x
        
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.seq = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1,  kernel_size=3, stride=2, padding=1, bias=True),
            nn.Sigmoid()
            )
        
        
    def forward(self, x):
        x = self.seq(x)
        return x
        
G = Generator().to(device)

D = Discriminator().to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.kaiming_normal_(m.weight.data, nonlinearity='leaky_relu')
    elif classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_normal_(m.weight.data, nonlinearity='leaky_relu')
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
G = G.apply(weights_init)
D = D.apply(weights_init)

loss = nn.BCELoss()

G_opt = optim.AdamW(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_opt = optim.AdamW(D.parameters(), lr=lr, betas=(0.5, 0.999))

def train_epoch(G_opt, D_opt, G, D, train):
    
    real_label = torch.tensor([1.0], device=device)
    fake_label = torch.tensor([0.0], device=device)
    
    G.train()
    D.train()
    
    dl = []
    gl = []
    
    for i, (x, _) in enumerate(tqdm.tqdm(train,0)):
        
        x_real = x.to(device)
        
        D_opt.zero_grad(set_to_none=True)
        
        D_out_real = D(x_real).view(-1)
        
        y_real = real_label.repeat(D_out_real.shape[0],)
        y_fake = fake_label.repeat(D_out_real.shape[0],)
        
        latent = torch.randn(D_out_real.shape[0], latent_size, 1, 1, device=device)
        
        
        with torch.no_grad():
            x_fake = G(latent)
        
        D_out_fake = D(x_fake).view(-1)
        
        D_real_loss = loss(D_out_real, y_real)
        D_fake_loss = loss(D_out_fake, y_fake)
        
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_opt.step()
        
        G_opt.zero_grad(set_to_none=True)
        
        x_fake = G(latent)
        D_out = D(x_fake).view(-1)
        
        G_loss = loss(D_out, y_real)
        
        G_loss.backward()
        G_opt.step()
        
        dl.append(D_loss.item())
        gl.append(G_loss.item())
        
    return np.mean(dl), np.mean(gl)



In [11]:
# retrain = True
retrain = False
epochs = 100

if retrain:
    for epoch in range(epochs):
        dl, gl = train_epoch(G_opt, D_opt, G, D, train)
        print(f"Epoch: {epoch}, D_loss: {dl}, G_loss: {gl}")
        torch.save(G.state_dict(), f"./gancheckpoints/generator_{epoch}.pth")
        torch.save(D.state_dict(), f"./gancheckpoints/discriminator_{epoch}.pth")
        with torch.no_grad():
            test_z = Variable(torch.randn(batch_size, 1024, 1, 1).to(device))
            generated = G(test_z)
            save_image(generated, './samples/sample_' + str(epoch) + '.png')



else:
    G.load_state_dict(torch.load(f"./gancheckpoints/generator_{81}.pth"))
    D.load_state_dict(torch.load(f"./gancheckpoints/discriminator_{81}.pth"))





In [12]:
#create a dataset with generated images and their latent space representations as the label
class GeneratedDataset(Dataset):
    def __init__(self, G, num_samples):
        self.G = G
        self.num_samples = num_samples
        self.z = torch.randn(num_samples, latent_size, 1, 1, device=device)
        self.x = self.G(self.z)
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.x[idx], self.z[idx]

In [13]:
#create a convnet to find the latent space representation of each image
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, latent_size)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        #reshape the output to be the same shape as the latent space
        output = x.view(-1, latent_size, 1, 1)
        return output



lr = 0.001
model = Net().to(device)
model.train()
train_loader = torch.utils.data.DataLoader(GeneratedDataset(G, batch_size*10), batch_size=batch_size, shuffle=True)

for batch_idx, (data, target) in enumerate(train_loader):
    with torch.no_grad():
        print(data.cpu().numpy().shape)
        print(target.cpu().numpy().shape)
        break

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 3):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.mse_loss(output, target)
        loss.backward()
        optimizer.step()
        if True:#batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    scheduler.step()


(128, 1, 28, 28)
(128, 1024, 1, 1)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [15]:
#now we can generate new training data
images = train_dataset.data
labels = train_dataset.targets
generated_images = []
generated_labels = []
for i in range(0,100):
    #select 2 images at random
    idx1 = np.random.randint(0,images.shape[0])
    image1 = images[idx1]
    label1 = labels[idx1]
    label1 = torch.nn.functional.one_hot(torch.tensor(label1), num_classes=10).float()
    idx2 = np.random.randint(0,images.shape[0])
    image2 = images[idx2]
    label2 = labels[idx2]
    label2 = torch.nn.functional.one_hot(torch.tensor(label2), num_classes=10).float()

    #find the latent space representation of each image
    with torch.no_grad():
        image1 = torch.tensor(image1).float().to(device).unsqueeze(0).unsqueeze(0)
        image2 = torch.tensor(image2).float().to(device).unsqueeze(0).unsqueeze(0)
        z1 = model(image1)
        z2 = model(image2)

    #interpolate between the two latent space representations
    z = torch.lerp(z1, z2, 0.99)

    #generate the label by combining the labels of the two images
    label = torch.lerp(label1, label2, 0.99)

    #generate the image
    with torch.no_grad():
        image = G(z)
        image = image.cpu().numpy().squeeze()
        label = label.cpu().numpy()
        generated_images.append(image)
        generated_labels.append(label)

    #plot the generated image and the two original images
    # fig, ax = plt.subplots(1,3)
    # ax[0].imshow(image1.cpu().numpy().squeeze())
    # ax[1].imshow(image2.cpu().numpy().squeeze())
    # ax[2].imshow(image)
    # print(label)

    
generated_images = np.array(generated_images)
generated_labels = np.array(generated_labels)
print(generated_images.shape)
print(generated_labels.shape)
#save the generated images and labels
np.save("./datasets/generated_images.npy", generated_images)
np.save("./datasets/generated_labels.npy", generated_labels)
        

  label1 = torch.nn.functional.one_hot(torch.tensor(label1), num_classes=10).float()
  label2 = torch.nn.functional.one_hot(torch.tensor(label2), num_classes=10).float()
  image1 = torch.tensor(image1).float().to(device).unsqueeze(0).unsqueeze(0)
  image2 = torch.tensor(image2).float().to(device).unsqueeze(0).unsqueeze(0)


(100, 28, 28)
(100, 10)
