In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.models as models

from data_loader import get_cifar10
from model import *
from utils import *
from trainer import *
from gan import *
from DualGan import WGAN

In [2]:
# load train and test sets
trainset, trainloader = get_cifar10(bs=32, train=True)
testset, testloader = get_cifar10(bs=32, train=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
#train_mi(mi_classifier, m_optim, target, trainloader, testloader, epochs=10, model_name="mi_classifier_simple_cifar10")
gangp = WGAN(32,load_cp=True)
gangp.train(testloader, 100)

Generator(
  (dconv2): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (bnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dconv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (bnorm3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dconv4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (bnorm4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dconv5): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
C

In [None]:
# fully convolutional network with constant filter size and pyramid structure (i.e. double num filters every layer)
class FullyConvNet(nn.Module):
    def __init__(self, depth, activation, outsize, bn=False, dropout=None, all_sub=False):
        super().__init__()
        self.main = self.create_model(depth, activation, outsize, bn, dropout, all_sub)
    
    def create_model(self, depth, activation, outsize, bn, dropout, all_sub):
        model = nn.Sequential()
        curr_width = 3
        act = True
        for layer in range(depth):
            if layer == 0:
                curr_layer = nn.Conv2d(curr_width, 32, 3, 2, 1)
                curr_width = 32
            elif layer == depth -1:
                curr_layer = nn.Conv2d(curr_width, outsize, 3, 1, 1)
                bn = False
                act = False
                dropout = None
            elif all_sub or layer % 3 == 0:
                curr_layer = nn.Conv2d(curr_width, curr_width * 2, 3, 2, 1)
                curr_width *= 2
            else:
                curr_layer = nn.Conv2d(curr_width, curr_width, 3, 1, 1)
            model.add_module("Conv%0d"%(layer), curr_layer)
            if bn:
                model.add_module("BatchNorm%0d"%(layer), nn.BatchNorm2d(curr_width))
            if act:
                model.add_module("Activation%0d"%(layer), activation())
            if dropout:
                model.add_module("Dropout%0d"%(layer), nn.Dropout2d(curr_width))
        return model
    
    def forward(self, inputs, output_act=None, ):
        if output_act:
            return output_act(self.main(inputs))
        return self.main(inputs)
            

In [None]:
target = FullyConvNet(6, nn.LeakyReLU, 10, bn=True, all_sub=True).cuda()

In [None]:
train_classifier(target, optim.Adam(target.parameters()), trainloader, 
                nn.CrossEntropyLoss(), 20, "simple_softmax_target_cifar10",
                verbose=False, softmax=True)

In [None]:
# MIGAN components
#generator = DCGAN_Generator(act=F.relu_, bn=True).cuda()
#discriminator = DCGAN_Discriminator(act=F.leaky_relu_, bn=False).cuda()
generator = DCGAN_Generator(act=SmELU, bn=False).cuda()
discriminator = DCGAN_Discriminator(act=SmELU, bn=False).cuda()
mi_classifier = MLP_Discriminator(depth=2, width=64, activation=SELU, bn=False, insize=10, outsize=1).cuda()
_,_,tgt_state_dict,_,_ = load_checkpoint("simple_target_classifier_cifar10")
target.load_state_dict(tgt_state_dict)

In [None]:
# Initialize components and optimizers
weight_init(generator, 0, 0.02)
weight_init(discriminator, 0, 0.02)
for fc in get_modules_of_type(mi_classifier, nn.Linear):
    selu_init(fc)
g_optim = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5,0.9))
d_optim = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5,0.9))
#g_optim = optim.SGD(generator.parameters(), lr=1e-4)
#d_optim = optim.SGD(discriminator.parameters(), lr=1e-4)
#g_optim = optim.RMSprop(generator.parameters(), lr=5e-5)
#d_optim = optim.RMSprop(discriminator.parameters(), lr=5e-5)
m_optim = optim.Adam(mi_classifier.parameters())

In [None]:
def get_test_samples(testiter, testloader):
    try:
        data = testiter.next()
        return data, testiter
    except StopIteration:
        testiter = iter(testloader)
        data = testiter.next()
        return data, testiter
    
def train_mi(mi, mi_opt, target, trainloader, testloader, epochs, model_name):
    iterations = 0
    testiter = iter(testloader)
    for epoch in range(epochs):
        for data in trainloader:
            mi_opt.zero_grad()
            pos, _ = [d.cuda() for d in data]
            neg, testiter = get_test_samples(testiter, testloader)
            pos = torch.softmax(target(pos).view(-1, 10), 1)
            neg = torch.softmax(target(neg[0].cuda()).view(-1, 10), 1)
            # mi(neg) -> -1 and mi(pos) -> 1
            errneg = torch.tanh_(mi(neg)).mean()
            errpos = torch.tanh_(mi(pos)).mean()
            error = errneg - errpos
            error.backward()
            mi_opt.step()
            iterations += 1
            print("[%d/%d][%d] Error: %f Errpos: %f Errneg: %f"%(epoch, epochs,
                                                                iterations, error, errpos, errneg))
            if iterations % 1000 == 999:
                save_checkpoint(epoch=epoch, iters=iterations, model_name=model_name,net=mi, optim=mi_opt)

In [None]:
train_wgan(discriminator, generator, testloader, d_optim, g_optim, epochs=60, name="cifar10testsamplewgan", save_samples=True)