# Dataset for stride structures
- found some previouly saved latents and structures
- latents from pair SAE were randomly sampled, so information about token id was lost => impossible to map to index from stride sequence
- final ds columns: **values, strucuture_id, token_id, timestep, secondary_struct, helix**

In [5]:
import subprocess
import glob
import os
from datasets import Dataset
import pandas as pd

def run_stride(pdb_path: str, output_path: str, stride_path):
    subprocess.run(
        f"{stride_path} -o {pdb_path} > {output_path}",
        shell=True,
        capture_output=True,
        text=True,
        timeout=30,
    )

## generate stide outputs


In [None]:



def parse_directory(dir_with_pdb: str, dir_for_stride, stride_binary: str) -> None:
    os.makedirs(dir_for_stride, exist_ok=True)
    for pdb_file in glob.glob(f"{dir_with_pdb}/*.pdb"):
        base_name = os.path.basename(pdb_file)
        stride_file_name = base_name.replace(".pdb", ".stride")
        stride_file = os.path.join(dir_for_stride, stride_file_name)
        run_stride(pdb_file, stride_file, stride_binary)

In [3]:
pdb_dir = "/home/wzarzecki/ds_secondary_struct/structures"
stride_dir = "/home/wzarzecki/ds_secondary_struct/stride"
stride_binary = "/data/wzarzecki/SAEtoRuleRFDiffusion/stride/stride"

In [8]:


parse_directory(pdb_dir, stride_dir, stride_binary)

## merge datasets
from

In [None]:


example_ds = Dataset.load_from_disk(
    "/home/wzarzecki/ds_secondary_struct/latents/non_pair/1/xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a")
example_ds.set_format("torch")
example_ds

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['values', 'subcellular', 'solubility'],
    num_rows: 187
})

In [11]:
example_ds[0]

{'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 'subcellular': 'Nucleus',
 'solubility': 'Soluble'}

In [None]:



def merge_datasets(base_dir: str, save_path: str):
    merged_rows = []
    for timestep_dir in sorted(os.listdir(base_dir)):
        timestep_path = os.path.join(base_dir, timestep_dir)
        if not os.path.isdir(timestep_path):
            continue
        for struct_id in sorted(os.listdir(timestep_path)):
            struct_path = os.path.join(timestep_path, struct_id)
            if not os.path.isdir(struct_path):
                continue
            try:
                ds = Dataset.load_from_disk(struct_path)
            except Exception as e:
                print(f"Failed to load {struct_path}: {e}")
                continue
            df = ds.to_pandas()
            df = df.drop(columns=[col for col in ['subcellular', 'solubility'] if col in df.columns], errors='ignore')
            df['structure_id'] = struct_id
            df['timestep_id'] = int(timestep_dir)
            df['token_id'] = df.index
            merged_rows.append(df)
        print(f"processed {timestep_dir}")
    if not merged_rows:
        print("No datasets found to merge.")
        return
    merged_df = pd.concat(merged_rows, ignore_index=True)
    # 'values' column: if only one column left, rename it to 'values', else keep as is
    value_cols = [col for col in merged_df.columns if col not in ['structure_id', 'timestep_id', 'token_id']]
    if len(value_cols) == 1 and value_cols[0] != 'values':
        merged_df = merged_df.rename(columns={value_cols[0]: 'values'})
    merged_ds = Dataset.from_pandas(merged_df)
    merged_ds.save_to_disk(save_path)
    print(f"Merged dataset saved to {save_path}")

In [7]:
base_latents_dir = "/home/wzarzecki/ds_secondary_struct/latents/non_pair"
merged_ds_dir = "/home/wzarzecki/ds_secondary_struct/merged_latents"

In [None]:


merge_datasets(base_latents_dir, merged_ds_dir)

processed 1
processed 10
processed 11
processed 12
processed 13
processed 14
processed 15
processed 16
processed 17
processed 18
processed 19
processed 2


Saving the dataset (28/28 shards): 100%|██████████| 732450/732450 [01:35<00:00, 7683.99 examples/s] 


Merged dataset saved to /home/wzarzecki/ds_secondary_struct/merged_latents


In [8]:
merged_ds = Dataset.load_from_disk(merged_ds_dir)
merged_ds.set_format("torch")
merged_ds

Loading dataset from disk:   0%|          | 0/28 [00:00<?, ?it/s]

Dataset({
    features: ['values', 'structure_id', 'timestep_id', 'token_id'],
    num_rows: 732450
})

In [9]:
merged_ds[0]

{'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 'structure_id': 'xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a',
 'timestep_id': tensor(1),
 'token_id': tensor(0)}

## add secondary structure labels

In [10]:
import os
import re
from datasets import Dataset

def parse_stride_file(file_path: str) -> str:
    """
    Parses a STRIDE file to extract secondary structure strings based on
    the amino acid sequence location in corresponding 'SEQ' lines.
    """
    structure_parts = []
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()

        for i, line in enumerate(lines):
            if line.startswith("SEQ"):
                # Find the start index of the amino acid sequence.
                # It's the first non-whitespace character after the initial residue number.
                match = re.search(r'SEQ\s+\d+\s+(\S)', line)
                if not match:
                    continue
                
                start_index = match.start(1)

                # Find the end index of the sequence.
                # It's the last non-whitespace character before the final residue number.
                match = re.search(r'(\S)\s+\d+\s*~*$', line)
                if not match:
                    continue
                
                end_index = match.start(1) + 1

                # Check for the STR line immediately following the SEQ line
                if i + 1 < len(lines) and lines[i+1].startswith("STR"):
                    str_line = lines[i+1]
                    # Extract the structure part using the indices from the SEQ line
                    structure_part = str_line[start_index:end_index]
                    structure_parts.append(structure_part)

    except FileNotFoundError:
        print(f"Warning: Stride file not found: {file_path}")
        return ""
    return "".join(structure_parts)

def add_secondary_struct_column(ds: Dataset, stride_dir: str) -> Dataset:
    """
    Adds a 'secondary_struct' column to the dataset by mapping token_id to
    secondary structure information from STRIDE files.
    """
    stride_cache = {}

    def get_secondary_structure(example):
        structure_id = example['structure_id']
        token_id = example['token_id']

        if structure_id not in stride_cache:
            stride_file_path = os.path.join(stride_dir, f"{structure_id}.stride")
            stride_cache[structure_id] = parse_stride_file(stride_file_path)

        full_ss_string = stride_cache[structure_id]

        if token_id < len(full_ss_string):
            label = full_ss_string[token_id]
            return label if label.strip() != '' else None
        return None

    return ds.map(lambda example: {'secondary_struct': get_secondary_structure(example)})


In [20]:
stride_str = parse_stride_file("/home/wzarzecki/ds_secondary_struct/stride/xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a.stride")
len(stride_str), stride_str[:5], stride_str[-2:]

(187, '    E', 'H ')

In [16]:
mini_stride_ds = add_secondary_struct_column(merged_ds.take(10), stride_dir)
mini_stride_ds

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Dataset({
    features: ['values', 'structure_id', 'timestep_id', 'token_id', 'secondary_struct'],
    num_rows: 10
})

In [19]:
mini_stride_ds[0], mini_stride_ds[9]

({'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
  'structure_id': 'xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a',
  'timestep_id': tensor(1),
  'token_id': tensor(0),
  'secondary_struct': None},
 {'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
  'structure_id': 'xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a',
  'timestep_id': tensor(1),
  'token_id': tensor(9),
  'secondary_struct': 'E'})

In [None]:
stride_ds_path = "/home/wzarzecki/ds_secondary_struct/stride_ds"

In [None]:
stride_ds = add_secondary_struct_column(merged_ds, stride_dir)
stride_ds.save_to_disk(stride_ds_path)

Map:   0%|          | 0/732450 [00:00<?, ? examples/s]

Saving the dataset (0/28 shards):   0%|          | 0/732450 [00:00<?, ? examples/s]

In [22]:
from datasets import Dataset

def add_helix_column(ds: Dataset) -> Dataset:
    """
    Adds a 'helix' column to the dataset based on the 'secondary_struct' column.
    The 'helix' column is True if 'secondary_struct' is 'G', 'H', or 'I', otherwise False.
    """
    helix_letters = {'G', 'H', 'I'}
    def check_if_helix(example):
        ss = example.get('secondary_struct')
        return {'helix': ss is not None and ss in helix_letters}

    return ds.map(check_if_helix)

helix_ds = add_helix_column(stride_ds)
helix_ds_path = "/home/wzarzecki/ds_secondary_struct/helix_ds"
helix_ds.save_to_disk(helix_ds_path)

Map:   0%|          | 0/732450 [00:00<?, ? examples/s]

Saving the dataset (0/28 shards):   0%|          | 0/732450 [00:00<?, ? examples/s]

In [30]:
helix_ds[0], helix_ds[48]

({'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
  'structure_id': 'xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a',
  'timestep_id': tensor(1),
  'token_id': tensor(0),
  'secondary_struct': None,
  'helix': tensor(False)},
 {'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
  'structure_id': 'xxx_0_c4e8cf7d-8850-486c-8a9a-76f0b46a751a',
  'timestep_id': tensor(1),
  'token_id': tensor(48),
  'secondary_struct': 'H',
  'helix': tensor(True)})