In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
%config InlineBackend.figure_format = 'retina'
import numpy as np 
import matplotlib.pyplot as plt
import torch
from torch import nn
from tqdm import tqdm
import sys
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import DataLoader, Dataset

In [None]:
import pickle
values_train = pickle.load(open("values_train.p", "rb"))
jacobians_train = pickle.load(open("jacobians_train.p", "rb"))

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device('cuda')

Tesla V100-SXM2-16GB


In [None]:
"""
Hyper-params
"""
noise_dim = 40 # noise vector dimension
hidden_dim = 256 # hidden dimension for both LSTM
epochs = 3001
lr = 1e-4
batch_size = 8
frame_num = 16
dataset_frame_num = 30
motion_dim = 60
n_critic = 5
clip_value = .01
lamb = 10 # gradient penalty lambda hyperparameter

In [None]:
class Motion(Dataset):
    def __init__(self, values, jacobians, jmax=None, jmin=None, random=True):
        if jmax == None:
            self.jmax = np.max([torch.max(j) for j in jacobians])
        else:
            self.jmax = jmax
        if jmin == None:
            self.jmin = np.min([torch.min(j) for j in jacobians])
        else:
            self.jmin = jmin
        self.data  = []
        for i, j in enumerate(jacobians):
            j = (j - self.jmin) / (self.jmax - self.jmin) * 2 - 1
            j = j.view(len(j), -1)
            v = values[i]
            v = v.view(len(v), -1)
            self.data.append(torch.cat((v, j), 1))
        self.random = random
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.random:
            # Draw random frame_num continuous frames
            start = np.random.randint(dataset_frame_num - frame_num) 
            end = start + frame_num
            return self.data[idx][start:end]
        else:
            return self.data[idx][:frame_num]

In [None]:
class GRUGenerator(nn.Module):
    """
    A Bi-GRU based generator. 
    It expects a sequence of noise vectors + start frame + final frame as input
    Args:
        in_dim: noise / e(start) / e(final) dimensionality
        out_dim: output dimensionality
        n_layers: number of gru layers
        hidden_dim: dimensionality of the hidden layer of grus
    Input: noise of shape (batch_size, seq_len, in_dim)
    Output: sequences of shape (batch_size, seq_len, out_dim)
    """

    def __init__(self, in_dim=noise_dim, out_dim=motion_dim, n_layers=1, hidden_dim=hidden_dim):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        self.motion_encoder = nn.Sequential(nn.Linear(out_dim, in_dim), nn.ReLU())
        self.gru = nn.GRU(in_dim * 3, hidden_dim, n_layers, batch_first=True, bidirectional=True)
        self.linear = nn.Sequential(nn.Linear(hidden_dim * 2, out_dim), nn.Tanh())

    def forward(self, noise, start, final):
        batch_size, seq_len = noise.size(0), noise.size(1)
        h_0 = torch.zeros(self.n_layers * 2, batch_size, self.hidden_dim).to(device)
        
        # Encode first and final images
        e_start = self.motion_encoder(start).unsqueeze(1).repeat(1, frame_num, 1)
        e_final = self.motion_encoder(final).unsqueeze(1).repeat(1, frame_num, 1)
        input = torch.cat((noise, e_start, e_final), -1)

        out, _ = self.gru(input, h_0)
        return self.linear(out)

In [None]:
class GRUDiscriminator(nn.Module):
    """
    A Bi-GRU based discriminator. 
    It expects a real/fake as input and outputs a probability for each timestep.
    Args:
        in_dim: input sequence dimensionality
        n_layers: number of gru layers
        hidden_dim: dimensionality of the hidden layer of grus
        
    Inputs: sequences of shape (batch_size, seq_len, out_dim)
    Output: prob sequence of shape (batch_size, seq_len, 1)
    """
    def __init__(self, in_dim=motion_dim, n_layers=1, hidden_dim=hidden_dim):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.gru = nn.GRU(in_dim, hidden_dim, n_layers, batch_first=True, bidirectional=True)
        self.linear = nn.Linear(hidden_dim * 2, 1) # WGAN update 1: remove sigmoid

    def forward(self, seqs):
        batch_size = len(seqs)
        h_0 = torch.zeros(self.n_layers * 2, batch_size, self.hidden_dim).to(device)
        out, _ = self.gru(seqs, h_0)
        return self.linear(out)

In [None]:
# Setup network
netD = GRUDiscriminator().to(device)
print(netD)
netG = GRUGenerator().to(device)
print(netG)

# Setup optimizer
optimizerD = optim.RMSprop(netD.parameters(), lr=lr) # WGAN update 4: don't use momentum
optimizerG = optim.RMSprop(netG.parameters(), lr=lr)

# Setup dataloader
train = Motion(values_train, jacobians_train)
print(train.jmax, train.jmin)
dataloader = DataLoader(train, batch_size, shuffle=True)

GRUDiscriminator(
  (gru): GRU(60, 256, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=512, out_features=1, bias=True)
)
GRUGenerator(
  (motion_encoder): Sequential(
    (0): Linear(in_features=60, out_features=40, bias=True)
    (1): ReLU()
  )
  (gru): GRU(120, 256, batch_first=True, bidirectional=True)
  (linear): Sequential(
    (0): Linear(in_features=512, out_features=60, bias=True)
    (1): Tanh()
  )
)
10.143672 -10.052552


In [None]:
def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(batch_size, frame_num, 1)
    alpha = alpha.expand(real_data.size()).to(device)
    interpolates = (alpha * real_data + (1 - alpha) * fake_data).to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates)
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True)[0]
    gradient_penalty = ((gradients.reshape(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() * lamb
    return gradient_penalty

In [None]:
Dcosts, Gcosts = [], []

for epoch in range(epochs):
    Dcosts_, Gcosts_ = [], []

    for it, real in tqdm(enumerate(dataloader)):
        ############################
        # (1) Update D network
        ###########################
        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True   # they are set to False below in netG update

        # Train with real data
        netD.zero_grad()
        real = real.to(device, dtype=torch.float)
        Dreal = -torch.mean(netD(real))
        Dreal.backward()

        # Train with fake data
        noise = torch.randn(len(real), frame_num, noise_dim, device=device)
        fake = netG(noise, real[:, 0], real[:, -1]).detach()
        Dfake = torch.mean(netD(fake))
        Dfake.backward()

        # Train with gradient penalty
        # gradient_penalty = calc_gradient_penalty(netD, real, fake)
        # gradient_penalty.backward()
        
        # Clip weights of discriminator
        for p in netD.parameters():
            p.data.clamp_(-clip_value, clip_value)
        
        Dcost = Dfake + Dreal
        optimizerD.step()
        
        ############################
        # (2) Update G network every n_critic iteration
        ###########################
        if it % n_critic == 0:
            for p in netD.parameters():  # to avoid computation
                p.requires_grad = False
            netG.zero_grad()

            noise = torch.randn(len(real), frame_num, noise_dim, device=device)
            fake = netG(noise, real[:, 0], real[:, -1])
            G = -torch.mean(netD(fake))
            G.backward()
            Gcost = G
            optimizerG.step()

        ###########################
        # (3) Report metrics
        ###########################
        Dcosts_.append(Dcost.item())
        Gcosts_.append(Gcost.item())

    ##### End of the epoch #####
    Dcosts.append(np.mean(Dcosts_))
    Gcosts.append(np.mean(Gcosts_))
    
    print('[%d/%d] Loss_D: %.4f   Loss_G: %.4f\n' 
          % (epoch, epochs, Dcosts[-1], Gcosts[-1]), end='', file=sys.stderr)
    
    # Checkpoint
    if epoch % 6000 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': netG.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
          }, f'drive/MyDrive/save3/netG_epoch_{epoch}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': netD.state_dict(),
            'optimizer_state_dict': optimizerD.state_dict(),
          }, f'drive/MyDrive/save3/netD_epoch_{epoch}.pth')

    with open("drive/MyDrive/save3/Dcosts.txt", "w") as output:
        output.write(str(Dcosts))
    with open("drive/MyDrive/save3/Gcosts.txt", "w") as output:
        output.write(str(Gcosts))

In [None]:
plt.plot(Dcosts)
plt.title('D Loss over time')
plt.xlabel('iter')
plt.ylabel('D Loss')
plt.show()

In [None]:
plt.plot(Gcosts, color='orange')
plt.title('G Loss over time')
plt.xlabel('iter')
plt.ylabel('G Loss')
plt.show()

# Testing

In [None]:
values_test = pickle.load(open("values_test.p", "rb"))
jacobians_test = pickle.load(open("jacobians_test.p", "rb"))

test = Motion(values_test, jacobians_test, train.jmax, train.jmin, random=False)
testloader = DataLoader(test, batch_size, shuffle=False)

In [91]:
noise1 = torch.randn(frame_num, noise_dim, device=device)
noise2 = torch.randn(frame_num, noise_dim, device=device)
noise3 = torch.randn(frame_num, noise_dim, device=device)

In [100]:
netG.eval()

values = torch.zeros((0, frame_num, 10, 2)).to(device)
jacobians = torch.zeros((0, frame_num, 10, 2, 2)).to(device)

for batch in testloader:
    batch = batch.to(device)
    noise = noise3.repeat((len(batch), 1, 1))
    fake = netG(noise, batch[:, 0], batch[:, -1])
    fake[:, 0]  = batch[:, 0]
    fake[:, -1] = batch[:, -1]
    value = fake[..., :20].view(len(batch), frame_num, 10, 2)
    jacobian = fake[..., 20:].view(len(batch), frame_num, 10, 2, 2)
    jacobian = (jacobian + 1) / 2 * (train.jmax - train.jmin) + train.jmin
    
    values = torch.cat((values, value), 0)
    jacobians = torch.cat((jacobians, jacobian), 0)

k = 3
with open(f'GAN_values_{k}.npy', 'wb') as f:
    np.save(f, values.cpu().detach().numpy())
with open(f'GAN_jacobians_{k}.npy', 'wb') as f:
    np.save(f, jacobians.cpu().detach().numpy())

In [101]:
values[0, 4]

tensor([[-0.0643, -0.4482],
        [-0.4222,  0.5108],
        [ 0.3311, -0.5475],
        [-0.3349,  0.3354],
        [-0.2248,  0.0074],
        [-0.5037,  0.2340],
        [-0.0560,  0.8224],
        [-0.6673, -0.3291],
        [-0.4919,  0.3791],
        [-0.0518,  0.3584]], device='cuda:0', grad_fn=<SelectBackward>)

In [83]:
# Load WGAN
start_epoch = 48000

checkpoint = torch.load(f'drive/MyDrive/save2/netG_epoch_{start_epoch}.pth')
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])

checkpoint = torch.load(f'drive/MyDrive/save2/netD_epoch_{start_epoch}.pth')
netD.load_state_dict(checkpoint['model_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])

errDss = open('drive/MyDrive/save2/Dcosts.txt', 'r')
errDss = errDss.readlines()[0][1:-1].split(',')
errDss = [float(s) for s in errDss[:start_epoch]]

errGss = open('drive/MyDrive/save2/Gcosts.txt', 'r')
errGss = errGss.readlines()[0][1:-1].split(',')
errGss = [float(s) for s in errGss[:start_epoch]]

print(netG.train())
print(netD.train())

GRUGenerator(
  (motion_encoder): Sequential(
    (0): Linear(in_features=60, out_features=40, bias=True)
    (1): ReLU()
  )
  (gru): GRU(120, 256, batch_first=True, bidirectional=True)
  (linear): Sequential(
    (0): Linear(in_features=512, out_features=60, bias=True)
    (1): Tanh()
  )
)
GRUDiscriminator(
  (gru): GRU(60, 256, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=512, out_features=1, bias=True)
)
