<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/VAE/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torch.distributions as td
from torch.distributions import Categorical, Normal, Bernoulli


https://github.com/kampta/pytorch-distributions/blob/master/gaussian_vae.py

In [2]:
batch_size = 128
# if : device = torch.device('cuda')
# else: device = torch.device('cpu')
# print(device)

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device = 'cuda'

In [5]:
# device = torch.device("cuda"
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True}

# print(device)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [6]:
device

'cuda'

In [7]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
    def encode(self, x):
        # print('x.shape: ', x.shape) # x.shape:  torch.Size([128, 784])

        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1) # mu, logvar

    def decode(self, z):
        # print('z shape: ', z.shape) # [20, 784]
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
        # return self.fc4(h3)

    def reparameterize(self, mu, logvar):
        std = logvar.exp().pow(0.5)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # print('x.shape in forward : ', x.shape)
        # x.shape in forward :  torch.Size([128, 1, 28, 28])

        # print('after view x.view(-1, 784): ', x.view(-1, 784).shape)
        # after view x.view(-1, 784):  torch.Size([128, 784])

        mu, logvar = self.encode(x.view(-1, 784))
        # std = logvar.exp().pow(0.5)
        # eps = torch.randn_like(std)
        # q_z = td.normal.Normal(mu, std)
        # z = q_z.rsample()

        # return self.decode(z), q_z
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.cuda(), x.view(-1, 784).cuda(), reduction='sum')
    # p_z = td.normal.Normal(torch.zeros_like(q_z.loc),
    #                        torch.ones_like(q_z.scale))
    # KLD = td.kl_divergence(q_z, p_z).sum()
    KLD  = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [9]:

def train(epoch, log_interval=100):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        # print('recon_batch.size: ', recon_batch.shape)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [10]:
# train(1000)

In [17]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            # recon_batch, q_z = model(data)
            recon_batch, mu, logvar = model(data)
        # print('recon_batch.size: ', recon_batch.shape)
            # loss = loss_function(recon_batch, data, mu, logvar)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                       recon_batch.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                           'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [18]:
import os
def create_path(path):
    # Check whether the specified path exists or not
    isExist = os.path.exists(path)
    #printing if the path exists or not
    print(path, ' folder is in directory: ', isExist)
    if not isExist:
    # Create a new directory because it does not exist
        os.makedirs(path)
        print(path, " is created!")
create_path('results/')

results/  folder is in directory:  False
results/  is created!


In [None]:
epoch = 10000
for epoch in range(1, epoch + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 111.5362
====> Test set loss: 109.6628
====> Epoch: 2 Average loss: 109.7891
====> Test set loss: 108.2379
====> Epoch: 3 Average loss: 108.6332
====> Test set loss: 107.6086
====> Epoch: 4 Average loss: 107.7365
====> Test set loss: 107.2493
====> Epoch: 5 Average loss: 107.1640
====> Test set loss: 106.3840
====> Epoch: 6 Average loss: 106.5978
====> Test set loss: 105.8692
====> Epoch: 7 Average loss: 106.2497
====> Test set loss: 105.6901
====> Epoch: 8 Average loss: 105.8086
====> Test set loss: 105.3774
====> Epoch: 9 Average loss: 105.5380
====> Test set loss: 105.0633
====> Epoch: 10 Average loss: 105.2821
====> Test set loss: 104.9243
====> Epoch: 11 Average loss: 105.0265
====> Test set loss: 104.6767
====> Epoch: 12 Average loss: 104.8175
====> Test set loss: 104.4889
====> Epoch: 13 Average loss: 104.6429
====> Test set loss: 104.4698
====> Epoch: 14 Average loss: 104.4288
====> Test set loss: 104.0110
====> Epoch: 15 Average loss: 104.2953
====

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

ax = plt.axes(projection='3d')

# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');


In [None]:
def f(x, y):
    return np.sin(np.sqrt(x ** 2 + y ** 2))
def f(x, y):
    return -1 * np.sqrt(x ** 2 + y ** 2)


x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
ax.scatter([0, 1], [1, 1], [2, 1])

In [None]:
import numpy as np

from collections import deque

import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical, Normal, Bernoulli

# Gym
# import gym
# import gym_pygame

# Hugging Face Hub
# from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.
import imageio

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class Policy(nn.Module):
    def __init__(self, s_size, a_size, h_size):
        super(Policy, self).__init__()
        self.s_size = s_size
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, h_size*2)
        self.fc3_1 = nn.Linear(h_size*2, a_size)
        self.fc3_2 = nn.Linear(h_size*2, a_size)


    def forward(self, x):
        # x = x.to(device).astype(np.float32)
        # x = x.long().to(device)
        # print('x: ', x.shape)
        if x.shape == (1, self.s_size):
            x = x.float().to(device)
        else: x = x.float().unsqueeze(0).to(device)

        # print(x)
        # print(x.shape)
        # print(type(x))
        x = x.to(device)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = self.fc3_1(x)
        logvar = self.fc3_2(x)
        # print(type(mu), type(logvar))
        return mu, logvar
        # return F.softmax(x, dim=1)

    # def act_categorical(self, state):
    #     state = torch.from_numpy(state).float().unsqueeze(0).to(device)
    #     probs = self.forward(state).cpu()
    #     m = Categorical(probs)
    #     action = m.sample()
    #     return action.item(), m.log_prob(action)

    def act(self, state, eta=1e-2):
        # print(state)
        if isinstance(state, torch.Tensor):
            state = state.to(device)
        else:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        # state = torch.from_numpy(state).float().to(device)
        # print(state, type(state), state.shape)
        # print(state.device)
        mu, logvar = self.forward(state)
        # mu = mu.cpu()
        # logvar = logvar.cpu()
        std = logvar.exp().pow(0.5)
        p = Normal(mu, std)
        # p_z = td.normal.Normal(torch.zeros_like(q_z.loc), torch.ones_like(q_z.scale))
        action = p.rsample()
        log_action = p.log_prob(action)
        # m = Normal(*probs)
        # action = m.rsample()
        # action = mu + eta * logvar
        return action, log_action

In [None]:
# import torch.distributions as td
# def forward(self, x):
#         mu, logvar = self.encode(x.view(-1, 784))

#         std = logvar.exp().pow(0.5)         # logvar to std
#         q_z = td.normal.Normal(mu, std)     # create a torch distribution
#         z = q_z.rsample()                   # sample with reparameterization

#         return self.decode(z), q_z

In [None]:
# std = logvar.exp().pow(0.5) # logvar to std
# q_z = td.normal.Normal(mu, std)


In [None]:
s

In [None]:
def reinforce(policy, optimizer, env, n_training_episodes=1000,
              max_t=100, gamma=1.0, print_every=10):
    scores_deque = deque(maxlen=100)
    scores = []
    for i in range(1, n_training_episodes+1):
        saved_log_probs = []
        rewards = []
        state = env.reset()
        # print('state here', state)
        for t in range(max_t):  # one episode
            action, log_prob = policy.act(state)
            saved_log_probs.append(log_prob)
            state, reward, done = env.step(action)
            rewards.append(reward)
            if done:
                break
        scores_deque.append(sum(rewards))
        scores.append(sum(rewards)) # for one episode, total R

        returns = deque(maxlen=max_t)
        n_steps = len(rewards)
        for t in range(n_steps)[::-1]:
            disc_return_t = (returns[0] if len(returns)>0 else 0)
            returns.appendleft(gamma*disc_return_t + rewards[t])
        eps = np.finfo(np.float32).eps.item()
        ## eps is the smallest representable float, which is
        # added to the standard deviation of the returns to avoid numerical instabilities
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)

        policy_loss = []
        for log_prob, disc_return in zip(saved_log_probs, returns):
            policy_loss.append(-log_prob * disc_return)

        policy_loss = torch.cat(policy_loss).sum()
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        if i % print_every == 0:
            print('Episode {}\tAverage Score: {:.2f}'.format(
                i, np.mean(scores_deque)))
    return scores


In [None]:
new_s = s.unsqueeze(0)
new_s.shape == (1, 1, 2)

In [None]:
class env:
    def __init__(self, f):
        # self.observation_space.n = 2
        # self.action_space.n = 2
        # state =
        self.state = torch.from_numpy(np.random.rand(2))
        self.f = f
        self.reward = 0
        self.done = False
        # self.info = None

    def reset(self):
        # self.state = np.random.rand(2)
        self.state = torch.from_numpy(np.random.rand(2))
        self.done = False
        return self.state

    def step(self, action, step=1e-3, error=1e-6):
        done = False
        # print('state:' , self.state, type(self.state), self.state.shape)
        # print('action: ', action.shape, type(action))
        # print('self.state', self.state, type(self.state), isinstance(self.state, torch.Tensor))
        if isinstance(self.state, torch.Tensor):
            current_state = self.state
        else: # numpy
            current_state = torch.from_numpy(self.state).double().unsqueeze(0).to(device)

        action = action.to(device)
        current_state = current_state.to(device)
        # print(action*step)
        state = current_state + action*step
        # print(state, type(state))
        reward = 0
        # try:
        #     reward = self.f(state[0], state[1])
        # except: reward = self.f(state[0][0], state[1][0])
        try:
            reward = self.f(state.detach().cpu().numpy()[0][0],
                            state.detach().cpu().numpy()[0][1])
        except:
            print('cannot calculate reward', state)
        # print(state)
        # print(abs(-1 * sum(state - current_state)[0]))
        if abs(-1 * sum(state - current_state)[0]) < error:
            done = True
        self.done = done
        self.state = state
        self.reward = reward
        return state, reward, done


In [None]:
# env_1 = env(f)

# a = torch.from_numpy(np.random.rand(2)).float().unsqueeze(0).to(device)
# s, r, d = env_1.step(a)
# r
# # s.items()

In [None]:
# state.detach().cpu().numpy()[0]

In [None]:
# state = s
# f(state.detach().cpu().numpy()[0][0], state.detach().cpu().numpy()[0][1])

In [None]:
# isinstance(s.detach().cpu().numpy()[0][1], 'float')

In [None]:
env_1 = env(f)
#
p = Policy(2, 2, 16).to(device)
cartpole_optimizer = optim.Adam(p.parameters(),
                                lr=1e-2)
scores = reinforce(p, cartpole_optimizer, env_1)

In [None]:
s

In [None]:
p.forward(s)

In [None]:
b = env(f)
for i in range(10):
    s, r, _ = b.step(a)
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    # ax.contour3D(X, Y, Z, 50, cmap='binary')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z');
    ax.scatter(s[0], s[1], r)
    assert r == f(s[0], s[1])


In [None]:
a + a * (1e-3)
# a