In [1]:
from src.rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator, get_default_model_and_tokenizer
)
from transformers.models import BertModel
import torch

In [2]:

model, tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)


In [11]:
old = model.state_dict()['encoder.layer.3.attention.self.query.weight']
old

tensor([[ 0.0552, -0.0137, -0.0009,  ...,  0.0221, -0.0173, -0.0106],
        [ 0.0032,  0.0318, -0.0359,  ..., -0.0893, -0.0229, -0.0698],
        [-0.0468,  0.0007, -0.0129,  ...,  0.0236, -0.0375,  0.0695],
        ...,
        [ 0.0510, -0.0214,  0.0567,  ..., -0.0109, -0.0440,  0.0296],
        [-0.0690,  0.0290, -0.0417,  ...,  0.0092,  0.0234,  0.0522],
        [ 0.0409,  0.0884, -0.0115,  ...,  0.0084,  0.0520,  0.0493]])

In [5]:

rxn = "CC=O>>CCO"
rxns = [
    "CC=O>>CCO",
    "CCC=O>>CCCO",
    "CCNC=O>>CCNCO",
]

In [6]:
tokenizer.encode_plus(
    rxn,
    max_length=model.config.max_position_embeddings,
    padding=True, truncation=True, return_tensors='pt'
)

{'input_ids': tensor([[12, 16, 16, 22, 19, 29, 16, 16, 19, 13]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [7]:
batched_output = tokenizer.batch_encode_plus(
    rxns,
    max_length=model.config.max_position_embeddings,
    padding=True, truncation=True, return_tensors='pt'
)
batched_output

{'input_ids': tensor([[12, 16, 16, 22, 19, 29, 16, 16, 19, 13,  0,  0,  0,  0],
        [12, 16, 16, 16, 22, 19, 29, 16, 16, 16, 19, 13,  0,  0],
        [12, 16, 16, 23, 16, 22, 19, 29, 16, 16, 23, 16, 19, 13]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [8]:
batched_output['input_ids']

tensor([[12, 16, 16, 22, 19, 29, 16, 16, 19, 13,  0,  0,  0,  0],
        [12, 16, 16, 16, 22, 19, 29, 16, 16, 16, 19, 13,  0,  0],
        [12, 16, 16, 23, 16, 22, 19, 29, 16, 16, 23, 16, 19, 13]])

In [9]:
print(batched_output['input_ids'].shape)
print(batched_output['token_type_ids'].shape)
print(batched_output['attention_mask'].shape)

torch.Size([3, 14])
torch.Size([3, 14])
torch.Size([3, 14])


In [17]:
model(**batched_output)['last_hidden_state'][:, 0, :].shape

torch.Size([3, 256])

In [14]:
fps = rxnfp_generator.convert(rxns)

In [11]:
fp = rxnfp_generator.convert(rxn)

print(rxn)
print(fp)

CC=O>>CCO
[-1.6542866230010986, -1.7284846305847168, -1.2616347074508667, -1.7588913440704346, 1.160961389541626, 1.9029163122177124, 0.9351949691772461, -1.1432665586471558, -0.4869869649410248, -0.8294594883918762, -1.6161038875579834, 1.095350742340088, -1.8773101568222046, 1.2194468975067139, -1.4543977975845337, 0.5810214877128601, 0.2639060318470001, 0.22561800479888916, -2.2959578037261963, -0.8097139596939087, -2.57768177986145, 0.6789469122886658, -0.5541850924491882, 1.8056433200836182, -0.7179451584815979, -1.9067778587341309, -1.300304889678955, 0.9869235754013062, 0.655123233795166, 0.7017709016799927, 0.11212652176618576, -0.46746811270713806, 1.1277308464050293, 0.09870129078626633, 0.7051181197166443, 1.8312512636184692, -3.2643401622772217, -0.13099075853824615, 0.1566365361213684, -0.8676517605781555, 0.29305192828178406, -0.486428439617157, -3.0421721935272217, 1.3955600261688232, -2.0653634071350098, -1.451167345046997, 1.5111931562423706, -0.8796708583831787, -0.39

In [12]:
fn = "/home/stef/quest_data/hiec/results/runs/686395449532405019/47aa7b669ede4dad8125ee671d728b9a/checkpoints/best-checkpoint-02-val_roc-0.500.ckpt"
ckpt = torch.load(fn, weights_only=False)

In [19]:
new = ckpt['state_dict']['reaction_encoder.model.encoder.layer.3.attention.self.query.weight']

In [23]:
abs(old - new).mean()

tensor(0.0033)

In [22]:
from src.nn import BertRxnEncoder
from src.model import BertTwoChannel

In [None]:
re = BertRxnEncoder(model, d_rxn=256, d_h=132)
wrap_model = BertTwoChannel(
    d_prot=1280,
    d_h=132,
    reaction_encoder=re,
    predictor
)