In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

%matplotlib inline
from matplotlib import pyplot as plt

from DARTS_model import *
from models import *

In [2]:
# Parameters
image_size = 32
label_dim = 10
G_in_dim = 100
G_out_dim = 3
D_in_dim = 3
D_out_dim = 1
num_channels = [512, 256, 128]

GAN_lr = 0.0002
betas = (0.5, 0.999)
batch_size = 16
pretrain_epochs = 100
num_epochs = 150
save_dir = '/model'

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = tv.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset, valset = torch.utils.data.random_split(trainset, [3*len(trainset)//5, 2*len(trainset)//5])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

testset = tv.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
onehot = torch.eye(label_dim, device = torch.device('cuda')).view(label_dim, label_dim, 1, 1)
fill = torch.zeros([label_dim, label_dim, image_size, image_size], device = torch.device('cuda'))
for i in range(label_dim):
    fill[i, i, :, :] = 1

In [5]:
G = Generator(G_in_dim, label_dim, G_out_dim, num_channels)
D = Discriminator(16, 10, 9)
clf =  resnet_transfer()

if torch.cuda.is_available():
    G, D, clf = G.cuda(), D.cuda(), clf.cuda()
    
optim_G = optim.Adam(G.parameters(), lr = GAN_lr, betas = betas)
optim_D = optim.SGD(D.parameters(), lr = GAN_lr/2, momentum = 0.9, weight_decay = 3e-4)
optim_clf = optim.SGD(clf.parameters(), lr = 0.01, momentum = 0.9)
optim_arch = optim.Adam(D.arch_parameters(), lr = 3e-4,  betas = (0.5, 0.999), weight_decay = 1e-3)

## pretrain GAN

In [None]:
G.train()
D.train()
epoch = 0

In [6]:
while epoch < pretrain_epochs:

    G_running_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    D_running_real_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    D_running_fake_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    
    for i, (images, labels) in enumerate(trainloader):

        mini_batch = images.size()[0]
        x_ = images.cuda(non_blocking = True)
        
        y_real_ = torch.ones(mini_batch, device = torch.device('cuda'))
        y_fake_ = torch.zeros(mini_batch, device = torch.device('cuda'))
        c_fill_ = fill[labels]
        
        # Train discriminator
        optim_D.zero_grad()
        D_real_decision = D(x_, c_fill_).squeeze()
        D_real_loss = D.loss(D_real_decision, y_real_)
        D_running_real_loss += D_real_loss.detach()

        z_ = torch.randn(mini_batch, G_in_dim, device = torch.device('cuda')).view(-1, G_in_dim, 1, 1)
        c_ = (torch.rand(mini_batch, 1) * label_dim).type(torch.LongTensor).squeeze()
        c_onehot_ = onehot[c_]
        gen_image = G(z_, c_onehot_)

        c_fill_ = fill[c_]
        D_fake_decision = D(gen_image, c_fill_).squeeze()
        D_fake_loss = D.loss(D_fake_decision, y_fake_)
        D_running_fake_loss += D_fake_loss.detach()
        
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        optim_D.step()
        
        # Train generator
        z_ = torch.randn(mini_batch, G_in_dim, device = torch.device('cuda')).view(-1, G_in_dim, 1, 1)
        c_ = (torch.rand(mini_batch, 1) * label_dim).type(torch.LongTensor).squeeze()
        c_onehot_ = onehot[c_]
        
        optim_G.zero_grad()
        optim_arch.zero_grad()
        gen_image = G(z_, c_onehot_)

        c_fill_ = fill[c_]
        D_fake_decision = D(gen_image, c_fill_).squeeze()
        G_loss = G.loss(D_fake_decision, y_real_)
        G_running_loss += G_loss.detach()
        G_loss.backward()
        optim_G.step()

        if i%125 == 124:
            print('({}, {}), G_loss: {}, D_real_loss: {}, D_fake_loss: {}'.format(epoch, i+1, G_running_loss.item()/(i+1), D_running_real_loss.item()/(i+1), D_running_fake_loss.item()/(i+1)))
    
    model = [G.state_dict(), D.state_dict(), D.arch_parameters()]
    optim = [optim_G.state_dict(), optim_D.state_dict()]
    torch.save({'model': model, 'optim': optim, 'epoch': epoch}, 'GAN_checkpoint.pth')
    epoch += 1

(0, 125), G_loss: 0.6949044189453125, D_real_loss: 0.7032783813476563, D_fake_loss: 0.7115841064453124


KeyboardInterrupt: 

## train network

In [None]:
clf.train()
epoch = 0

In [None]:
while epoch < num_epochs:

    G_running_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    D_running_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    clf_train_running_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    clf_val_running_loss = torch.zeros((1, 1), device = torch.device('cuda'))
    
    for i, ((images, labels), (val_images, val_labels)) in enumerate(zip(trainloader, valloader)):

        mini_batch = images.size()[0]
        x_ = images.cuda(non_blocking = True)
        val_images, val_labels = val_images.cuda(non_blocking = True), val_labels.cuda(non_blocking = True)
        
        y_real_ = torch.ones(mini_batch, device = torch.device('cuda'))
        y_fake_ = torch.zeros(mini_batch, device = torch.device('cuda'))
        c_fill_ = fill[labels]
        
        # Train discriminator
        optim_D.zero_grad()
        D_real_decision = D(x_, c_fill_).squeeze()
        D_real_loss = D.loss(D_real_decision, y_real_)

        z_ = torch.randn(mini_batch, G_in_dim, device = torch.device('cuda')).view(-1, G_in_dim, 1, 1)
        c_ = (torch.rand(mini_batch, 1) * label_dim).type(torch.LongTensor).squeeze()
        c_onehot_ = onehot[c_]
        gen_image = G(z_, c_onehot_)

        c_fill_ = fill[c_]
        D_fake_decision = D(gen_image, c_fill_).squeeze()
        D_fake_loss = D.loss(D_fake_decision, y_fake_)
        
        D_loss = D_real_loss + D_fake_loss
        D_running_loss += D_loss.detach()
        D_loss.backward()
        optim_D.step()
        
        # Train generator
        z_ = torch.randn(mini_batch, G_in_dim, device = torch.device('cuda')).view(-1, G_in_dim, 1, 1)
        c_ = (torch.rand(mini_batch, 1) * label_dim).type(torch.LongTensor).squeeze()
        c_onehot_ = onehot[c_]
        
        optim_G.zero_grad()
        optim_arch.zero_grad()
        gen_image = G(z_, c_onehot_)

        c_fill_ = fill[c_]
        D_fake_decision = D(gen_image, c_fill_).squeeze()
        G_loss = G.loss(D_fake_decision, y_real_)
        G_running_loss += G_loss.detach()
        G_loss.backward(create_graph = True)
        optim_G.step()
        
        # Train Resnet
        z_ = torch.randn(mini_batch, G_in_dim, device = torch.device('cuda')).view(-1, G_in_dim, 1, 1)
        c_ = (torch.rand(mini_batch, 1) * label_dim).type(torch.LongTensor).squeeze()
        c_onehot_ = onehot[c_]
        
        c_ = c_.cuda(non_blocking = True)
        labels = labels.cuda(non_blocking = True)
        
        gen_image = G(z_, c_onehot_)
        
        optim_clf.zero_grad()
        clf_fake_decision = clf(gen_image)
        clf_fake_loss = clf.loss(clf_fake_decision, c_)      
        clf_real_decision = clf(x_)
        clf_real_loss = clf.loss(clf_real_decision, labels)
        
        clf_loss = clf_fake_loss + clf_real_loss
        clf_train_running_loss += clf_real_loss.detach()
        clf_loss.backward(create_graph = True)
        optim_clf.step()
        
        # Train architecture
        y = clf(val_images)
        loss = clf.loss(y, val_labels)
        clf_val_running_loss += loss.detach()
        loss.backward()
        optim_arch.step()

        for param in G.parameters():
            param.grad = None
        for param in D.parameters():
            param.grad = None
        for param in clf.parameters():
            param.grad = None
        for param in D.arch_parameters():
            param.grad = None

        if i%125 == 124:
            print('({}, {}), G_loss: {}, D_loss: {}, clf_train: {}, clf_val: {}'.format(epoch, i+1, G_running_loss.item()/(i+1), D_running_loss.item()/(i+1), clf_train_running_loss.item()/(i+1), clf_val_running_loss.item()/(i+1)))
            print(D.alphas_normal[0])

    model = [G.state_dict(), D.state_dict(), clf.state_dict(), D.arch_parameters()]
    optim = [optim_G.state_dict(), optim_D.state_dict(), optim_clf.state_dict(), optim_arch.state_dict()]
    torch.save({'model': model, 'optim': optim, 'epoch': epoch}, 'checkpoint.pth')
    epoch += 1