In [None]:
import Ipynb_importer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import numpy as np

class SpecGANGenerator(nn.Module):
    def __init__(self, model_size=64, ngpus=1, num_channels=2,
                 latent_dim=100, post_proc_filt_len=512,
                 verbose=False, upsample=True):
        super(SpecGANGenerator, self).__init__()
        self.ngpus = ngpus
        self.model_size = model_size  # d
        self.num_channels = num_channels  # c
        self.latent_di = latent_dim
        self.post_proc_filt_len = post_proc_filt_len
        self.verbose = verbose
        
        self.fc1 = nn.Linear(latent_dim, 256 * model_size)
        
        
        self.upsample1 = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d( 16*model_size, 8*model_size, 5, 2, 2, output_padding=1 ,bias=False),
        )
        self.upsample2 = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d( 8*model_size, 4*model_size, 5, 2, 2, output_padding=1 ,bias=False),
        )
        self.upsample3 = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d( 4*model_size, 2*model_size, 5, 2, 2,output_padding=1 , bias=False),
        )
        self.upsample4 = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d( 2*model_size, model_size, 5, 2, 2,output_padding=1 , bias=False),
        )
        self.upsample5 = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d( model_size, num_channels, 5, 2, 2,output_padding=1 , bias=False)
        )
        
    def forward(self, x):  # 64 100
        x = self.fc1(x).view(-1, 16 * self.model_size, 4, 4) 
        if self.verbose:
            print(x.shape) # 64 1024 4 4
        x = self.upsample1(x)
        
        if self.verbose:
            print(x.shape) # 64 512 8 8
        x = self.upsample2(x)
        
        if self.verbose:
            print(x.shape) # 64 256 16 16
        x = self.upsample3(x)
        
        if self.verbose:
            print(x.shape) # 64 128 32 32
        x = self.upsample4(x)
        
        if self.verbose:
            print(x.shape) # 64 64 64 64
        x = self.upsample5(x)
        
        if self.verbose:
            print(x.shape) # 64 2 128 128

        output = F.tanh(x) 
        return output
    
    
    
class SpecGANDiscriminator(nn.Module):
    def __init__(self, model_size=64, ngpus=1, num_channels=2, shift_factor=2,
                 alpha=0.2, verbose=False):
        super(SpecGANDiscriminator, self).__init__()
        self.model_size = model_size  # d
        self.ngpus = ngpus
        self.num_channels = num_channels  # c
        self.shift_factor = shift_factor  # n
        self.alpha = alpha
        self.verbose = verbose
        
        self.fc1 = nn.Linear(256 * model_size, 1)
        
        self.downsample1 = nn.Sequential(
            nn.Conv2d(num_channels, model_size, 5, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(model_size, 2*model_size, 5, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(2*model_size, 4*model_size, 5, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
        )
        self.downsample4 = nn.Sequential(
            nn.Conv2d(4*model_size, 8*model_size, 5, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
        )
        self.downsample5 = nn.Sequential(
            nn.Conv2d(8*model_size, 16*model_size, 5, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
        )
        
    def forward(self, x):
        
        if self.verbose:
            print(x.shape) # 64 1 128 128
        x = self.downsample1(x)
        
        if self.verbose:
            print(x.shape) # 64 64 64 64
        x = self.downsample2(x)
        
        if self.verbose:
            print(x.shape) # 64 128 32 32
        x = self.downsample3(x)
        
        if self.verbose:
            print(x.shape) # 64 256 16 16
        x = self.downsample4(x)
        
        if self.verbose:
            print(x.shape) # 64 512 8 8
        x = self.downsample5(x)
        
        if self.verbose:
            print(x.shape) # 64 1024 4 4

        x = x.view(-1,256*self.model_size) # 64 16384 

        output = self.fc1(x) # 64 1
        
        if self.verbose:
            print(x.shape)

        return output