# Import Package

In [1]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F

# Model Class

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        # For encoder
        self.encode = nn.Sequential(
            nn.Linear(12,20),
            nn.ReLU(inplace = False),
            nn.Linear(20, 24),
            nn.ReLU(inplace = False)
        )
        
        self.l_mu = nn.Linear(in_features=24, out_features=12)
        self.l_var = nn.Linear(in_features=24,out_features=12)
    
    def encoder(self, x):
        x = x.view(-1,12)
        h = self.encode(x)
        mu = self.l_mu(h)
        log_var = self.l_var(h)
        
        return mu, log_var  # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)  # return z sample
    
    def auxilary(self,z):
        z_p = Variable(z.data.new(z.size()).normal_())
        return z_p
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.sampling(mu, log_var)
        z_p = self.auxilary(z)
        return mu, log_var, z,z_p

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decode = nn.Sequential(
            nn.Linear(12,16),
            nn.ReLU(inplace = False),
            nn.Linear(16, 14),
            nn.ReLU(inplace = False)
        )
        self.output = nn.Linear(in_features=14, out_features=12)
    
    def forward(self, z):
        h = self.decode(z)
        h = self.output(h)
        return h
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(12,20),
            nn.ReLU(inplace = False),
            nn.Linear(20, 10),
            nn.ReLU(inplace = False),
            nn.Linear(10, 1),
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = x.view(-1,12)
        validity = self.discriminator(x)
        return self.sigmoid(validity)
    
    def similarity(self, x):
        x = x.view(-1,12)
        h = self.discriminator(x)
        return h

In [3]:
def discriminator_loss(recon_data, sample_data, real_data, REAL_LABEL, FAKE_LABEL):
    recon_loss = F.binary_cross_entropy(recon_data, FAKE_LABEL)
    sample_loss = F.binary_cross_entropy(sample_data, FAKE_LABEL)
    real_loss = F.binary_cross_entropy(real_data, REAL_LABEL)
    return recon_loss + sample_loss + real_loss

# Load Data

In [4]:
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm 

In [5]:
train_input = np.load('wireless_gan/train_input.npy')
train_gt = np.load('wireless_gan/train_gt.npy')
valid_input = np.load('wireless_gan/valid_input.npy')
valid_gt = np.load('wireless_gan/valid_gt.npy')
test_input = np.load('wireless_gan/test_input.npy')
test_gt = np.load('wireless_gan/test_gt.npy')

In [6]:
class SignalDataset(Dataset):
    def __init__(self, input_data,ground_truth):
        self.input_data =input_data
        self.ground_truth = ground_truth

    def __getitem__(self, index):
        input_data = np.array(self.input_data[index], dtype=np.float32)
        ground_truth = np.array(self.ground_truth[index], dtype=np.float32)
        return input_data,ground_truth

    def __len__(self):
        return len(self.input_data)

In [7]:
train_dataset = SignalDataset(train_input[:240000],train_gt[:240000])
valid_dataset = SignalDataset(valid_input,valid_gt)

train_loader = DataLoader(dataset=train_dataset, batch_size=40, shuffle=True, num_workers=0)
valid_loader = DataLoader(dataset=train_dataset, batch_size=24, shuffle=True, num_workers=0)

# Train Model

In [8]:
from torch.autograd import Variable

In [9]:
if torch.cuda.is_available():
    Tensor = torch.cuda.FloatTensor
    encoder = Encoder().cuda()
    decoder = Decoder().cuda()
    discriminator = Discriminator().cuda()
else:
    Tensor = torch.FloatTensor
    encoder = Encoder()
    decoder = Decoder()
    discriminator = Discriminator()

In [10]:
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(decoder.parameters(), lr=0.0001)
optimizer_Dis = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

In [11]:
# lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size = 1, gamma=0.9)

In [None]:
for epoch in range(0,10):
    train_dec_loss = 0
    train_dis_loss = 0
    train_enc_loss = 0
    for i, (input_data,ground_truth) in tqdm(enumerate(train_loader)):
        valid = Variable(Tensor(len(input_data), 1).fill_(1.0), requires_grad=False).type(torch.cuda.FloatTensor)
        fake = Variable(Tensor(len(input_data), 1).fill_(0.0), requires_grad=False).type(torch.cuda.FloatTensor)
        input_data = Variable(input_data.type(Tensor)).type(torch.cuda.FloatTensor)
        ground_truth = Variable(ground_truth.type(Tensor)).type(torch.cuda.FloatTensor)
        
        mu, log_var, z,z_p  = encoder(input_data)
        predicts = decoder(z)
        predicts_p = decoder(z_p)
        
        optimizer_Dis.zero_grad()
        X = discriminator(predicts)
        X_p = discriminator(predicts_p)
        X_real = discriminator(ground_truth)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        dis_loss = discriminator_loss(X, X_p, X_real, valid, fake)
        dis_loss.backward(retain_graph=True)
        optimizer_Dis.step()
        
        optimizer_D.zero_grad()
        X = discriminator(predicts)
        X_p = discriminator(predicts_p)
        X_real = discriminator(ground_truth)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        dis_loss = discriminator_loss(X, X_p, X_real, valid, fake)
        rec_loss = ((X_sim - X_data) ** 2).mean()
        decoder_loss = 25 * rec_loss - dis_loss
        decoder_loss.backward(retain_graph=True)
        optimizer_D.step()
        
        optimizer_E.zero_grad()
        mu, log_var, z,z_p  = encoder(input_data)
        predicts = decoder(z)
        predicts_p = decoder(z_p)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        rec_loss = ((X_sim - X_data) ** 2).mean()
        mu, log_var, z,z_p  = encoder(input_data)
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        KLD = KLD / len(input_data) * 12
        encoder_loss = KLD + 5 * rec_loss
        encoder_loss.backward()
        optimizer_E.step()
        
        train_dec_loss = train_dec_loss + decoder_loss.item()
        train_dis_loss = train_dis_loss + dis_loss.item()
        train_enc_loss = train_enc_loss + encoder_loss.item()
        
    print('====> Epoch: {} Average Decoder loss: {:.4f} Average Discriminator loss: {:.4f} Average Ecoder loss: {:.4f}'.format(
          epoch,train_dec_loss / len(input_data),train_dis_loss / len(input_data),train_enc_loss / len(input_data)))

4332it [00:53, 80.72it/s]