In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import CLIPTextModel, CLIPTokenizer

# SMILES Tokenizer

In [None]:
import re

class SMILESTokenizer:

    """
    Shared by Markus.
    
    """

    def __init__(self):
        self.pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\!|\$|\%[0-9]{2}|[0-9])"
        self.vocab = {}
        self.inv_vocab = {}
        self.pad_token = '<PAD>'
        self.unk_token = '<UNK>'
        self.start_token = '<START>'
        self.end_token = '<END>'
        self.max_len = None
    def tokenize(self, smiles):
        """Tokenizes a SMILES string using the predefined regular expression."""
        return re.findall(self.pattern, smiles)
    def build_vocab(self, smiles_list):
        """Builds vocabulary from a list of SMILES strings."""
        all_tokens = set()
        for smiles in smiles_list:
            tokens = self.tokenize(smiles)
            all_tokens.update(tokens)
        tokens = [self.pad_token, self.unk_token, self.start_token, self.end_token]
        all_tokens = sorted(all_tokens)
        all_tokens = tokens + all_tokens
        self.vocab = {token: idx for idx, token in enumerate(all_tokens)}
        self.inv_vocab = {idx: token for token, idx in self.vocab.items()}
    def encode(self, smiles, max_len=None):
        """Encodes a SMILES string into a list of token indices, optionally padding to max_len."""
        tokens = self.tokenize(smiles)
        tokens = [self.start_token] + tokens + [self.end_token]
        token_ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
        if max_len:
            token_ids = token_ids[:max_len] + [self.vocab[self.pad_token]] * max(0, max_len - len(token_ids))
        return token_ids
    def decode(self, token_ids):
        """Decodes a list of token indices back into a SMILES string."""
        tokens = [self.inv_vocab.get(token_id, self.unk_token) for token_id in token_ids]
        tokens = [token for token in tokens if token not in [self.start_token, self.end_token, self.pad_token]]
        return ''.join(tokens)
    def vocab_size(self):
        """Returns the size of the vocabulary."""
        return len(self.vocab)
    def pad_sequence(self, sequence, max_len):
        """Pads a sequence to the maximum length."""
        return sequence[:max_len] + [self.vocab[self.pad_token]] * max(0, max_len - len(sequence))

In [None]:

class SMILESEncoder(nn.Module):
    def __init__(self, tokenizer, embed_dim, max_len=128):
        super(SMILESEncoder, self).__init__()
        self.tokenizer = tokenizer
        self.embed_dim = embed_dim
        self.max_len = max_len
        
        # Embedding layer
        self.embedding = nn.Embedding(self.tokenizer.vocab_size(), embed_dim)
        
        # Transformer encoder for SMILES
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8),
            num_layers=6
        )

    def forward(self, smiles):
        # Tokenize and encode the SMILES strings
        encoded = [self.tokenizer.encode(s, max_len=self.max_len) for s in smiles]
        padded = torch.tensor(encoded).to(next(self.parameters()).device)

        # Pass through embedding layer
        embeddings = self.embedding(padded)

        # Process embeddings through the transformer
        embeddings = embeddings.permute(1, 0, 2)  # Transformer expects (seq_len, batch_size, embed_dim)
        transformed = self.transformer(embeddings)

        # Pooling: Use the mean of the sequence outputs
        pooled_output = transformed.mean(dim=0)
        return pooled_output

class SMILESCLIP(nn.Module):
    def __init__(self, smiles_encoder, text_encoder, projection_dim):
        super(SMILESCLIP, self).__init__()
        self.smiles_encoder = smiles_encoder
        self.text_encoder = text_encoder

        # Projection heads
        self.smiles_projection = nn.Linear(smiles_encoder.embed_dim, projection_dim)
        self.text_projection = nn.Linear(text_encoder.config.hidden_size, projection_dim)

    def forward(self, smiles, text):
        # Encode SMILES and text
        smiles_features = self.smiles_encoder(smiles)
        text_features = self.text_encoder(text)["pooler_output"]

        # Project features to the same dimensional space
        smiles_embeddings = self.smiles_projection(smiles_features)
        text_embeddings = self.text_projection(text_features)

        # Normalize embeddings
        smiles_embeddings = nn.functional.normalize(smiles_embeddings, p=2, dim=1)
        text_embeddings = nn.functional.normalize(text_embeddings, p=2, dim=1)

        return smiles_embeddings, text_embeddings

# Initialize tokenizer
smiles_tokenizer = SMILESTokenizer()

# Build vocabulary (example SMILES dataset needed)
example_smiles_list = ["CCO", "C1=CC=CC=C1", "O=C(O)C(O)"]
smiles_tokenizer.build_vocab(example_smiles_list)

# Initialize models
smiles_encoder = SMILESEncoder(smiles_tokenizer, embed_dim=512)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")

# Combine into a CLIP-like model
model = SMILESCLIP(smiles_encoder, text_encoder, projection_dim=512).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Example training loop
def train_step(batch, model, optimizer, criterion):
    model.train()
    optimizer.zero_grad()

    smiles, text, labels = batch  # Example batch structure

    # Forward pass
    smiles_embeddings, text_embeddings = model(smiles, text)

    # Compute logits
    logits_per_smiles = smiles_embeddings @ text_embeddings.T
    logits_per_text = logits_per_smiles.T

    # Compute loss
    ground_truth = torch.arange(len(smiles)).to(logits_per_smiles.device)
    loss = (criterion(logits_per_smiles, ground_truth) + criterion(logits_per_text, ground_truth)) / 2

    # Backward pass
    loss.backward()
    optimizer.step()

    return loss.item()

# Example usage
batch = (["CCO", "C1=CC=CC=C1"], ["ethanol", "benzene"], None)  # Dummy batch
loss = train_step(batch, model, optimizer, criterion)
print(f"Training loss: {loss}")