# Import Package

In [1]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F

# Model Class

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        # For encoder
        self.encode = nn.Sequential(
            nn.Linear(12,32),
            nn.LeakyReLU(),
            nn.Linear(32,64),
            nn.LeakyReLU(),
            nn.Linear(64,128),
            nn.LeakyReLU(),
            nn.Linear(128,70),
            nn.LeakyReLU(),
            nn.Linear(70,35),
            nn.LeakyReLU(),
            nn.Linear(35,16),
            nn.LeakyReLU()
        )
        
        self.l_mu = nn.Linear(in_features=16, out_features=10)
        self.l_var = nn.Linear(in_features=16,out_features=10)
    
    def encoder(self, x):
#         x = x.view(-1,12)
        h = self.encode(x)
        mu = self.l_mu(h)
        log_var = self.l_var(h)
        
        return mu, log_var  # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)  # return z sample
    
    def auxilary(self,z):
        z_p = Variable(z.data.new(z.size()).normal_())
        return z_p
    
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.sampling(mu, log_var)
        z_p = self.auxilary(z)
        return mu, log_var, z,z_p

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decode = nn.Sequential(
            nn.Linear(10,40),
            nn.LeakyReLU(),
            nn.Linear(40,80),
            nn.LeakyReLU(),
            nn.Linear(80, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
        )
        self.output = nn.Linear(in_features=32, out_features=12)
    
    def forward(self, z):
        h = self.decode(z)
        h = self.output(h)
        return h
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(12,64),
            nn.LeakyReLU(),
            nn.Linear(64, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64,1),
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
#         x = x.view(-1,12)
        validity = self.discriminator(x)
        return self.sigmoid(validity)
    
    def similarity(self, x):
#         x = x.view(-1,12)
        h = self.discriminator(x)
        return h

class Position_Discriminator(nn.Module):
    def __init__(self):
        super(Position_Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(4,16),
            nn.LeakyReLU(),
            nn.Linear(16, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32,2),
        )
        
    def forward(self, x):
        validity = self.discriminator(x)
        return validity

In [3]:
def discriminator_loss(recon_data, sample_data, real_data, REAL_LABEL, FAKE_LABEL):
    recon_loss = F.binary_cross_entropy(recon_data, FAKE_LABEL)
    sample_loss = F.binary_cross_entropy(sample_data, FAKE_LABEL)
    real_loss = F.binary_cross_entropy(real_data, REAL_LABEL)
    return recon_loss + sample_loss + real_loss

In [4]:
def position_discriminator_loss(predict, target):
    loss_fuc = nn.MSELoss()
    loss = loss_fuc(predict, target)
    return loss

## Plot loss(JPG)

In [5]:
def plot_loss(train_dec_loss_li,train_dis_loss_li,train_enc_loss_li,train_Q_loss_li):#,valid_dec_loss_li,valid_dis_loss_li,valid_enc_loss_li,test_dec_loss_li,test_dis_loss_li,test_enc_loss_li):
#     plt.plot(train_dec_loss_li, color='b', label='train_dec_loss')
#     plt.plot(train_dis_loss_li, color='g', label='train_dis_loss')
#     plt.plot(train_enc_loss_li, color='r', label='train_enc_loss')
#     plt.plot(valid_dec_loss_li, color='cadetblue', label='valid_dec_loss')
#     plt.plot(valid_dis_loss_li, color='coral', label='valid_dis_loss')
#     plt.plot(valid_enc_loss_li, color='deepskyblue', label='valid_enc_loss')
#     plt.plot(test_dec_loss_li, color='fuchsia', label='valid_dec_loss')
#     plt.plot(test_dis_loss_li, color='hotpink', label='valid_dis_loss')
#     plt.plot(test_enc_loss_li, color='lavenderblush', label='valid_enc_loss')
#     plt.legend()
#     plt.savefig('loss_png/loss_all-{}.jpg'.format(str('')))
#     plt.close()

    plt.plot(train_dec_loss_li, color='b', label='train_dec_loss')
    plt.plot(train_dis_loss_li, color='g', label='train_dis_loss')
    plt.plot(train_enc_loss_li, color='r', label='train_enc_loss')
    plt.plot(train_Q_loss_li, color='deepskyblue', label='train_pos_loss')
    plt.legend()
    plt.savefig('loss_png/train_loss-{}.jpg'.format(str('')))
    plt.close()

#     plt.plot(valid_dec_loss_li, color='cadetblue', label='valid_dec_loss')
#     plt.plot(valid_dis_loss_li, color='coral', label='valid_dis_loss')
#     plt.plot(valid_enc_loss_li, color='deepskyblue', label='valid_enc_loss')
#     plt.legend()
#     plt.savefig('loss_png/valid_loss-{}.jpg'.format(str('')))
#     plt.close()
    
#     plt.plot(test_dec_loss_li, color='fuchsia', label='test_dec_loss')
#     plt.plot(test_dis_loss_li, color='hotpink', label='test_dis_loss')
#     plt.plot(test_enc_loss_li, color='lavenderblush', label='test_enc_loss')
#     plt.legend()
#     plt.savefig('loss_png/test_loss-{}.jpg'.format(str('')))
#     plt.close()

## L1_loss Calculation

In [6]:
def l1loss_cal(data_loader):
    l1_loss = []
    for i, (input_data,ground_truth) in enumerate(data_loader):
        input_data = Variable(input_data[:,:12].type(Tensor)).type(torch.cuda.FloatTensor)
        ground_truth = Variable(ground_truth[:,:12].type(Tensor)).type(torch.cuda.FloatTensor)

        mu, log_var, z,z_p  = encoder(input_data)
        predicts = decoder(z)
        l1_loss.append(abs(float(torch.mean(ground_truth.reshape(-1,4,3)[0,:,1]-predicts.reshape(-1,4,3)[0,:,1]))))
    return sum(l1_loss)/len(l1_loss)

In [7]:
def plot_l1_loss(train_l1):#,valid_l1,test_l1):
#     plt.plot(train_l1, color='darkred', label='Train_l1_loss')
#     plt.plot(valid_l1, color='deeppink', label='Valid_l1_loss')
#     plt.plot(test_l1, color='gold', label='Test_l1_loss')
#     plt.legend()
#     plt.savefig('loss_png_l1/all_l1_loss-{}.jpg'.format(str('')))
#     plt.close()
    
    plt.plot(train_l1, color='darkred', label='Train_l1_loss')
    plt.legend()
    plt.savefig('loss_png_l1/train_l1_loss-{}.jpg'.format(str('')))
    plt.close()
    
#     plt.plot(valid_l1, color='deeppink', label='Valid_l1_loss')
#     plt.legend()
#     plt.savefig('loss_png_l1/valid_l1_loss-{}.jpg'.format(str('')))
#     plt.close()
    
#     plt.plot(test_l1, color='gold', label='Valid_l1_loss')
#     plt.legend()
#     plt.savefig('loss_png_l1/test_l1_loss-{}.jpg'.format(str('')))
#     plt.close()

# Load Data

In [8]:
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm 

## DataLoader Class

In [9]:
class SignalDataset(Dataset):
    def __init__(self, input_data,ground_truth):
        self.input_data =input_data
        self.ground_truth = ground_truth

    def __getitem__(self, index):
        input_data = np.array(self.input_data[index], dtype=np.float32)
        ground_truth = np.array(self.ground_truth[index], dtype=np.float32)
        return input_data,ground_truth

    def __len__(self):
        return len(self.input_data)

### Load Data(Don't Use Tranpose)

In [10]:
train_input = np.load('train_data/model_pos/train_input_pos.npy')
train_input = train_input.reshape(-1,14)
train_gt = np.load('train_data/model_pos/train_gt_pos.npy')
train_gt = train_gt.reshape(-1,14)
# print(train_gt[0])
# valid_input = np.load('valid_input.npy').reshape(-1,3,4)
# valid_gt = np.load('valid_gt.npy').reshape(-1,3,4)
# test_input = np.load('test_input.npy')
# test_gt = np.load('test_gt.npy')

In [11]:
# train_input = train_input.transpose(0,2,1)
# train_gt = train_gt.transpose(0,2,1)
# valid_input = valid_input.transpose(0,2,1)
# valid_gt = valid_gt.transpose(0,2,1)
# test_input = test_input.transpose(0,1,3,2)
# test_gt = test_gt.transpose(0,1,3,2)

In [12]:
# test_input = test_input[:54,:1200,:,:]
# test_input = test_input.reshape(-1,4,3)[:5]
# test_gt = test_gt[:54,:1200,:,:]
# test_gt = test_gt.reshape(-1,4,3)[:5]

In [13]:
train_dataset = SignalDataset(train_input,train_gt)
# valid_dataset = SignalDataset(valid_input,valid_gt)
# test_dataset = SignalDataset(test_input,test_gt)

train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True, num_workers=0)
# valid_loader = DataLoader(dataset=valid_dataset, batch_size=10000, shuffle=True, num_workers=0)
# test_loader = DataLoader(dataset=test_dataset, batch_size=10000, shuffle=True, num_workers=0)

# Train Model

In [14]:
from torch.autograd import Variable
import matplotlib.pyplot as plt

In [15]:
if torch.cuda.is_available():
    Tensor = torch.cuda.FloatTensor
    encoder = Encoder().cuda()
    decoder = Decoder().cuda()
    discriminator = Discriminator().cuda()
    position_Discriminator = Position_Discriminator().cuda()
#     encoder_pt = 'wireless_gan/1.12/encoder.pt'
#     encoder.load_state_dict(torch.load(encoder_pt))
#     decoder_pt = 'wireless_gan/1.12/decoder.pt'
#     decoder.load_state_dict(torch.load(decoder_pt))
#     discriminator_pt = 'wireless_gan/1.12/discriminator.pt'
#     discriminator.load_state_dict(torch.load(discriminator_pt))
else:
    Tensor = torch.FloatTensor
    encoder = Encoder()
    decoder = Decoder()
    discriminator = Discriminator()
    position_Discriminator = Position_Discriminator()

In [16]:
train_dec_loss_li,train_dis_loss_li,train_enc_loss_li,train_Q_loss_li,valid_dec_loss_li,valid_dis_loss_li,valid_enc_loss_li = [],[],[],[],[],[],[]
test_dec_loss_li,test_dis_loss_li,test_enc_loss_li = [],[],[]
train_l1,valid_l1,test_l1 = [],[],[]

In [17]:
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(decoder.parameters(), lr=0.0001)
optimizer_Dis = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
optimizer_Dis_Pos = torch.optim.Adam(position_Discriminator.parameters(), lr=0.0001)
# lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size = 1, gamma=0.9)

## Train model process detail function

In [18]:
def model_process(mode,dataloader,epoch):
    dec_loss,dis_loss,enc_loss,flag = 0,0,0,0
    for i, (x,y) in tqdm(enumerate(dataloader)):
        valid = Variable(Tensor(len(x), 1).fill_(1.0), requires_grad=False).type(torch.cuda.FloatTensor)
        fake = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).type(torch.cuda.FloatTensor)
        input_data = Variable(x[:,:12].type(Tensor)).type(torch.cuda.FloatTensor)
        ground_truth = Variable(y[:,:12].type(Tensor)).type(torch.cuda.FloatTensor)
        
        mu, log_var, z,z_p  = encoder(input_data)
        predicts = decoder(z)
        predicts_p = decoder(z_p)
        
        # Train Discriminator 
        X = discriminator(predicts)
        X_p = discriminator(predicts_p)
        X_real = discriminator(ground_truth)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        dis_loss = discriminator_loss(X, X_p, X_real, valid, fake)
        if mode =='train':
            if epoch % 2 == 0 and flag==0:
                if epoch == 0 and i>10:
                    flag = 1
                optimizer_Dis.zero_grad()
                dis_loss.backward(retain_graph=True)
                optimizer_Dis.step()
        
        # Train Position_Discriminator 
        ground_truth_pos = Variable(y[:,12:].type(Tensor)).type(torch.cuda.FloatTensor)
        pos_predicts = position_Discriminator(predicts.view(-1,4,3)[:, :, 1])
        pos_dis_loss = position_discriminator_loss(pos_predicts, ground_truth_pos)
        pos_predicts_real = position_Discriminator(ground_truth.view(-1,4,3)[:, :, 1])
        pos_dis_loss = position_discriminator_loss(pos_predicts, ground_truth_pos) + position_discriminator_loss(pos_predicts_real, ground_truth_pos)
        if mode =='train':
            if epoch % 2 == 0 and flag==0:
                if epoch == 0 and i>10:
                    flag = 1
                optimizer_Dis_Pos.zero_grad()
                pos_dis_loss.backward(retain_graph=True)
                optimizer_Dis_Pos.step()
        
        # Train Decoder 
        X = discriminator(predicts)
        X_p = discriminator(predicts_p)
        X_real = discriminator(ground_truth)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        dis_loss = discriminator_loss(X, X_p, X_real, valid, fake)
        rec_loss = ((X_sim - X_data) ** 2).mean()
        l1_loss_fuc = nn.L1Loss()
        l1_loss_dec_a = l1_loss_fuc(predicts.view(-1,4,3)[:, :, 0], ground_truth.view(-1,4,3)[:, :, 0])
        l1_loss_dec_b = l1_loss_fuc(predicts.view(-1,4,3)[:, :, 2], ground_truth.view(-1,4,3)[:, :, 2])
        l1_loss_dec = l1_loss_dec_a + l1_loss_dec_b
        pos_predicts = position_Discriminator(predicts.view(-1,4,3)[:, :, 1])
        pos_dis_loss = position_discriminator_loss(pos_predicts, ground_truth_pos)
        decoder_loss = 100 * rec_loss - 2 * dis_loss + 0.1 * l1_loss_dec + pos_dis_loss
        if mode =='train':
            optimizer_D.zero_grad()
            decoder_loss.backward(retain_graph=True)
            optimizer_D.step()
        
        # Train Encoder 
        mu, log_var, z,z_p  = encoder(input_data)
        predicts = decoder(z)
        predicts_p = decoder(z_p)
        X_sim = discriminator.similarity(predicts)
        X_data = discriminator.similarity(ground_truth)
        rec_loss = ((X_sim - X_data) ** 2).mean()
        mu, log_var, z,z_p  = encoder(input_data)
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        KLD = KLD / len(input_data) * 12
        pos_predicts = position_Discriminator(predicts.view(-1,4,3)[:, :, 1])
        pos_dis_loss = position_discriminator_loss(pos_predicts, ground_truth_pos)
        encoder_loss = KLD + 10 * rec_loss + pos_dis_loss
        if mode =='train':
            optimizer_E.zero_grad()
            encoder_loss.backward()
            optimizer_E.step()
            
            
        dec_loss = dec_loss + decoder_loss.item()
        dis_loss = dis_loss + dis_loss.item()
        enc_loss = enc_loss + encoder_loss.item()
    return dec_loss,dis_loss,enc_loss,pos_dis_loss,predicts

## Start Train Epoch

In [None]:
for epoch in range(0,50):
    #Training
    print('Training')
    train_dec_loss,train_dis_loss,train_enc_loss,train_pos_loss,_ = model_process('train',train_loader,epoch)
    print('====> Epoch: {} Average Decoder loss: {:.4f} Average Discriminator loss: {:.4f} Average Ecoder loss: {:.4f} Average Position Discriminator loss: {:.4f}'.format(
          epoch+1,train_dec_loss / len(train_loader),train_dis_loss / len(train_loader),train_enc_loss / len(train_loader),train_pos_loss/ len(train_loader)))
    train_dec_loss_li.append(train_dec_loss/len(train_loader))
    train_dis_loss_li.append(train_dis_loss/len(train_loader))
    train_enc_loss_li.append(train_enc_loss/len(train_loader))
    train_Q_loss_li.append(train_pos_loss/len(train_loader))
    '''#Validing
    print('Validing')   
    valid_dec_loss,valid_dis_loss,valid_enc_loss,_ = model_process('train',valid_loader,epoch)
    print('Valid Epoch:{} ===> Average Decoder loss: {:.4f} Average Discriminator loss: {:.4f} Average Ecoder loss: {:.4f}'.format(epoch+1,
          valid_dec_loss / len(valid_loader),valid_dis_loss / len(valid_loader),valid_enc_loss / len(valid_loader)))
    valid_dec_loss_li.append(valid_dec_loss/len(valid_loader))
    valid_dis_loss_li.append(valid_dis_loss/len(valid_loader))
    valid_enc_loss_li.append(valid_enc_loss/len(valid_loader))
    
    #Testing
    print('Testing')   
    test_dec_loss,test_dis_loss,test_enc_loss,_ = model_process('test',test_loader,epoch)
    print('Test Epoch:{} ===> Average Decoder loss: {:.4f} Average Discriminator loss: {:.4f} Average Ecoder loss: {:.4f}'.format(epoch+1,
          test_dec_loss / len(test_loader),test_dis_loss / len(test_loader),test_enc_loss / len(test_loader)))
    test_dec_loss_li.append(test_dec_loss/len(test_loader))
    test_dis_loss_li.append(test_dis_loss/len(test_loader))
    test_enc_loss_li.append(test_enc_loss/len(test_loader))
    '''
    #Save Model
    if epoch%5 ==0:
        torch.save(encoder,'model/encoder.pt')
        torch.save(decoder,'model/decoder.pt')
        torch.save(discriminator,'model/discriminator.pt')
    
    #Plot loss
    plot_loss(train_dec_loss_li,train_dis_loss_li,train_enc_loss_li,train_Q_loss_li)#,valid_dec_loss_li,valid_dis_loss_li,valid_enc_loss_li,test_dec_loss_li,test_dis_loss_li,test_enc_loss_li)
    
    tr_l1 = l1loss_cal(train_loader)
#     va_l1 = l1loss_cal(valid_loader)
#     te_l1 = l1loss_cal(test_loader)
    #L1 loss Calculation and Plot L1 loss
    print('L1 loss Calculation','Train', tr_l1)#,'Valid',va_l1,'Test',te_l1)
    train_l1.append(tr_l1)
#     valid_l1.append(va_l1)
#     test_l1.append(te_l1)
    plot_l1_loss(train_l1)#,valid_l1,test_l1)

0it [00:00, ?it/s]

Training


23234it [07:38, 51.01it/s]

# Test

In [None]:
test_input = np.load('test_input.npy')
test_input = test_input.transpose(0,1,3,2)
test_input = test_input[:54,:1200,:,:]
test_input = test_input.reshape(-1,4,3)

test_gt = np.load('test_gt.npy')
test_gt = test_gt.transpose(0,1,3,2)
test_gt = test_gt[:54,:1200,:,:]
test_gt = test_gt.reshape(-1,4,3)

test_dataset = SignalDataset(test_input,test_gt)
test_loader = DataLoader(dataset=test_dataset, batch_size=64800, shuffle=False, num_workers=0)

test_dec_loss,test_dis_loss,test_enc_loss, oao = model_process('test',test_loader,0)

In [None]:
oao = oao.reshape(54,-1,4,3)
oao = oao.cpu().detach().numpy()
oao = oao.transpose(0,1,3,2)
oao = oao.astype(np.int)
# print(test_input[5])
# print(oao[5])

In [None]:
with open('test_gen.npy', 'wb') as f:
    np.save(f, oao)