In [1]:
# Import needed libraries
import pandas as pd
import numpy as np
import tensorflow as tf
from google.cloud import storage
import io
from tqdm import tqdm

2025-02-03 22:10:27.051258: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-03 22:10:27.059447: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-03 22:10:27.167545: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-03 22:10:27.209176: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-03 22:10:27.222129: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-03 22:10:27.327612: I tensorflow/core/platform/cpu_feature_gu

In [5]:
PADDING = "P"
VOCAB = [PADDING, "U", "A", "C", "G", "T", "N"]
DNA_TOKENS = {ch: i for i, ch in enumerate(VOCAB)}
# This makes:
# PADDING = 0
# U = 1 
# A = 2
# C = 3
# G = 4
# T = 5
# N = 6

def tokenize_dna(sequence):
    return [DNA_TOKENS.get(base, 1) for base in sequence.upper()]  # Unknown bases map to U(1)

In [7]:
# Test
test_seq = "ACGT"
tokens = tokenize_dna(test_seq)
print(f"Test sequence: {test_seq}")
print(f"Tokens: {tokens}")
assert tokens == [2,3,4,5]

Test sequence: ACGT
Tokens: [2, 3, 4, 5]


In [8]:
def process_csv_to_tfrecord(species_name, max_padding=100):
   """Process CSV to TFRecord with tests and validation"""
   
   # Load CSV
   bucket_name = "minformer_data"
   storage_client = storage.Client()
   bucket = storage_client.get_bucket(bucket_name)
   csv_blob = bucket.blob(f"eukaryote_pands/{species_name}.csv")
   content = csv_blob.download_as_string()
   df = pd.read_csv(io.BytesIO(content))
   
   print(f"Processing {species_name}")
   print(f"Total sequences: {len(df)}")
   
   valid_sequences = []
   skipped_padding = 0
   skipped_n = 0
   
   # Validate sequences
   for _, row in tqdm(df.iterrows()):
       seq = row["Sequence"]
       
       # Skip if starts with N
       if seq[0:5] == "NNNNN":
           skipped_n += 1
           continue
           
       # Check padding needed
       padding_needed = 8192 - len(seq)
       if padding_needed > max_padding:
           skipped_padding += 1
           continue
           
       valid_sequences.append(row)
       
   print(f"Sequences after filtering:")
   print(f"Skipped due to N's: {skipped_n}")
   print(f"Skipped due to padding: {skipped_padding}")
   print(f"Valid sequences: {len(valid_sequences)}")
   
   # Save valid sequences in batches
   batch_size = 128
   for i in range(0, len(valid_sequences), batch_size):
       batch = valid_sequences[i:i+batch_size]
       output_file = f"gs://minformer_data/diverse_genomes_tf/{species_name}/tfrecords/record_{i//batch_size}.tfrecord"
       
       with tf.io.TFRecordWriter(output_file) as writer:
           for row in batch:
               # Tokenize and pad
               tokens = tokenize_dna(row["Sequence"])
               padding_needed = 8192 - len(tokens)
               tokens = np.pad(tokens, (0, padding_needed))  # Pad end only
               segment_ids = np.ones(8192)
               
               # Create TF Example
               example = tf.train.Example(
                   features=tf.train.Features(
                       feature={
                           "x": tf.train.Feature(int64_list=tf.train.Int64List(value=tokens)),
                           "segment_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids))
                       }
                   )
               )
               writer.write(example.SerializeToString())

In [9]:
# Test processing
test_species = "human_genome_8192bp_bins_no_N"
process_csv_to_tfrecord(test_species)

Processing human_genome_8192bp_bins_no_N
Total sequences: 357933


357933it [00:36, 9923.15it/s] 


Sequences after filtering:
Skipped due to N's: 0
Skipped due to padding: 0
Valid sequences: 357933


In [11]:
import tensorflow as tf
import numpy as np
import os

# These should match exactly what's in diverse_genomes_tfs.py
PADDING = "P"
VOCAB = [PADDING, "A", "C", "G", "T", "N"]
VOCAB_SIZE = len(VOCAB)
stoi = {ch: i for i, ch in enumerate(VOCAB)}
itos = {i: ch for i, ch in enumerate(VOCAB)}

def inspect_first_species():
    species = 'Bradyrhizobium_japonicum_8192bp_bins_no_N'
    file_pattern = f"gs://minformer_data/diverse_genomes_tf_v2/{species}/tfrecords/record_0.tfrecord"
    
    try:
        dataset = tf.data.TFRecordDataset([file_pattern])
        record = next(iter(dataset))
        example = tf.train.Example()
        example.ParseFromString(record.numpy())
        tokens = list(example.features.feature['x'].int64_list.value)
        segment_ids = list(example.features.feature['segment_ids'].int64_list.value)
        
        print(f"\nInspecting {species}:")
        print("\nSequence 1:")
        print("First 50 tokens:", tokens[:50])
        print("First 50 segment_ids:", segment_ids[:50])
        token_counts = dict(zip(*np.unique(tokens, return_counts=True)))
        print("Token counts:", token_counts)
        
        # Use our tokenization scheme to verify
        decoded_seq = ''.join(itos[t] for t in tokens[:50])
        print(f"Decoded sequence: {decoded_seq}...")
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")

inspect_first_species()


Inspecting Bradyrhizobium_japonicum_8192bp_bins_no_N:

Sequence 1:
First 50 tokens: [4, 4, 1, 1, 4, 4, 1, 1, 4, 1, 3, 4, 2, 4, 4, 4, 3, 1, 2, 4, 3, 2, 1, 1, 4, 1, 2, 4, 3, 3, 3, 2, 3, 1, 4, 1, 4, 3, 1, 4, 2, 2, 3, 3, 1, 1, 3, 2, 3, 2]
First 50 segment_ids: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Token counts: {1: 1621, 2: 2386, 3: 2561, 4: 1624}
Decoded sequence: TTAATTAATAGTCTTTGACTGCAATACTGGGCGATATGATCCGGAAGCGC...
