# Prepare Data Set

First, a data set is loaded. Function `load_data_from_df` automatically saves calculated features to the provided data directory (unless `use_data_saving` is set to `False`). Every next run will use the saved features.

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch

In [None]:
import numpy as np

In [None]:
from src.featurization.data_utils import load_data_from_df, construct_loader

In [None]:
batch_size = 64

# Formal charges are one-hot encoded to keep compatibility with the pre-trained weights.
# If you do not plan to use the pre-trained weights, we recommend to set one_hot_formal_charge to False.
X, y = load_data_from_df('data/freesolv/freesolv.csv', one_hot_formal_charge=True)
data_loader = construct_loader(X, y, batch_size)

You can use your data, but the CSV file should contain two columns as shown below:

In [None]:
df = pd.read_csv('data/freesolv/freesolv.csv')

# Prepare Model

In [None]:
from src.transformer import make_model

In [None]:
d_atom = X[0][0].shape[1]  # It depends on the used featurization.

model_params = {
    'd_atom': d_atom,
    'd_model': 1024,
    'N': 8,
    'h': 16,
    'N_dense': 1,
    'lambda_attention': 0.33, 
    'lambda_distance': 0.33,
    'leaky_relu_slope': 0.1, 
    'dense_output_nonlinearity': 'relu', 
    'distance_matrix_kernel': 'exp', 
    'dropout': 0.0,
    'aggregation_type': 'mean'
}

model = make_model(**model_params)

# Load Pretrained Weights (optional)

If you want to use the pre-trained weights to train your model, **you should not change model parameters in the cell above**.

In [None]:
pretrained_name = 'pretrained_weights.pt'  # This file should be downloaded first (See README.md).
pretrained_state_dict = torch.load(pretrained_name)

In [None]:
model_state_dict = model.state_dict()
for name, param in pretrained_state_dict.items():
    if 'generator' in name:
         continue
    if isinstance(param, torch.nn.Parameter):
        param = param.data
    model_state_dict[name].copy_(param)

# Smoke Run

In [None]:
batch = next(iter(data_loader))

In [None]:
class Molecule:
    def __init__(self, distances_matrix, adj_matrix, edges_att):
        self.distances_matrix = distances_matrix
        self.adj_matrix = adj_matrix
        self.edges_att = edges_att

In [None]:
# adjacency_matrix, node_features, distance_matrix, y = batch
# batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
# output = model(node_features, node_features, batch_mask, batch_mask, Molecule(adjacency_matrix, distance_matrix, None))

# Optimizer

In [None]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

# TRAINING

In [None]:
# Train the simple copy task.
criterion = torch.nn.MSELoss()
opt = NoamOpt(1024, 1, 400,
                    torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), eps=1e-9))

model.train()
for epoch in range(10):
    running_loss = 0
    for i, batch in enumerate(data_loader):
        adjacency_matrix, node_features, distance_matrix, y = batch
        batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
        
        opt.optimizer.zero_grad()
        output = model(node_features, node_features, batch_mask, batch_mask, 
                       Molecule(adjacency_matrix, distance_matrix, None))

        loss = criterion(node_features,output)
        loss.backward()
        opt.optimizer.step()
        running_loss += loss.item()
        mean_loss = running_loss/(i+1)
        print(mean_loss)

# Gumble Softmax

In [None]:
import torch.nn.functional as F

def loss_function(recon_x, x, qy):
    CE = F.cross_entropy(recon_x, x, size_average=False) / x.shape[0]

    log_ratio = torch.log(qy * 1024 + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()

    return CE + KLD

In [None]:
# Train the simple copy task.
opt = NoamOpt(1024, 1, 400,
                    torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.98), eps=1e-9))

model.train()
for epoch in range(10):
    running_loss = 0
    for i, batch in enumerate(data_loader):
        adjacency_matrix, node_features, distance_matrix, y = batch
        batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
        opt.optimizer.zero_grad()

        output, qy = model(node_features, node_features, batch_mask, batch_mask, 
                       Molecule(adjacency_matrix, distance_matrix, None), gumbel=True, hard=True)

        loss = loss_function(output, node_features, qy)
        loss.backward()
        
        opt.optimizer.step()
        running_loss += loss.item()
        mean_loss = running_loss/(i+1)
        print(mean_loss)