In [None]:
# Install necessary packages
!pip install rdkit
!pip install duckdb
!pip install pandas networkx
!pip install torch
!pip install torch-geometric

# Import libraries
import numpy as np 
import pandas as pd 
import duckdb
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv, GINConv
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
import itertools
import networkx as nx
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
train_path = '/kaggle/input/leash-BELKA/train.parquet'
test_path = '/kaggle/input/leash-BELKA/test.csv'

con = duckdb.connect()
df = con.query(f"""(SELECT * FROM parquet_scan('{train_path}') WHERE binds = 0 ORDER BY random() LIMIT 30000) 
                   UNION ALL 
                   (SELECT * FROM parquet_scan('{train_path}') WHERE binds = 1 ORDER BY random() LIMIT 30000)""").df()
con.close()

df = df.drop(['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles'], axis=1)

test_df = pd.read_csv(test_path)
test_df = test_df.drop(['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles'], axis=1)
print(test_df.head())

In [None]:
# pd.set_option('display.max_colwidth', None)
# df.sample(n=10)

# import matplotlib.pyplot as plt

# df['binds'].value_counts().plot(kind='bar')
# plt.title('Training Data Distribution')
# plt.show()

In [None]:
protein_encoder = LabelEncoder()
protein_encoder.fit(['HSA', 'BRD4', 'sEH'])

def smiles_to_graph(smiles, protein):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    nodes = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
    edges = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(nodes, dtype=torch.float).view(-1, 1)
    protein_encoded = protein_encoder.transform([protein])[0]
    protein_feature = torch.tensor([protein_encoded], dtype=torch.float)
    protein_features = protein_feature.repeat(x.size(0), 1)
    x = torch.cat([x, protein_features], dim=1)
    return Data(x=x, edge_index=edge_index)

df['graph'] = df.apply(lambda row: smiles_to_graph(row['molecule_smiles'], row['protein_name']), axis=1)
df = df[df['graph'].notnull()]
test_df['graph'] = test_df.apply(lambda row: smiles_to_graph(row['molecule_smiles'], row['protein_name']), axis=1)
test_df = test_df[test_df['graph'].notnull()]

In [None]:
class MoleculeDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        graph = data['graph']
        label = torch.tensor(data['binds'], dtype=torch.long)
        return graph, label

class TestMoleculeDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        graph = data['graph']
        return graph

dataset = MoleculeDataset(df)
test_dataset = TestMoleculeDataset(test_df)

# data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=2, layer_type='GAT'):
        super(GNN, self).__init__()
        if layer_type == 'GAT':
            self.conv1 = GATConv(input_dim, hidden_dim, heads=heads)
            self.conv2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads)
            self.conv3 = GATConv(hidden_dim * heads, output_dim, heads=heads)
        elif layer_type == 'GCN':
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, hidden_dim)
            self.conv3 = GCNConv(hidden_dim, output_dim)
        elif layer_type == 'GIN':
            nn1 = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
            self.conv1 = GINConv(nn1)
            nn2 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
            self.conv2 = GINConv(nn2)
            nn3 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, output_dim))
            self.conv3 = GINConv(nn3)
        self.output_dim = output_dim

    def forward(self, data):
        data = data.to(device)
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        return F.softmax(x, dim=1)

hidden_dims = [16, 32, 64]
learning_rates = [0.01, 0.001, 0.0001]
dropout_rates = [0.3, 0.5]
weight_decays = [1e-4, 1e-5]
layer_types = ['GAT', 'GCN', 'GIN']
heads = [2, 4]

hyperparameter_grid = list(itertools.product(hidden_dims, learning_rates, dropout_rates, weight_decays, layer_types, heads))

kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')

In [None]:
best_accuracy = 0
best_params = None

for params in hyperparameter_grid:
    hidden_dim, learning_rate, dropout_rate, weight_decay, layer_type, heads = params
    print(f"Testing combination: Hidden Dim: {hidden_dim}, Learning Rate: {learning_rate}, Dropout Rate: {dropout_rate}, Weight Decay: {weight_decay}, Layer Type: {layer_type}, Heads: {heads}")

    fold_accuracies = []

    for train_idx, val_idx in kf.split(dataset):
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

        model = GNN(input_dim=2, hidden_dim=hidden_dim, output_dim=2, heads=heads, layer_type=layer_type).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        criterion = torch.nn.CrossEntropyLoss()
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        early_stopping = EarlyStopping(patience=10, verbose=True)
        
        model.train()
        for epoch in range(100):  
            total_loss = 0
            for data, labels in train_loader:
                data, labels = data.to(device), labels.to(device)
                optimizer.zero_grad()
                out = model(data)
                loss = criterion(out, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            val_loss = 0
            model.eval()
            with torch.no_grad():
                for data, labels in val_loader:
                    data, labels = data.to(device), labels.to(device)
                    out = model(data)
                    loss = criterion(out, labels)
                    val_loss += loss.item()
            val_loss /= len(val_loader)
            scheduler.step(val_loss)
            early_stopping(val_loss, model)
            
            if early_stopping.early_stop:
                break

        model.load_state_dict(torch.load('checkpoint.pt'))

        def evaluate(model, data_loader):
            model.eval()
            correct = 0
            with torch.no_grad():
                for data, labels in data_loader:
                    data, labels = data.to(device), labels.to(device)
                    out = model(data)
                    pred = out.argmax(dim=1)
                    correct += (pred == labels).sum().item()
            accuracy = correct / len(data_loader.dataset)
            return accuracy

        val_accuracy = evaluate(model, val_loader)
        fold_accuracies.append(val_accuracy)

    avg_val_accuracy = np.mean(fold_accuracies)
    print(f"Avg Validation Accuracy for params {params}: {avg_val_accuracy}")

    if avg_val_accuracy > best_accuracy:
        best_accuracy = avg_val_accuracy
        best_params = params

print(f"Best Hyperparameters: {best_params} with accuracy: {best_accuracy}")

In [None]:
hidden_dim, learning_rate, dropout_rate, weight_decay, layer_type, heads = best_params
final_model = GNN(input_dim=2, hidden_dim=hidden_dim, output_dim=2, heads=heads, layer_type=layer_type).to(device)
final_optimizer = torch.optim.AdamW(final_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
final_criterion = torch.nn.CrossEntropyLoss()
final_scheduler = ReduceLROnPlateau(final_optimizer, mode='min', factor=0.5, patience=5, verbose=True)
final_early_stopping = EarlyStopping(patience=10, verbose=True)

final_train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
final_model.train()
for epoch in range(100):
    total_loss = 0
    for data, labels in final_train_loader:
        data, labels = data.to(device), labels.to(device)
        final_optimizer.zero_grad()
        out = final_model(data)
        loss = final_criterion(out, labels)
        loss.backward()
        final_optimizer.step()
        total_loss += loss.item()
    final_scheduler.step(total_loss / len(final_train_loader))
    final_early_stopping(total_loss / len(final_train_loader), final_model)
    if final_early_stopping.early_stop:
        break

final_model.load_state_dict(torch.load('checkpoint.pt'))

In [None]:
def make_predictions(model, test_data_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for data in test_data_loader:
            data = data.to(device)
            out = model(data)
            probs = out[:, 1]  
            predictions.extend(probs.cpu().numpy())
    return predictions

test_predictions = make_predictions(final_model, test_data_loader)

test_df['binds'] = test_predictions

output_df = test_df[['id', 'binds']]

output_csv_path = '/kaggle/working/test_predictions.csv'
output_df.to_csv(output_csv_path, index=False)

print(f'Saved predictions to {output_csv_path}')

In [None]:
def visualize_graph(data):
    G = nx.Graph()
    x = data.x.cpu().numpy()
    edge_index = data.edge_index.cpu().numpy()
    edge_index = data.edge_index.numpy()
    for i, feature in enumerate(x):
        G.add_node(i, atom=feature[0], protein=feature[1])
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])
    pos = nx.spring_layout(G)
    node_labels = nx.get_node_attributes(G, 'atom')
    nx.draw(G, pos, with_labels=True, labels=node_labels, node_color='lightblue', edge_color='gray')
    plt.title('Graph Representation of a Molecule')
    plt.show()

sample_data = df['graph'].sample(1).values[0]
visualize_graph(sample_data)

sample_test_data = test_df['graph'].sample(1).values[0]
visualize_graph(sample_test_data)