# C-RNN-GAN
http://mogren.one/publications/2016/c-rnn-gan/mogren2016crnngan.pdf

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# https://github.com/cjbayron/c-rnn-gan.pytorch/blob/master/train_simple.py

In [2]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [3]:
class Generator(nn.Module):
    def __init__(self, features, hidden_size):
        super(Generator, self).__init__()
        
        self.hidden_size = hidden_size
        self.features = features
        
        self.fc1 = nn.Linear(in_features=(features*2), out_features=hidden_size)
        self.lstm1 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.dropout = nn.Dropout(p=0.6)
        self.lstm2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=features)
        
    def forward(self, z, states):
        z = z.to(device)
        batch_size, seq_len, num_feats = z.shape
        z = torch.split(z, 1, dim=1)
        z = [z_step.squeeze(dim=1) for z_step in z]
        
        prev_gen = torch.empty([batch_size, num_feats]).uniform_()
        prev_gen = prev_gen.to(device)
        
        state1, state2 = states
        gen_feats = []
        for z_step in z:
            concat_in = torch.cat((z_step, prev_gen), dim=-1)
            out = F.relu(self.fc1(concat_in))
            h1, c1 = self.lstm1(out, state1)
            h1 = self.dropout(h1)
            h2, c2 = self.lstm2(h1, state2)
            prev_gen = self.fc2(h2)
            gen_feats.append(prev_gen)
            state1 = (h1, c1)
            state2 = (h2, c2)
        
        # seq_len * (batch_size * num_feats) -> (batch_size * seq_len * num_feats)
        gen_feats = torch.stack(gen_feats, dim=1)
        
        states = (state1, state2)
        return gen_feats, states

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        
        hidden = ( (weight.new(batch_size, self.hidden_size).zero_().to(device),
                   weight.new(batch_size, self.hidden_size).zero_().to(device)),
                   (weight.new(batch_size, self.hidden_size).zero_().to(device),
                   weight.new(batch_size, self.hidden_size).zero_().to(device)) )

        return hidden

In [20]:
class Discriminator(nn.Module):
    def __init__(self, features, hidden_size):
        super(Discriminator, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.dropout = nn.Dropout(p=.5)
        self.lstm = nn.LSTM(input_size=features, hidden_size=hidden_size,
                           num_layers=self.num_layers, batch_first=True, dropout=0.5,
                           bidirectional=True)
    
        self.fc = nn.Linear(in_features=(2*hidden_size), out_features=1)
        
    def forward(self, sequence, state):
        sequence = sequence.to(device)
        drop_in = self.dropout(sequence)
        
        lstm_out, state = self.lstm(drop_in, state)
        out = self.fc(lstm_out)
        out = torch.sigmoid(out)
        
        num_dims = len(out.shape)
        reduction_dims = tuple(range(1, num_dims))
        out = torch.mean(out, dim=reduction_dims)
        
        return out, lstm_out, state
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        layer_mult = 2
        
        hidden = (weight.new(self.num_layers * layer_mult, batch_size, self.hidden_size).zero_().to(device),
                 weight.new(self.num_layers * layer_mult, batch_size, self.hidden_size).zero_().to(device))
        
        return hidden

class DLoss(nn.Module):
    ''' C-RNN-GAN discriminator loss
    '''
    def __init__(self):
        super(DLoss, self).__init__()

    def forward(self, logits_real, logits_gen):
        ''' Discriminator loss
        logits_real: logits from D, when input is real
        logits_gen: logits from D, when input is from Generator
        '''
        logits_real = torch.clamp(logits_real, EPSILON, 1.0)
        d_loss_real = -torch.log(logits_real)

        logits_gen = torch.clamp((1 - logits_gen), EPSILON, 1.0)
        d_loss_gen = -torch.log(logits_gen)

        batch_loss = d_loss_real + d_loss_gen
        return torch.mean(batch_loss)

def control_grad(model, freeze=True):
    ''' Freeze/unfreeze optimization of model
    '''
    if freeze:
        for param in model.parameters():
            param.requires_grad = False

    else: # unfreeze
        for param in model.parameters():
            param.requires_grad = True
    
def check_loss(model, loss):
    ''' Check loss and control gradients if necessary
    '''
    control_grad(model['g'], freeze=False)
    control_grad(model['d'], freeze=False)

    if loss['d'] == 0.0 and loss['g'] == 0.0:
        print('Both G and D train loss are zero. Exiting.')
        return False
    elif loss['d'] == 0.0: # freeze D
        control_grad(model['d'], freeze=True)
    elif loss['g'] == 0.0: # freeze G
        control_grad(model['g'], freeze=True)
    elif loss['g'] < 2.0 or loss['d'] < 2.0:
        control_grad(model['d'], freeze=True)
        if loss['g']*0.7 > loss['d']:
            control_grad(model['g'], freeze=True)


# Training Examples

Training, trying to output numbers with the function `f(n) = 2*f(n-1)`

In [57]:
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
npdata = np.stack([2 ** np.arange(10)[:, np.newaxis] * np.random.rand() for i in range(4)])

In [58]:
data = TensorDataset(torch.from_numpy(npdata))
data[0]

(tensor([[  0.2647],
         [  0.5293],
         [  1.0587],
         [  2.1174],
         [  4.2348],
         [  8.4695],
         [ 16.9391],
         [ 33.8782],
         [ 67.7564],
         [135.5128]], dtype=torch.float64),)

In [70]:
dataloader = DataLoader(data, shuffle=False)

In [71]:
list(enumerate(dataloader))[3]

(3,
 [tensor([[[0.0064],
           [0.0129],
           [0.0258],
           [0.0515],
           [0.1030],
           [0.2060],
           [0.4120],
           [0.8241],
           [1.6482],
           [3.2964]]], dtype=torch.float64)])

Each element of data loader is a sequence. Thats it We just put this into a function to make a train set and validation set easily

In [75]:
for batch in dataloader:
    print(batch.shape)

AttributeError: 'list' object has no attribute 'shape'

In [76]:
for batch, in dataloader:
    print(batch.shape)

ValueError: not enough values to unpack (expected 2, got 1)

In [24]:
def dummy_dataloader(seq_len, batch_size, num_sample):
    ''' Dummy data generator (for debugging purposes)
    '''
    # the following code generates random data of numbers
    # where each number is twice the prev number
    np_data = np.stack([(2 ** np.arange(seq_len))[:, np.newaxis] \
                        * np.random.rand() for i in range(num_sample)])

    data = TensorDataset(torch.from_numpy(np_data))
    return DataLoader(data, shuffle=True, batch_size=batch_size)

In [25]:
MAX_SEQ_LEN = 20
trn_dataloader = dummy_dataloader(MAX_SEQ_LEN, 13, 7)
val_dataloader = dummy_dataloader(MAX_SEQ_LEN, 13, 4)

In [28]:
model = {
    'g': Generator(features=1, hidden_size=100),
    'd': Discriminator(features=1, hidden_size=100)
}

## Test model sizes

In [38]:
g_states = model['g'].init_hidden(real_batch_sz)
d_state = model['d'].init_hidden(real_batch_sz)
z = torch.empty([100, MAX_SEQ_LEN, 1]).uniform_()
real_batch_sz

1

In [36]:
g_feats, _ = model['g'](z, g_states)

RuntimeError: Input batch size 100 doesn't match hidden[0] batch size 1

In [None]:
G_LRN_RATE = 0.001
D_LRN_RATE = 0.001
PERFORM_LOSS_CHECKING = True
MAX_GRAD_NORM = 5.0
MAX_EPOCHS = 500
L2_DECAY = 1.0

optimizer = {
    'g': optim.Adam(model['g'].parameters(), G_LRN_RATE),
    'd': optim.Adam(model['d'].parameters(), D_LRN_RATE)
}
criterion = {
    'g': nn.MSELoss(reduction='sum'),
    'd': DLoss()
}

NUM_EPOCHS = 100
freeze_g = False
freeze_d = False

In [23]:
for epoch in range(NUM_EPOCHS):
    # TRAINING
    loss = {
        'g': 10.0,
        'd': 10.0
    }
    
    features = model['g'].features
    model['g'].train()
    model['d'].train()

    g_loss_total = 0.0
    d_loss_total = 0.0
    num_corrects = 0
    num_sample = 0

    log_sum_real = 0.0
    log_sum_gen = 0.0
    
    for (batch_input, ) in dataloader:
        real_batch_sz = len(batch_input)
        batch_input = batch_input.type(torch.FloatTensor)
        
        if PERFORM_LOSS_CHECKING == True:
            if not check_loss(model, loss):
                break
        
        g_states = model['g'].init_hidden(real_batch_sz)
        d_state = model['d'].init_hidden(real_batch_sz)
        
        ### GENERATOR ###
        if not freeze_g:
            optimizer['g'].zero_grad()
            
        z = torch.empty([real_batch_sz, MAX_SEQ_LEN, features]).uniform_()
        
        g_feats, _ = model['g'](z, g_states)
        _, d_feats_real, _ = model['d'](batch_input, d_state)
        _, d_feats_gen, _ = model['d'](g_feats, d_state)
        
        loss['g'] = criterion['g'](d_feats_real, d_feats_gen)
        if not freeze_g:
            loss['g'].backward()
            # nn.utils.clip_grad_norm_(model['g'].parameters(), max_norm=MAX_GRAD_NORM)
            optimizer['g'].step()

        #### DISCRIMINATOR ####
        if not freeze_d:
            optimizer['d'].zero_grad()
        
        
        d_logits_real, _, _ = model['d'](batch_input, d_state)
        d_logits_gen, _, _ = model['d'](g_feats.detach(), d_state)
        
        loss['d'] = criterion['d'](d_logits_real, d_logits_gen)
        
        log_sum_real += d_logits_real.sum().item()
        log_sum_gen += d_logits_gen.sum().item()
        
        if not freeze_d:
            loss['d'].backward()
            nn.utils.clip_grad_norm_(model['d'].parameters(), max_norm=MAX_GRAD_NORM)
            optimizer['d'].step()
            
        g_loss_total += loss['g'].item()
        d_loss_total += loss['d'].item()
        num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item()
        num_sample += real_batch_sz
        
    g_loss_avg, d_loss_avg = 0.0, 0.0
    d_acc = 0.0
    if num_sample > 0:
        g_loss_avg = g_loss_total / num_sample
        d_loss_avg = d_loss_total / num_sample
        d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated)

        print("Trn: ", log_sum_real / num_sample, log_sum_gen / num_sample)

    print("[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f", g_loss_avg, d_loss_avg, d_acc)
    

    # VALIDATE
    model['g'].eval()
    model['d'].eval()
    
    g_loss_total = 0.0
    d_loss_total = 0.0
    num_corrects = 0
    num_sample = 0

    log_sum_real = 0.0
    log_sum_gen = 0.0

    for (batch_input, ) in dataloader:

        real_batch_sz = len(batch_input)
        batch_input = batch_input.type(torch.FloatTensor)

        # initial states
        g_states = model['g'].init_hidden(real_batch_sz)
        d_state = model['d'].init_hidden(real_batch_sz)

        #### GENERATOR ####
        # prepare inputs
        z = torch.empty([real_batch_sz, MAX_SEQ_LEN, features]).uniform_() # random vector

        # feed inputs to generator
        g_feats, _ = model['g'](z, g_states)
        # feed real and generated input to discriminator
        d_logits_real, d_feats_real, _ = model['d'](batch_input, d_state)
        d_logits_gen, d_feats_gen, _ = model['d'](g_feats, d_state)
        # print("Val: ", d_logits_real.mean(), d_logits_gen.mean())
        log_sum_real += d_logits_real.sum().item()
        log_sum_gen += d_logits_gen.sum().item()

        # calculate loss
        g_loss = criterion['g'](d_feats_real, d_feats_gen)
        d_loss = criterion['d'](d_logits_real, d_logits_gen)

        g_loss_total += g_loss.item()
        d_loss_total += d_loss.item()
        num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item()
        num_sample += real_batch_sz

    g_loss_avg, d_loss_avg = 0.0, 0.0
    d_acc = 0.0
    if num_sample > 0:
        g_loss_avg = g_loss_total / num_sample
        d_loss_avg = d_loss_total / num_sample
        d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated)

        print("Val: ", log_sum_real / num_sample, log_sum_gen / num_sample)

    

[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f 0.0 0.0 0.0


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 1