In [1]:
import pyarrow as pa
import pyarrow.csv as csv
import polars as pl
import numpy as np


In [2]:
def read_tsv_to_table(file_path, column_names=None):
    # Read TSV into a PyArrow Table
    read_options = csv.ReadOptions(column_names=column_names)
    parse_options = csv.ParseOptions(delimiter='\t')
    table = csv.read_csv(file_path, read_options=read_options, parse_options=parse_options)
    return table

def write_table_to_arrowwriter(table, output_path):
    # Define features based on table schema
    with pa.OSFile(output_path, 'wb') as sink:
        with pa.RecordBatchStreamWriter(sink, table.schema) as writer:
            writer.write_table(table)

def pa_type_to_hf_type(pa_type):
    # Basic mapping from PyArrow types to HF dataset types
    if pa.types.is_integer(pa_type):
        return "int32"
    elif pa.types.is_floating(pa_type):
        return "float32"
    elif pa.types.is_string(pa_type):
        return "string"
    else:
        raise ValueError(f"Unsupported PyArrow type: {pa_type}")

In [3]:
# file_prefix = 'cluster_100k'
file_prefix = 'cluster_6M'

table = read_tsv_to_table(f'{file_prefix}.txt', column_names=['cid','smiles','iupac','formula','num_atoms'])
# write_table_to_arrowwriter(table, 'cluster_100k.arrow')

In [4]:
def partition_dataframe(df: pl.DataFrame, probs: tuple[float, float, float]) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    assert sum(probs) == 1.0, "Probabilities must sum to 1.0"
   
    # Generate random choices
    choices = np.random.choice([0, 1, 2], size=len(df), p=probs)
   
    # Create masks for each partition
    mask0 = pl.Series(choices == 0)
    mask1 = pl.Series(choices == 1)
    mask2 = pl.Series(choices == 2)

    df0 = df.filter(mask0)
    df1 = df.filter(mask1)
    df2 = df.filter(mask2)

    return df0, df1, df2

In [5]:

# Convert PyArrow Table to Polars DataFrame
pl_df = pl.from_arrow(table)

# clean up the names
pl_df = pl_df.filter(pl_df['num_atoms'] >= 3)
pl_df = pl_df.filter(~pl_df['smiles'].str.contains(r'\.'))
pl_df = pl_df.filter(~pl_df['iupac'].str.contains(r'None'))
pl_df = pl_df.filter(~pl_df['smiles'].str.contains(r'\[2H\]'))

# Sort by 'num_atoms' and show the first 20 rows
sorted_df = pl_df.sort('num_atoms')
print(sorted_df.head(40))
# Randomly split the dataset
train_df, eval_df, test_df = partition_dataframe(sorted_df,[0.9, 0.05, 0.05])

# Write each split to separate Arrow streaming format files
train_df.write_ipc_stream(f'{file_prefix}_train.arrow')
eval_df.write_ipc_stream(f'{file_prefix}_eval.arrow')
test_df.write_ipc_stream(f'{file_prefix}_test.arrow')


shape: (40, 5)
┌───────┬───────────┬────────────────────────┬─────────┬───────────┐
│ cid   ┆ smiles    ┆ iupac                  ┆ formula ┆ num_atoms │
│ ---   ┆ ---       ┆ ---                    ┆ ---     ┆ ---       │
│ i64   ┆ str       ┆ str                    ┆ str     ┆ i64       │
╞═══════╪═══════════╪════════════════════════╪═════════╪═══════════╡
│ 177   ┆ CC=O      ┆ acetaldehyde           ┆ C2H4O   ┆ 3         │
│ 283   ┆ C(=O)[O-] ┆ formate                ┆ CHO2-   ┆ 3         │
│ 284   ┆ C(=O)O    ┆ formic acid            ┆ CH2O2   ┆ 3         │
│ 540   ┆ C(#N)O    ┆ cyanic acid            ┆ CHNO    ┆ 3         │
│ 674   ┆ CNC       ┆ N-methylmethanamine    ┆ C2H7N   ┆ 3         │
│ …     ┆ …         ┆ …                      ┆ …       ┆ …         │
│ 11648 ┆ C[Se]C    ┆ methylselanylmethane   ┆ C2H6Se  ┆ 3         │
│ 14788 ┆ S=[Fe]=S  ┆ bis(sulfanylidene)iron ┆ FeS2    ┆ 3         │
│ 14793 ┆ O=[Pb]=O  ┆ dioxolead              ┆ O2Pb    ┆ 3         │
│ 14796 ┆ O=[Ge]=O 

In [6]:
print(len(train_df), len(eval_df), len(test_df))
# print(sorted_df.tail(40))
# Group by 'formula' and filter groups with more than one row
duplicates = sorted_df.group_by("formula").count().filter(pl.col("count") > 1)['formula']

duplicate_df = sorted_df.filter(pl.col("formula").is_in(duplicates))
# Print the rows with duplicate formulas
print(duplicate_df.sort('formula').head(40))

130448 7310 7356
shape: (40, 5)
┌──────────┬─────────────────────────────┬─────────────────────────────┬───────────────┬───────────┐
│ cid      ┆ smiles                      ┆ iupac                       ┆ formula       ┆ num_atoms │
│ ---      ┆ ---                         ┆ ---                         ┆ ---           ┆ ---       │
│ i64      ┆ str                         ┆ str                         ┆ str           ┆ i64       │
╞══════════╪═════════════════════════════╪═════════════════════════════╪═══════════════╪═══════════╡
│ 24577    ┆ O[As]=O                     ┆ arsenous acid               ┆ AsHO2         ┆ 3         │
│ 104785   ┆ O=[AsH]=O                   ┆ dioxo-lambda5-arsane        ┆ AsHO2         ┆ 3         │
│ 13025    ┆ CC1=CCC2=CC=CC=C12          ┆ 3-methyl-1H-indene          ┆ C10H10        ┆ 10        │
│ 142329   ┆ C=C1C2C=CC=CC1C=C2          ┆ 9-methylidenebicyclo[4.2.1] ┆ C10H10        ┆ 10        │
│          ┆                             ┆ non…            

  duplicates = sorted_df.group_by("formula").count().filter(pl.col("count") > 1)['formula']
