In [6]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from tensorboardX import SummaryWriter
import datetime
import os

In [7]:
# config
dtype = torch.float32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

max_epochs = 1500

local_img_dim = 10 # pixels of local map length
y_dim = local_img_dim * local_img_dim +1
x_dim = 3
dim_latent = 3
dim_hidden = 512

train_valid_ratio = 0.95 # % for train
batchSize = 256

dropout = 0.2
learning_rate = 0.0001
kl_weight = 1

# loss function
criteria = torch.nn.MSELoss() # 1/n*Sum((xi-yi)**2)

#####################################
#------------Suggestions------------#
# 1. tune KL weight --> size of latent space --> add new hidden layers
# 2. tune KL weight from small value to big

In [8]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.len = len(data)
        self.list_IDs = np.arange(0,len(data))
        self.data = data  # npy data

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, item):
        ID = self.list_IDs[item]
        x = self.data[ID]
        return x

In [9]:
# data loading
data = np.load('sample_0418.npy') # numpy.ndarray
# print("number of terrain samples:", len(data))
data_training = data[0:int(train_valid_ratio * len(data))]
data_validation = data[int(train_valid_ratio * len(data)):]
data_training_set = Dataset(data_training)
data_training_generator = torch.utils.data.DataLoader(data_training_set, batch_size=batchSize, shuffle=True, num_workers=2)
data_validation_set = Dataset(data_validation)
data_validation_generator = torch.utils.data.DataLoader(data_validation_set, batch_size=batchSize, shuffle=True, num_workers=2)

# print(len(data_training_set))
# print(type(data_training_set[50]))
# print(data_training_set[50])

In [12]:
# define CVAE model
class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__() # init nn.Module
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.linear_mu = torch.nn.Linear(H, D_out)
        self.linear_logvar = torch.nn.Linear(H, D_out)
        self.dropout = torch.nn.Dropout(p=dropout) # p – probability of an element to be zeroed
        self.H = H
    # Encoder Q netowrk, approximate the latent feature with gaussian distribution,output mu and logvar
    # Q(z|x) = N(z|mu(X),sigma(X))
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        return self.linear_mu(x), self.linear_logvar(x)
    
class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super().__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.linear3 = torch.nn.Linear(H, D_out)
        self.dropout = torch.nn.Dropout(p=dropout)
    # Decoder P network, sampling from normal distribution and build the reconstruction
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        return self.linear3(x)

class CVAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    def _reparameterize(self, mu, logvar):
        eps = torch.randn_like(mu)
        return mu + torch.exp(logvar / 2) * eps
    def forward(self, state, cond):
        x_in = torch.cat((state, cond),1)
        mu, logvar = self.encoder(x_in)
        z = self._reparameterize(mu, logvar)
        z_in = torch.cat((z, cond), 1)
        return mu, logvar, self.decoder(z_in)

In [13]:
# apply CVAE
encoder = Encoder(x_dim + y_dim, dim_hidden, dim_latent)
decoder = Decoder(dim_latent + y_dim, dim_hidden, x_dim)
model = CVAE(encoder, decoder)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
# start logging
description = "_description"
folder = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + description
writer = SummaryWriter("runs/" + folder)
os.makedirs(name="policies", exist_ok=True)
os.makedirs(name="policies/" + folder)

In [60]:
# main training process
try:
    print("========= start training ============")
    for epoch in range(max_epochs):
        epoch_loss = []
        for local_batch in data_training_generator:
            x_train = local_batch[:,  0  :x_dim].type(dtype).to(device)
            y_train = local_batch[:,x_dim:     ].type(dtype).to(device)
            optimizer.zero_grad()
            mu, logvar, recon_batch = model(x_train,y_train) #mu 256*3 ... 
            recon_loss = criteria(recon_batch, x_train)      # recon_loss 1*1
            kl_loss = 1000*kl_weight * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar, 1) # kl_loss 1*256, 一整个batch的kl loss
            
            loss = torch.mean(recon_loss + kl_loss, 0)
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
            
            KLLOSS = torch.mean(kl_loss,0)
            print(KLLOSS)
            KLLOSS_2 = loss.item() - recon_loss.item()
            print(KLLOSS_2)
        writer.add_scalar('training/loss', np.mean(np.array(epoch_loss)), epoch)
        #print("epoch: {}, Loss: {}".format(epoch,np.mean((np.array(epoch_loss))),))
     
    # validation
        if epoch % 1 == 0 and epoch > 0:
            with torch.no_grad():
                recon_loss_viz, kl_loss_viz, loss_viz = [], [], []
                for local_batch in data_validation_generator:
                    x_validate = local_batch[:, 0:x_dim].type(dtype).to(device)
                    y_validate = local_batch[:, x_dim:].type(dtype).to(device)
                    mu, logvar, recon_batch = model(x_validate, y_validate)
                    recon_loss = criteria(recon_batch, x_validate)
                    kl_loss =  kl_weight * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar, 1)
                    loss = torch.mean(recon_loss + kl_loss, 0)
                    recon_loss_viz.append(recon_loss.item())
                    loss_viz.append(loss.item())
                    kl_loss_viz.append(loss.item() - recon_loss.item())
                writer.add_scalar('validation/recon_loss', np.mean(np.array(recon_loss_viz)), epoch)
                writer.add_scalar('validation/kl_loss', np.mean((np.array(kl_loss_viz))), epoch)
                writer.add_scalar('validation/loss', np.mean((np.array(loss_viz))), epoch)
                #print("epoch: {}, Loss: {}, Recon_loss: {}, KL_loss: {}".format(epoch,np.mean((np.array(loss_viz))),
                                                                               # np.mean(np.array(recon_loss_viz)),
                                                                               # np.mean((np.array(kl_loss_viz)))))
    # save intermediate model
        #if epoch % 1000 == 0 and epoch > 0:
            # print("Saving intermediate model.")
            # save_path = "policies/" + "final.pt"
            # torch.save(model, f=save_path)

    # save final model
    # print("Finishing training, Saving final model.")
    # save_path = "policies/final.pt"
    # torch.save(model, f=save_path)
    print("========= training end ============")
    

except KeyboardInterrupt:
    print("Training interupted.")
    # save_path = "policies/final.pt"
    # torch.save(model, f=save_path)
writer.close()

# tensorboard --logdir=./

tensor(1.5679, device='cuda:0', grad_fn=<MeanBackward1>)
1.5678539276123047


In [31]:
cnt

75

19200

In [48]:
-3.7253e-08+3.9116e-08+-7.4506e-09

-5.587600000000005e-09