In [13]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

In [9]:
# Load data
smile_data = pd.read_csv("drugbank_smile_data.csv")

In [10]:
# Filter data. X-values are SMILE strings. Y-values are toxicity ratings (1 indicates toxic, 0 indicates non-toxic).
filtered_data_dict = dict()
unlabeled_data_dict = dict()
column_names = ["drug_id", "name", "smile", "toxicity"]
for col in column_names:
    filtered_data_dict[col] = []
    if col != "toxicity":
        unlabeled_data_dict[col] = []
    
num_approved = 0

for index, row in smile_data.iterrows():
    if row["approved"] == 1 or row["illicit"] == 1 or row["withdrawn"] == 1:
        filtered_data_dict["drug_id"].append(row["drug_id"])
        filtered_data_dict["name"].append(row["name"])
        filtered_data_dict["smile"].append(row["smile"])
        filtered_data_dict["toxicity"].append(int(row["approved"] != 1))
        num_approved += int(row["approved"] == 1)
    else:
        unlabeled_data_dict["drug_id"].append(row["drug_id"])
        unlabeled_data_dict["name"].append(row["name"])
        unlabeled_data_dict["smile"].append(row["smile"])

num_filtered_data = len(filtered_data_dict["smile"])
print("Used " + str(num_filtered_data) + " data points out of " + str(len(smile_data)))
print("Number of non-toxic: " + str(num_approved))

Used 2631 data points out of 10630
Number of non-toxic: 2413


In [11]:
# Convert data dictionary to dataframe and then CSV. Write CSV to disk.
filtered_data_df = pd.DataFrame(data=filtered_data_dict)
filtered_data_df.to_csv("drugbank_smile_data_filtered.csv", index=False)

unlabeled_data_df = pd.DataFrame(data=unlabeled_data_dict)
unlabeled_data_df.to_csv("drugbank_smile_data_unlabeled.csv", index=False)


In [17]:
# Generate train-test split (80-20).
train_data_df, test_data_df = train_test_split(filtered_data_df, train_size=0.8, test_size=0.2)

In [18]:
train_data_df.to_csv("drugbank_smile_data_filtered_train.csv", index=False)
test_data_df.to_csv("drugbank_smile_data_filtered_test.csv", index=False)