In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import torch.optim as optim
import torch.autograd as autograd

from mog_eigval_dist import *
from model import *
from utils import *

In [None]:
# 2-D dist
# target dist -- normal ([-1, 1], 0.4)
# aux dist -- normal ([1, -1], 0.4)
# input dist -- normal (0, 1)
tmu1 = [-1, 1]
tmu2 = [1, -1]
amu1 = [-1, -1]
amu2 = [1, 1]
sigma = [0.4, 0.4]
batch_size = 512
inputs = np.random.normal(loc=[0,0], scale=[0.4,0.4], size=[batch_size, 2])
train1 = np.random.normal(loc=[-1,1], scale=[0.4,0.4], size=[batch_size, 2])
train2 = np.random.normal(loc=[1,-1], scale=[0.4,0.4], size=[batch_size, 2])
aux1 = np.random.normal(loc=[-1,-1], scale=[0.4,0.4], size=[batch_size, 2])
aux2 = np.random.normal(loc=[1,1], scale=[0.4,0.4], size=[batch_size, 2])

out = np.vstack([train1, train2, aux1, aux2])
kde(out[:,0], out[:,1], save_file='tgtauxmog.png')
kde(inputs[:,0], inputs[:,1])
plt.plot(inputs[:,0], inputs[:,1], 'bo')

In [None]:
def quick_load_model(net, model_name):
    _,_,sd,_,_ = load_checkpoint(model_name)
    net.load_state_dict(sd)
    
def sample_dist(tmu, tsig, batch_size=512, n=1):
    t = np.random.normal(loc=tmu, scale=tsig, size=[batch_size//n, 2])
    return t
def get_label(n, l, size=1):
    # return one-hot encoding of length l with nth element 1
    # return torch.FloatTensor(np.identity(n=l)[n,:].reshape(1,2)).expand(size, -1).contiguous()
    return np.repeat(np.identity(n=l)[n,:].reshape(1,2), size, axis=0)

def train_classifier(net, opt, loss_fn, tmus, tsigs, model_name, iterations=4000, batch_size=512):
    assert(len(tmus) == len(tsigs))
    ndist = len(tmus)
    for iteration in range(iterations):
        opt.zero_grad()
        inputs = torch.FloatTensor(np.vstack([sample_dist(tmus[i], tsigs[i], n=ndist) 
                                             for i in range(ndist)])).cuda()
        labels = torch.LongTensor(np.hstack([[i]*(batch_size//ndist) for i in range(ndist)])).cuda()
        loss = loss_fn(net(inputs), labels)
        loss.backward()
        opt.step()
        print('[%d/%d] Loss: %f'%(iteration, iterations, loss))
        if iteration%999==0:
            save_checkpoint(epoch=0, iters=iteration, net=net, optim=opt, model_name=model_name)


In [None]:
tgt_classifier = MLP_Discriminator(depth=4, width=16, activation=SELU,insize=2,outsize=2).cuda()
for fc in get_modules_of_type(module_type=nn.Linear, net=tgt_classifier):
    selu_init(fc)
tc_optim = optim.Adam(tgt_classifier.parameters())

In [None]:
train_classifier(tgt_classifier, tc_optim, nn.CrossEntropyLoss(), [tmu1, tmu2], [sigma, sigma], "MLPclass_tgt")

In [None]:
# train mi model
def train_mi(net, opt, tgt, tmus, amus, sigma=sigma, loss_fn=nn.BCELoss(), model_name='MLPmi', iterations=4000,batch_size=512):
    assert(len(tmus) == len(amus) == len(sigma))
    ndist = len(tmus)
    for iteration in range(iterations):
        opt.zero_grad()
        traindist = torch.FloatTensor(np.vstack([sample_dist(mu, sigma, n=ndist) for mu in tmus])).cuda()
        auxdist = torch.FloatTensor(np.vstack([sample_dist(mu, sigma, n=ndist) for mu in amus])).cuda()
        #print(traindist.shape, auxdist.shape)
        trainout = tgt(traindist)
        auxout = tgt(auxdist)
        out = torch.cat([trainout, auxout], dim=0)
        #print(out.shape)
        labels = torch.FloatTensor(np.vstack([[[np.abs(i-1)]]*batch_size for i in range(ndist)])).cuda()
        #labels = torch.FloatTensor(np.vstack([get_label(i, ndist, batch_size) for i in range(ndist)])).cuda()
        #print(labels.shape, labels)
        #loss = loss_fn(net(out), labels)
        loss = -(labels * torch.log(torch.sigmoid(net(out))) + (1 - labels) * torch.log(1 - torch.sigmoid(net(out)))).mean()
        loss.backward()
        opt.step()
        print('[%d/%d] Loss: %f'%(iteration, iterations, loss))
        if iteration%999==0:
            save_checkpoint(epoch=0, iters=iteration, net=net, optim=opt, model_name=model_name)
# MI adversarial training
def atrain_mi(tgt, mic, m_opt, G, g_opt, rsampler, isampler=None, model_name='MIadv', iterations=8000):
    for iteration in range(iterations):
        niters = 100 if iteration < 25 or iteration % 500 == 0 else 5
        for i in range(niters):
            real = rsampler()
            fake = G(isampler()) if isampler else G()
            L_D = mic(tgt(real)).mean() - mic(tgt(fake.detach())).mean() + calc_gradient_penalty(mic, tgt(real), 
                                                                                                 tgt(fake.detach()))
            m_opt.zero_grad()
            L_D.backward()
            m_opt.step()
        g_opt.zero_grad()
        L_G = mic(tgt(fake)).mean()
        L_G.backward()
        g_opt.step()
        print('[%d/%d] Loss Mic: %f Loss G: %f'%(iteration, iterations, L_D, L_G))
        if iteration % 1000 == 999:
            save_checkpoint(epoch=0, iters=iteration, net=G, optim=g_opt, model_name=model_name+'_G')
            save_checkpoint(epoch=0, iters=iteration, net=mic, optim=m_opt, model_name=model_name+'_Mic')
            fake = fake.detach().cpu().numpy()
            kde(fake[:,0],fake[:,1], show=False, save_file="./expt_results/{0}_{1}.png".format(model_name, iteration))

In [None]:
mi_classifier = MLP_Discriminator(depth=16, width=16, activation=SELU,insize=2,outsize=1).cuda()
G = MLP_Generator(depth=10, width=16, activation=SELU,bs=512,insize=2,outsize=2).cuda()

for fc in get_modules_of_type(module_type=nn.Linear, net=G):
    selu_init(fc)
g_optim = optim.Adam(G.parameters(), lr=1e-4)
for fc in get_modules_of_type(mi_classifier, nn.Linear):
    selu_init(fc)
mi_opt = optim.Adam(mi_classifier.parameters(), lr=3e-4)

In [None]:
tgt_classifier = MLP_Discriminator(depth=4, width=16, activation=SELU,insize=2,outsize=2).cuda()
_,_,tgt_sd,_,_ = load_checkpoint(model_name="MLPclass_tgt")
tgt_classifier.load_state_dict(tgt_sd)

In [None]:
train_mi(mi_classifier, mi_opt, tgt_classifier, [tmu1, tmu2], [amu1, amu2])

In [None]:
atrain_mi(tgt_classifier, mi_classifier, mi_opt, G, g_optim, sampler([tmu1, tmu2], [sigma, sigma], 512), model_name='MIadv16', 
          iterations=5000)

In [None]:
G = MLP_Generator(depth=10, width=16, activation=SELU,bs=512,insize=2,outsize=2).cuda()
D = MLP_Discriminator(depth=4, width=16, activation=SELU,insize=2,outsize=1).cuda()

for fc in get_modules_of_type(module_type=nn.Linear, net=G):
    selu_init(fc)
for fc in get_modules_of_type(module_type=nn.Linear, net=D):
    selu_init(fc)

#g_optim = optim.RMSprop(G.parameters(), alpha=5e-4)
#d_optim = optim.RMSprop(D.parameters(), alpha=5e-4)
g_optim = optim.Adam(G.parameters(), lr=1e-4)
d_optim = optim.Adam(D.parameters())

In [None]:
out = G(torch.FloatTensor(inputs).cuda()).detach().cpu().numpy()
kde(out[:,0], out[:,1], save_file='./expt_results/WMLPG_gp_init1.png')

In [None]:
class sampler:
    def __init__(self, mus, sigs, bs):
        assert len(mus) == len(sigs)
        self.mus = mus
        self.sigs = sigs
        self.bs = bs
        self.ndist = len(mus)
    def __call__(self):
        samples = np.vstack([sample_dist(self.mus[i], self.sigs[i],batch_size=self.bs, n=self.ndist) for i in range(self.ndist)])
        return torch.FloatTensor(samples).cuda()
    
class MiganToy:
    def __init__(self, tgt, mic, iloader, aloader, G=None, D=None, g_opt = None, d_opt = None, model_name="Migan_toy"):
        self.tgt = tgt
        self.mic = mic
        self.iloader = iloader
        self.aloader = aloader
        self.G = G if G else MLP_Generator(depth=4, width=16, activation=SELU,bs=512,insize=2,outsize=2).cuda()
        self.D = D if D else MLP_Discriminator(depth=4, width=16, activation=SELU,insize=2,outsize=1).cuda()
        if not G:
            for fc in get_modules_of_type(self.G, nn.Linear):
                selu_init(fc)
        if not D:
            for fc in get_modules_of_type(self.D, nn.Linear):
                selu_init(fc)    
        self.g_opt = g_opt if g_opt else optim.Adam(self.G.parameters())
        self.d_opt = d_opt if d_opt else optim.Adam(self.D.parameters())
         
    def get_migan_loss(self):
        real, fake = self.aloader(), self.G(self.iloader())
        L_D = self.D(real).mean() - self.D(fake.detach()).mean() + calc_gradient_penalty(D, tgt, fake.detach())
        
        
def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA=10, use_cuda=True, BATCH_SIZE=512):
    # print "real_data: ", real_data.size(), fake_data.size()
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous().view(BATCH_SIZE, 2)
    alpha = alpha.cuda() if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if use_cuda:
        interpolates = interpolates.cuda()
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

def train_gan(G, D, g_optim, d_optim, sampler_fn, model_name, batch_size=512, iter_offset=0, iterations=100000, clip=0.1):
    for iteration in range(iterations):
        """L_D = -(torch.log(torch.sigmoid(D(tgt))).mean()
                + torch.log(1 - torch.sigmoid(D(fake.detach()))).mean())
        L_G = (-torch.log(torch.sigmoid(D(fake))).mean())
        """
        niters = 100 if iteration < 25 or iteration % 500 == 0 else 5
        #for p in G.parameters():
        #    p.requires_grad = False
        for i in range(niters):
          #  for p in D.parameters():
          #      p.data.clamp_(-clip, clip)
            inputs = torch.FloatTensor(np.random.normal(loc=[0,0], scale=[0.4,0.4], size=[batch_size, 2])).cuda()
            #tgt = torch.FloatTensor(np.random.normal(loc=tmu, scale=tsig, size=[batch_size, 2])).cuda()
            tgt = sampler_fn()
            fake = G(inputs)
            L_D = D(tgt).mean() - D(fake.detach()).mean() + calc_gradient_penalty(D, tgt, fake.detach())
            #L_D, L_G = loss_fn()
            d_optim.zero_grad()
            L_D.backward()
            d_optim.step()
        #for p in G.parameters():
            #p.requires_grad = True
        g_optim.zero_grad()
        L_G = D(fake).mean()
        L_G.backward()
        g_optim.step()
        
        print('[%d/%d] Loss D: %f Loss G: %f'%(iteration, iterations, L_D, L_G))
        if iteration % 1000 == 999:
            save_checkpoint(epoch=0, iters=iteration, net=G, optim=g_optim, model_name=model_name+'_G')
            save_checkpoint(epoch=0, iters=iteration, net=D, optim=d_optim, model_name=model_name+'_D')
            fake = fake.detach().cpu().numpy()
            kde(fake[:,0],fake[:,1], show=False, save_file="{0}_{1}.png".format(model_name, iteration+iter_offset))

In [None]:
train_gan(G, D, g_optim, d_optim, sampler([amu1, amu2], [[0.4,0.4]]*2,512), 'WMLP_toy_aux', iterations=10000)

In [None]:
out2 = G(torch.FloatTensor(inputs).cuda()).detach().cpu().numpy()
kde(out2[:,0], out2[:, 1])#save_file='./expt_results/WMLPG_gp_aux1.png')#bbox=[-10,10,-10,10] )#

In [None]:
mic = MLP_Discriminator(depth=4, width=32, activation=SELU,insize=2,outsize=1).cuda()
print(mic)
quick_load_model(mic, 'MIadv32_Mic')

In [None]:
#print(tgt_classifier(sampler([tmu1,tmu2],[sigma, sigma], 32)()))
#print(tgt_classifier(sampler([[-40, 40],[40, -40]],[sigma, sigma], 32)()))
print(mi_classifier(tgt_classifier(sampler([tmu1,tmu2],[sigma, sigma], 32)())))
print(mi_classifier(tgt_classifier(sampler([[-40, 40],[40, -40]],[sigma, sigma], 32)())))
print(mi_classifier(tgt_classifier(sampler([[-40, -40],[40, 40]],[sigma, sigma], 32)())))

In [None]:
G = MLP_Generator(depth=4, width=16, activation=SELU,bs=512,insize=2,outsize=2).cuda()
for fc in get_modules_of_type(G, nn.Linear):
    selu_init(fc)
g_opt = optim.Adam(G.parameters())

In [None]:
iterations = 4000
for iteration in range(iterations):
    g_opt.zero_grad()
    inputs = torch.FloatTensor(np.random.normal(loc=[0,0], scale=[0.4,0.4], size=[batch_size, 2])).cuda()
    outputs = G(inputs)
    loss = -torch.sigmoid(mi_classifier(tgt_classifier(outputs))).mean()
    loss.backward()
    g_opt.step()
    print('[%d/%d] Loss: %f'%(iteration, iterations, loss))
    if iteration%1000 == 999:
        save_checkpoint(0, iteration, G, g_opt, model_name='migansimple')
        outputs = outputs.detach().cpu().numpy()
        kde(outputs[:,0], outputs[:,1], bbox=[np.amin(outputs,axis=0)[0], np.amax(outputs,axis=0)[0], 
                                              np.amin(outputs,axis=0)[1], np.amax(outputs,axis=0)[1]], 
            save_file="migansimple_{0}.png".format(iteration))
        

In [None]:
mat = np.arange(2)*0
for i in range(6):
    i+=1
    mat = np.vstack([mat, (np.arange(2)+i)*i])
print(mat)
print(mat[:,0])
print(np.amax(mat[:,0], axis=0))
print(np.amax(mat[:,0], axis=0))
print(np.amin(mat[:,0], axis=0))
print(np.amin(mat[:,0], axis=1))