In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [68]:
import argparse
import random
import json

import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from data_loader import DataGenerator
from tqdm import tqdm
from transformers import BertConfig
from bert_utils import SmilesBertModel

k = 5
batch_size = 3

# Create Data Generator
train_iterable = DataGenerator(
    data_json_path=f'data/train.json',
    k=k,
    repr="smiles_only",
)
train_loader = iter(
    torch.utils.data.DataLoader(
        train_iterable,
        batch_size=batch_size,
        num_workers=1,
        pin_memory=True,
    )
)

repr_to_input_dims = {"smiles_only": 2 + 767, 
                        "concat": 2 + 767 + 640,
                        "concat_smiles_vaeprot": 2 + 767 + 100}


In [84]:
model_config = BertConfig(
    max_position_embeddings = (k+1)*2,
    hidden_size = 2 + 767,
    num_hidden_layers = 4,
    num_attention_heads = 1,
    intermediate_size = 64,
    classifier_dropout = 0.3,
    attention_probs_dropout_prob = 0.3,
    hidden_dropout_prob = 0.3,
    k = 5,
    batch_size = 3,
)

# Create model
model = SmilesBertModel(model_config)
# model.to(device)

# Create optimizer
optim = torch.optim.Adam(model.parameters(), lr=1e-5)

attention_mask = torch.ones((batch_size, (k+1)*2))

for step in tqdm(range(100)):
    i, l = next(train_loader)
    loss = model(i.float(), l.float(), attention_mask)
    loss.backward()
    optim.step()
    optim.zero_grad()

    if (step+1) % 10 == 0: 
        print(loss)

 16%|█▌        | 16/100 [00:00<00:02, 31.19it/s]

tensor(0.1557, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 23%|██▎       | 23/100 [00:00<00:02, 29.37it/s]

tensor(0.0523, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 35%|███▌      | 35/100 [00:01<00:02, 31.95it/s]

tensor(0.0232, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 43%|████▎     | 43/100 [00:01<00:01, 32.39it/s]

tensor(0.0153, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 55%|█████▌    | 55/100 [00:01<00:01, 33.09it/s]

tensor(0.0101, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 63%|██████▎   | 63/100 [00:02<00:01, 33.09it/s]

tensor(0.0097, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 75%|███████▌  | 75/100 [00:02<00:00, 32.29it/s]

tensor(0.0066, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 83%|████████▎ | 83/100 [00:02<00:00, 32.38it/s]

tensor(0.0062, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


 95%|█████████▌| 95/100 [00:02<00:00, 33.49it/s]

tensor(0.0048, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


100%|██████████| 100/100 [00:03<00:00, 31.98it/s]

tensor(0.0050, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)





In [38]:
def process(input_images, input_labels):
    # 1. Zero out query set labels
    input_labels = input_labels.copy() # .clone()
    input_labels[:, -1, :, :] = 0

    # 2. Concatenate labels to images
    input_images_and_labels = torch.cat((input_images, input_labels), -1)

    # 3. Reshape
    B, K_1, N, D = input_images_and_labels.shape
    input_images_and_labels = input_images_and_labels.reshape((B, -1, D))
    return input_images_and_labels

In [43]:
for batch in test_iterable:
    smiles, labels = batch
    break

In [46]:
import numpy as np

In [47]:
smiles[-1, :, :] = 0
smiles_and_labels = np.concatenate((smiles, labels), -1)

In [None]:
# (batch, seq_len, hidden_dim) -> (batch, seq_len, hidden_dim)

In [6]:
import torch
from torch.utils.data import DataLoader
import pickle
from transformers import BertConfig
from tqdm import tqdm
from typing import Dict, Any
from utils import MoleculeDataset, SmilesBertModel

In [None]:
def run_pipeline(config: Dict[str, Any]):
    # Load training set
    train_dataset = MoleculeDataset(config["train_path"])
    train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

    # Load validation set
    val_dataset = MoleculeDataset(config["val_path"])
    val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"])

    # Check if model was provided, in the case of loading the pretrained model
    if "load_model_path" in config:
        model = torch.load(config["load_model_path"])
    else:
        # Load vocab
        with open("preprocessed/vocab.pickle", "rb") as f:
            vocab = pickle.load(f)
        
        # Set model configurations
        model_config = BertConfig(
            pad_token_id = vocab["PAD"],
            vocab_size = len(vocab),
            max_position_embeddings = 128,
            hidden_size = 64,
            num_hidden_layers = 4,
            num_attention_heads = 4,
            intermediate_size = 64,
            classifier_dropout = config["classifier_dropout"],
            attention_probs_dropout_prob = config["attention_probs_dropout_prob"],
            hidden_dropout_prob = config["hidden_dropout_prob"],
    )

    # Create model
    model = SmilesBertModel(model_config)

# Use Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

# Training loop
best_val_loss = float("inf")
for epoch in range(config["epochs"]):
    running_train_loss = 0
    for batch in tqdm(train_dataloader):
        loss, _ = model(**batch)
        running_train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Log training loss
    print(f"epoch {epoch}, train MSE loss {running_train_loss / len(train_dataloader)}")

    # Evaluate on validation set
    model.eval()
    val_loss = 0
    with torch.no_grad():
        # Compute average MSE loss
        for batch in val_dataloader:
            loss, _ = model(**batch)
            val_loss += loss
        avg_val_loss = val_loss / len(val_dataloader)

        # Save best model
        if avg_val_loss < best_val_loss:
            torch.save(model, config["model_save_path"])
        print(f"epoch {epoch}, val MSE loss {avg_val_loss}")
    model.train()
        