# C-RNN-GAN
http://mogren.one/publications/2016/c-rnn-gan/mogren2016crnngan.pdf

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [3]:
class Generator(nn.Module):
    def __init__(self, features, hidden_size):
        super(Generator, self).__init__()
        
        self.hidden_size = hidden_size
        self.features = features
        
        self.fc1 = nn.Linear(in_features=(features*2), out_features=hidden_size)
        self.lstm1 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.dropout = nn.Dropout(p=0.6)
        self.lstm2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=features)
        
    def forward(self, z, states):
        z = z.to(device)
        batch_size, seq_len, num_feats = z.shape
        z = torch.split(z, 1, dim=1)
        z = [z_step.squeeze(dim=1) for z_step in z]
        
        prev_gen = torch.empty([batch_size, num_feats]).uniform_()
        prev_gen = prev_gen.to(device)
        
        state1, state2 = states
        gen_feats = []
        for z_step in z:
            concat_in = torch.cat((z_step, prev_gen), dim=-1)
            out = F.relu(self.fc1(concat_in))
            h1, c1 = self.lstm1(out, state1)
            h1 = self.dropout(h1)
            h2, c2 = self.lstm2(h1, state2)
            prev_gen = self.fc2(h2)
            gen_feats.append(prev_gen)
            state1 = (h1, c1)
            state2 = (h2, c2)
        
        # seq_len * (batch_size * num_feats) -> (batch_size * seq_len * num_feats)
        gen_feats = torch.stack(gen_feats, dim=1)
        
        states = (state1, state2)
        return gen_feats, states

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        
        hidden = ( (weight.new(batch_size, self.hidden_size).zero_().to(device),
                   weight.new(batch_size, self.hidden_size).zero_().to(device)),
                   (weight.new(batch_size, self.hidden_size).zero_().to(device),
                   weight.new(batch_size, self.hidden_size).zero_().to(device)) )

        return hidden

In [16]:
gmodel = Generator(13, 100).to(device)
dmodel = Discriminator(13, 100).to(device)

In [22]:
BATCH_SIZE = 1
g_states = gmodel.init_hidden(BATCH_SIZE)
d_state = dmodel.init_hidden(BATCH_SIZE)

In [18]:
input_seq = torch.arange(13. * 4.).reshape(1, 4, 13).to(device)

In [19]:
gmodel.eval()
g_feats, _ = gmodel(input_seq, g_states)

In [20]:
class Discriminator(nn.Module):
    def __init__(self, features, hidden_size):
        super(Discriminator, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.dropout = nn.Dropout(p=.5)
        self.lstm = nn.LSTM(input_size=features, hidden_size=hidden_size,
                           num_layers=self.num_layers, batch_first=True, dropout=0.5,
                           bidirectional=True)
    
        self.fc = nn.Linear(in_features=(2*hidden_size), out_features=1)
        
    def forward(self, sequence, state):
        sequence = sequence.to(device)
        drop_in = self.dropout(sequence)
        
        lstm_out, state = self.lstm(drop_in, state)
        out = self.fc(lstm_out)
        out = torch.sigmoid(out)
        
        num_dims = len(out.shape)
        reduction_dims = tuple(range(1, num_dims))
        out = torch.mean(out, dim=reduction_dims)
        
        return out, lstm_out, state
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        layer_mult = 2
        
        hidden = (weight.new(self.num_layers * layer_mult, batch_size, self.hidden_size).zero_().to(device),
                 weight.new(self.num_layers * layer_mult, batch_size, self.hidden_size).zero_().to(device))
        
        return hidden
        

In [25]:
d_logits_gen, _, _ = dmodel(g_feats, d_state)