In [32]:
"""
Implementation of WGAN with gradient penalty
"""
import pytorch_lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F


# Construct residual block
class ResBlock(nn.Module):
    def __init__(self, num_channels, res_rate=0.3):
        super(ResBlock, self).__init__()
        self.num_channels = num_channels
        self.res_rate = res_rate
        
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(num_channels, num_channels, kernel_size=5, stride=1, padding='same'),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU()
        )
        
        self.conv_block2 = nn.Sequential(
            nn.Conv1d(num_channels, num_channels, kernel_size=5, stride=1, padding='same'),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU()
        )

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.uniform_(m.weight, 
                                 -torch.sqrt(torch.tensor(3.)) * torch.sqrt(4. / torch.tensor(5 * self.num_channels + 5 * self.num_channels)), 
                                 torch.sqrt(torch.tensor(3.)) * torch.sqrt(4. / torch.tensor(5 * self.num_channels + 5 * self.num_channels)))
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        res = self.conv_block1(x)
        res = self.conv_block2(res)

        return x + self.res_rate * res


# Construct generator with res blocks
class Generator(nn.Module):
    def __init__(self, latent_dim, num_channels, seq_len, res_rate, vocab_size=4, res_layers=5):
        super().__init__()
        self.seq_len = seq_len
        self.num_channels = num_channels
        
        # Linear layer to transform the latent vector
        self.linear = nn.Linear(latent_dim, seq_len * num_channels)
        
        self.res_blocks = nn.ModuleList([ResBlock(num_channels, res_rate) for _ in range(res_layers)])
        
        self.conv = nn.Conv1d(num_channels, vocab_size, kernel_size=1, stride=1, padding='same')
        
    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, self.num_channels, self.seq_len)
        
        for res_block in self.res_blocks:
            x = res_block(x)
            
        x = self.conv(x)
        
        return F.softmax(x, dim=1)
        
    
# Construct Critic(discriminator)
class Critic(nn.Module):
    def __init__(self, num_channels, seq_len, vocab_size, res_rate=0.3, res_layers=5):
        super(Critic, self).__init__()
        self.seq_len = seq_len
        self.num_channels = num_channels

        # Initial convolution layer
        self.conv1 = nn.Conv1d(vocab_size, num_channels, kernel_size=1, stride=1, padding='same')

        # Residual blocks
        self.res_blocks = nn.ModuleList([ResBlock(num_channels, res_rate) for _ in range(res_layers)])

        # Final linear layer for scoring
        self.fc = nn.Linear(seq_len * num_channels, 1)
        
        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, 
                                 -torch.sqrt(torch.tensor(3.)) * torch.sqrt(2. / torch.tensor(m.weight.size(0) + m.weight.size(1))), 
                                  torch.sqrt(torch.tensor(3.)) * torch.sqrt(2. / torch.tensor(m.weight.size(0) + m.weight.size(1))))
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)

        for res_block in self.res_blocks:
            x = res_block(x)

        # Flatten the output for the linear layer
        x = x.view(-1, self.seq_len * self.num_channels)

        score = self.fc(x)
        return score


In [36]:
discriminator = Critic(num_channels=100, seq_len=10, vocab_size=4, res_layers=5)
input_tensor = torch.randn(32, 4, 10)  # Example input tensor with batch size 32
output_scores = discriminator(input_tensor)
a = output_scores.reshape(-1)

In [38]:
a.shape

torch.Size([32])

In [75]:
generator = Generator(latent_dim=100, num_channels=100, seq_len=10, vocab_size=4, res_layers=5, res_rate=0.3)
latent_vector = torch.randn(32, 100)  # Example latent vector
output = generator(latent_vector)
output[0]

tensor([[0.2502, 0.2420, 0.2358, 0.1445, 0.0912, 0.2917, 0.0784, 0.2197, 0.4026,
         0.2060],
        [0.1809, 0.1357, 0.1790, 0.0696, 0.2154, 0.1293, 0.1503, 0.2322, 0.1444,
         0.2056],
        [0.1902, 0.1041, 0.2549, 0.1560, 0.0709, 0.0888, 0.0871, 0.2930, 0.1521,
         0.3276],
        [0.3786, 0.5182, 0.3303, 0.6299, 0.6225, 0.4902, 0.6843, 0.2551, 0.3009,
         0.2608]], grad_fn=<SelectBackward0>)

In [78]:
from WGAN_Model import WGAN

In [79]:
WGAN(seq_len=10, vocab_size=4)

WGAN(
  (generator): Generator(
    (linear): Linear(in_features=100, out_features=1000, bias=True)
    (res_blocks): ModuleList(
      (0-4): 5 x ResBlock(
        (conv_block1): Sequential(
          (0): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=same)
          (1): ReLU()
        )
        (conv_block2): Sequential(
          (0): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=same)
          (1): ReLU()
        )
      )
    )
    (conv): Conv1d(100, 4, kernel_size=(1,), stride=(1,), padding=same)
  )
  (critic): Critic(
    (conv1): Conv1d(4, 100, kernel_size=(1,), stride=(1,), padding=same)
    (res_blocks): ModuleList(
      (0-4): 5 x ResBlock(
        (conv_block1): Sequential(
          (0): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=same)
          (1): ReLU()
        )
        (conv_block2): Sequential(
          (0): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=same)
          (1): ReLU()
        )
      )
    )
    (fc): Lin