In [1]:
!pip install huggingface_hub
!pip install datasets



In [2]:
import os
import pickle 
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from argparse import Namespace
from pathlib import Path
from torch_geometric.data import Batch


# For Graph Encoding
from torch_geometric.data import Data as PyGData
from torch_geometric.loader import DataLoader as PyGDataLoader

import pickle
from datasets import load_from_disk


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
from huggingface_hub import hf_hub_download


# Get the token from the environment variable
hf_token = os.getenv("HF_TOKEN")

# Set repository details
repo_id = "OpenMol/PubChemSFT"  # Repository ID
filename = "train.pkl"          # Path to the file in the repository

# Optional: Authenticate if the repository is private
from huggingface_hub import login
login(token=hf_token)  # Replace with your actual token
local_dir = "/Users/smsultanmahmudrahat/Downloads/open_source/code"

# Download the file
try:
    hf_file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", local_dir= local_dir)
    print(f"File downloaded to: {hf_file_path}")
except Exception as e:
    print(f"Error: {e}")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


File downloaded to: /Users/smsultanmahmudrahat/Downloads/open_source/code/train.pkl


In [4]:
from huggingface_hub import hf_hub_download

# Get the token from the environment variable
hf_token = os.getenv("HF_TOKEN")

# Repository and file details
repo_id = "chao1224/MoleculeSTM"  # Repository name
filename = "demo/demo_checkpoints_SMILES/molecule_model_final.pth"  # Path to the file in the repo

# Optional: Authenticate if the repository is private
from huggingface_hub import login
login(token=hf_token)  
local_dir = "/Users/smsultanmahmudrahat/Downloads/open_source/code"


# Download file to memory
file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model", local_dir= local_dir)

# Load the PyTorch model directly from the file
import torch
# if GPU is available
# model_state_dict = torch.load(file_path) 
model_state_dict = torch.load(file_path, map_location=torch.device('cpu'))
print("Model loaded successfully!")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Model loaded successfully!


  model_state_dict = torch.load(file_path, map_location=torch.device('cpu'))


In [5]:
model_state_dict.keys()

odict_keys(['pos_emb', 'emb.weight', 'encoder.layers.0.self_attn.query_key_value.weight', 'encoder.layers.0.self_attn.query_key_value.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.key_value.weight', 'encoder.layers.0.self_attn.key_value.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.self_attn_layer_norm.bias', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc1.bias', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.fc2.bias', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.1.self_attn.query_key_value.weight', 'encoder.layers.1.self_attn.query_key_value.bias', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.q_proj.bias', 'encoder.layers.1.self_attn.key_value.weight', 'encoder.layers.1.self_attn.key_value.bias', 'encoder.

In [6]:
type(model_state_dict)

collections.OrderedDict

In [7]:
from torch.utils.data import Dataset
import pickle
from torch_geometric.data import Data as PyGData
from datasets import load_from_disk
import torch

class GraphTextDataset(Dataset):
    def __init__(self, hf_dataset):
        # Load graph data
        with open(hf_dataset, "rb") as f:
            self.graph_data = pickle.load(f)

        # Load text data from HuggingFace's dataset
        # self.text_data = load_from_disk(graph_path)
        print(f"Length of graph_data: {len(self.graph_data)}")

        # Optionally, ensure lengths match
        # assert len(self.graph_data) == len(self.text_data), "Graph and text data lengths do not match!"

    def __len__(self):
        # Return the length of the dataset
        return len(self.graph_data)

    def __getitem__(self, idx):
        # Fetch graph data
        item = self.graph_data[idx]
        gdict = item["graph"]

        edge_index = torch.tensor(gdict["edge_index"], dtype=torch.long)
        node_feat = torch.tensor(gdict["node_feat"], dtype=torch.long)

        # Handle edge attributes
        edge_attr = gdict.get("edge_attr", None)
        if edge_attr is not None:
            edge_attr = torch.tensor(edge_attr, dtype=torch.long)
        else:
            num_edges = edge_index.size(1)
            edge_attr = torch.zeros((num_edges, 2), dtype=torch.long)

        # Batch information (dummy for now)
        batch = torch.zeros((node_feat.size(0),), dtype=torch.long)

        # Create PyGData object
        pyg_graph = PyGData(
            x=node_feat,
            edge_attr=edge_attr,
            edge_index=edge_index,
            batch=batch,
        )

        # Fetch text data
        text = self.graph_data[idx]["answer"]

        return pyg_graph, text

In [8]:
dataset = GraphTextDataset(hf_file_path)
len(dataset)

Length of graph_data: 264391


264391

In [9]:
args = Namespace(
    model_name_or_path="lmsys/vicuna-7b-v1.3",
    # graph_data_path= graph_path,  # Change to your dataset path
    hf_data_path = hf_file_path,
    output_dir="/Users/smsultanmahmudrahat/Downloads/open_source/code",
    num_train_epochs=5,
    per_device_train_batch_size=16,
    learning_rate=2e-3,
    device="cuda" if torch.cuda.is_available() else "cpu",
)
device = args.device

########################################
# Replace with correct MoleculeSTM parameters as needed
########################################
NUM_LAYER = 4       # Number of GNN layers (adjust if known)
EMB_DIM = 300        # Embedding dimension used by MoleculeSTM (adjust if known)
JK = "last"          # The JK-connection mode used (adjust if known)
GRAPH_POOLING = "mean"  # Pooling mode: sum/mean/max (adjust if known)
# check smile-> grpahPath
INIT_CHECKPOINT = Path("/Users/smsultanmahmudrahat/Downloads/open_source/code/pickel_files/molecule_model.pth") # Path to molecule_model.pth


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (MessagePassing, global_add_pool,
                                global_max_pool, global_mean_pool)
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import add_self_loops, softmax, degree
from torch_scatter import scatter_add
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from collections import OrderedDict


        
        
class GINConv(MessagePassing):
    def __init__(self, emb_dim, aggr="add"):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr=aggr)

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr)
        # WARN: some weird thing happend if excute in bfloat16, so we force to cast to float32
        dtype = x.dtype
        inter = (1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)
        if dtype == torch.bfloat16:
            inter = inter.float()
            out = self.mlp.float()(inter)
            out = out.to(dtype)
        else:
            out = self.mlp(inter)
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out



class GCNConv(MessagePassing):
    def __init__(self, emb_dim, aggr="add"):
        super(GCNConv, self).__init__(aggr=aggr)

        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.bond_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


class GNN(nn.Module):
    def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"):

        if num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        super(GNN, self).__init__()
        self.drop_ratio = drop_ratio
        self.num_layer = num_layer
        self.JK = JK

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr="add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

    # def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.atom_encoder(x)

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
        else:
            raise ValueError("not implemented.")
        return node_representation


class GNN_graphpred(nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        arg.emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set

    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536 """

    def __init__(
        self, 
        emb_dim,  
        graph_pooling, 
        projection_dim:int=None,
        molecule_node_model=None,
        init_checkpoint=None,
    ):
        super(GNN_graphpred, self).__init__()

        self.molecule_node_model = molecule_node_model
        self.emb_dim = emb_dim

        # Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise ValueError("Invalid graph pooling type.")
        
        if projection_dim is not None:
            self.projector = nn.Linear(emb_dim, projection_dim)
            self.output_dim = projection_dim
        else:
            self.projector = None
            self.output_dim = emb_dim
        
        if init_checkpoint is not None:
            self._load_state_dict(init_checkpoint, strict=False)

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.molecule_node_model(x, edge_index, edge_attr)
        graph_representation = self.pool(node_representation, batch)
        return graph_representation, node_representation
    
    def encode_mol(self, mol, proj=False, return_node_feats=False, eval=True):
        if eval:
            self.molecule_node_model.eval() # hard code: set to eval mode
            with torch.no_grad():
                h_graph, h_node = self.forward(mol)
        else:
            self.molecule_node_model.train() # set to train mode
            h_graph, h_node = self.forward(mol)
        if proj and self.projector is not None:
            h_graph = self.projector(h_graph)
            h_node = self.projector(h_node)
        if return_node_feats:
            return h_graph, h_node
        else:
            return h_graph
    
    # def _load_state_dict(self, model_file, strict=False):
    #     print("Loading from {} ...".format(model_file))
    #     state_dict = torch.load(model_file, map_location=torch.device('cpu'))
    #     self.load_state_dict(state_dict, strict=strict)
    #     return
    
    def _load_state_dict(self, model_file, strict=False):
        print(f"Loading from {model_file} ...")
        state_dict = torch.load(model_file, map_location=torch.device('cpu'))
        incompatible_keys = self.load_state_dict(state_dict, strict=strict)
        print(f"Missing keys: {incompatible_keys.missing_keys}")
        print(f"Unexpected keys: {incompatible_keys.unexpected_keys}")
        return
    
    @property
    def dummy_feature(self):
        return self.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
    
    @property
    def hidden_size(self):
        return self.output_dim

In [11]:
from transformers import AutoModel, AutoTokenizer

model_name_or_path = "lmsys/vicuna-7b-v1.3"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
llm_model = AutoModel.from_pretrained(model_name_or_path) 

# now we need to freeze the llm model 
llm_model.eval().requires_grad_(False)

llm_hidden = llm_model.config.hidden_size
print(f"The hidden size of the model is: {llm_hidden}")

# it will be used to predict next node 
# it processes molecules and produces graph embedding 
molecule_node_model = GNN(
	num_layer= NUM_LAYER, emb_dim= EMB_DIM, JK= JK, drop_ratio = 0, gnn_type="gin")

# now predict the graph embedding because it will use pooling for it. 
# two GNN architecture is design independantly for better design flexibility. 

graph_encoder = GNN_graphpred(
	emb_dim= EMB_DIM,
	graph_pooling = "mean",
	# if we want to match it with the Text emb hidden layer, we could use it. 
	# but we have used nn.Linear() for projecting the emb to better alignment. 
	# so, we will not project anything.So, emb_dim and output er dimension same thakbe.
	projection_dim = None,
	# node level pred k ei layer e integrate korbo.
	molecule_node_model = molecule_node_model,
	# load the weights in same architecture
	init_checkpoint = INIT_CHECKPOINT)

graph_encoder.to(device)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
Loading checkpoint shards: 100%|██████████| 2/2 [00:41<00:00, 20.65s/it]

The hidden size of the model is: 4096
Loading from /Users/smsultanmahmudrahat/Downloads/open_source/code/pickel_files/molecule_model.pth ...
Missing keys: ['molecule_node_model.atom_encoder.atom_embedding_list.0.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.1.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.2.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.3.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.4.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.5.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.6.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.7.weight', 'molecule_node_model.atom_encoder.atom_embedding_list.8.weight', 'molecule_node_model.gnns.0.eps', 'molecule_node_model.gnns.0.mlp.0.weight', 'molecule_node_model.gnns.0.mlp.0.bias', 'molecule_node_model.gnns.0.mlp.1.weight', 'molecule_node_model.gnns.0.mlp.1.bias', 'molecule_node_model.gnns.0.mlp.1.running_mean', 'molecul


  state_dict = torch.load(model_file, map_location=torch.device('cpu'))


GNN_graphpred(
  (molecule_node_model): GNN(
    (atom_encoder): AtomEncoder(
      (atom_embedding_list): ModuleList(
        (0): Embedding(119, 300)
        (1): Embedding(5, 300)
        (2-3): 2 x Embedding(12, 300)
        (4): Embedding(10, 300)
        (5-6): 2 x Embedding(6, 300)
        (7-8): 2 x Embedding(2, 300)
      )
    )
    (gnns): ModuleList(
      (0-3): 4 x GINConv()
    )
    (batch_norms): ModuleList(
      (0-3): 4 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

In [12]:
def collate_fn (batch):
	# 1st element of the graph 
	graphs = [b[0] for b in batch]
	# 2nd element of the batch is texts 
	texts = [b[1] for b in batch]
	return Batch.from_data_list(graphs), texts 

def alignment_loss(graph_emb, text_emb):
	
	# before cosine similarity both emb should be normalized 
	graph = F.normalize(graph_emb, p= 2, dim=-1)
	text = F.normalize (text_emb,p=2, dim =-1)
	
	# similarity between graph and text 
	sim = graph @ text.T 
	batch_size = sim.size(0)
	
	target = torch.arange(batch_size, dtype= torch.long, device= sim.device)
	
	# now find out the entropy or loss function for text and graph 
	
	loss_g2t = F.cross_entropy(F.log_softmax(sim,dim= -1), target) 
	loss_t2g = F.cross_entropy(F.log_softmax(sim.T,dim= -1), target)
	
	loss = (loss_g2t+loss_t2g)/2.0
	
	return loss 

# now it will help 
data_loader = PyGDataLoader(
	dataset = dataset, 
	batch_size = 8, 
	shuffle = True, collate_fn = collate_fn, num_workers=0)


projector = nn.Linear(graph_encoder.output_dim, llm_hidden).to(device)
# optimize it
optimizer = torch.optim.Adam(params = projector.parameters(), lr= args.learning_rate, weight_decay= 0.)


In [None]:
def get_text_embedding(text_list):
    encoding = tokenizer(text_list, return_tensors = "pt", padding = True, truncation= True,max_length = 128)
    # return of encoding is a dict. {"input_ids" :... ,"attention_maskk" : 1,0,1... }
    token_id = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]
    
    # torch.no_grad() means no backprop, no weight update
    # gradient update is important in training.
    # but torch.no_grad() is used in inference cause
    # no gradient update is needed in that stage. 
    # and that's important not to gradient update.
    with torch.no_grad():
        # we input token_ID as text as a token value like 12,23 etc. 
        # and attention_mask inside the model
        # Input sequence: "Hello, how are you?" (length 5)
        # Padded sequence: "Hello, how are you? [PAD] [PAD]" (length 7)
        # [1, 1, 1, 1, 1, 0, 0] is the attention mask to padding. 
        # padding helps us to eliminate data leakage? NO! it prevents model not to calculate 
        # [PAD] token otherwise it will be used in query attention generation.

        output = llm_model(token_id, attention_mask)
        # shape of output is: [batch_size, seq_length, hidden_state]
        
        # [2,5,128] means 2 sentences and each sentences have 5 token and each token has 128 dim vector. 
        # mean(dim=1) we are geting one embedding for each sentences of token 
        # [batch_size, seq_length, hidden_state] converts into [batch_size, hidden_state] 
        text_emb = output.last_hidden_state.mean(dim= 1)
    return text_emb

#####################################
# training method 
############################
for epoch in range(args.num_train_epochs):
    projector.train()
    
    for step, (batch_graph, text) in enumerate (data_loader):
        
        # all graphs are distributed to the devices
        batch_graphs = batch_graph.to(device) 

        # print the size of the batch_graph.x 
        print(f" Size of the batch_graph.x: {batch_graph.x.size()}")
        # number of total graph in a batch 
        # print(f"Number of graphs in the batch: {batch_graph.num_graphs}")
        # we are making a graph_encoder so we are not update gradient for it. 
        # LLM and graph encoder will be assigned with their pretrained weight and they will be frozen 
        # during the training process. so we will not update the gradient of LLM and GNN. We will only 
        # update the model with projector.train() 
        with torch.no_grad():
            # make a dummpy input so that we can text whether the 
            # graph archi is working or not. 
            # graph_emb, _ = graph_encoder(*[torch.zeros(1, dtypes= torch.long).to(device)])
            
            # if the dummy works, then we will generate actual graph_encoder 
            
            graph_emb, _ = graph_encoder(batch_graph)
            
            
        # now we will load LLM mode to generate text. we will not update their trainign weight 
        with torch.no_grad():
            text_emb = get_text_embedding(text)

        proj_g = projector(graph_emb)
        
        loss = alignment_loss(proj_g,text_emb)
        
        loss.backward()
        
        optimizer.step()
        
        # Print loss every few steps
        
        if step %10 == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
            
    
    # save projector checkpoint each epoch
    os.makedirs (args.output_dir, exist_ok=True)
    torch.save(projector.state_dict(), os.path.join(args.output_dir, f"projector_epoch{epoch}.pth"))
