In [41]:
import numpy as np
from datasets import Dataset, concatenate_datasets
import torch
import os
import random
import pandas as pd
from rich import columns

In [42]:
def load_ds_from_dirs_flattening_timesteps(
        path: str, columns, dtype, n_shards_per_timestep: int | None = None
) -> Dataset:
    datasets = []
    for timestep_dir_name in os.listdir(path):
        timestep_dir_path = os.path.join(path, timestep_dir_name)
        ds_dir_names = os.listdir(timestep_dir_path)
        if n_shards_per_timestep:
            ds_dir_names = random.sample(ds_dir_names, n_shards_per_timestep)
        for example_dir_name in ds_dir_names:
            example_dir_path = os.path.join(timestep_dir_path, example_dir_name)
            ds = Dataset.load_from_disk(example_dir_path, keep_in_memory=False)
            ds.set_format(type="torch", columns=columns, dtype=dtype)
            structure_id = [example_dir_name] * len(ds)
            timestep = [timestep_dir_name] * len(ds)
            ds = ds.add_column("Sequence_Id", structure_id)
            ds = ds.add_column("Timestep", timestep)
            datasets.append(ds)
        print(f"processed {timestep_dir_name}")
    return concatenate_datasets(datasets)


ds = load_ds_from_dirs_flattening_timesteps(
    # "/home/wzarzecki/ds_sae_latents_1600x/activations/block4_non_pair",
    "/home/wzarzecki/ds_sae_latents_1600x/latents/non_pair",
    # "/home/wzarzecki/ds_sae_latents_1600x/tiny_debug_activations/block4_non_pair",
    # "/home/wzarzecki/ds_sae_latents_1600x/tiny_debug_latents/non_pair",
    ["values"], torch.float32)
ds, ds[0]

processed 49
processed 35
processed 34
processed 24
processed 19
processed 21
processed 13
processed 7
processed 22
processed 45
processed 3
processed 39
processed 48
processed 17
processed 27
processed 25
processed 38
processed 36
processed 47
processed 43
processed 50
processed 1
processed 20
processed 44
processed 40
processed 8
processed 4
processed 11
processed 23
processed 5
processed 37
processed 14
processed 9
processed 12
processed 31
processed 33
processed 41
processed 10
processed 18
processed 32
processed 15
processed 16
processed 29
processed 6
processed 46
processed 28
processed 42
processed 26
processed 30
processed 2


(Dataset({
     features: ['values', 'subcellular', 'solubility', 'Sequence_Id', 'Timestep'],
     num_rows: 117161
 }),
 {'values': tensor([0., 0., 0.,  ..., 0., 0., 0.]),
  'Sequence_Id': 'xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694',
  'Timestep': '49'})

In [43]:
len(ds)

117161

In [51]:
labels_df = pd.read_csv("/home/wzarzecki/ds_sae_latents_1600x/classifiers.csv")
labels_df.columns

Index(['Sequence', 'Sequence_Id', 'Subcellular Localization',
       'Solubility/Membrane-boundness'],
      dtype='object')

In [52]:
df = ds.to_pandas()
df.columns

Index(['values', 'subcellular', 'solubility', 'Sequence_Id', 'Timestep'], dtype='object')

In [53]:
merged_df = pd.merge(df, labels_df, on="Sequence_Id", how="inner")
merged_df.columns

Index(['values', 'subcellular', 'solubility', 'Sequence_Id', 'Timestep',
       'Sequence', 'Subcellular Localization',
       'Solubility/Membrane-boundness'],
      dtype='object')

In [54]:
subcellular_localization_values = merged_df["Subcellular Localization"].unique()
def ovr_label_row_cytoplasm(row):
    return row["Subcellular Localization"]=="Cytoplasm"
def ovr_label_row_nucleus(row):
    return row["Subcellular Localization"]=="Nucleus"
merged_df["Cytoplasm"] = merged_df.apply(ovr_label_row_cytoplasm, axis=1)
merged_df["Nucleus"] = merged_df.apply(ovr_label_row_nucleus, axis=1)
merged_df.head()

Unnamed: 0,values,subcellular,solubility,Sequence_Id,Timestep,Sequence,Subcellular Localization,Solubility/Membrane-boundness,Cytoplasm,Nucleus
0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,Soluble,xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694,49,MVKVKIKVTVKVESPEIDPEFRKKVEALQAEREALKARGEKDELDE...,Cytoplasm,Soluble,True,False
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,Soluble,xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694,49,MVKVKIKVTVKVESPEIDPEFRKKVEALQAEREALKARGEKDELDE...,Cytoplasm,Soluble,True,False
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,Soluble,xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694,49,MVKVKIKVTVKVESPEIDPEFRKKVEALQAEREALKARGEKDELDE...,Cytoplasm,Soluble,True,False
3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,Soluble,xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694,49,MVKVKIKVTVKVESPEIDPEFRKKVEALQAEREALKARGEKDELDE...,Cytoplasm,Soluble,True,False
4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,Soluble,xxx_58_cbf7ea1a-a825-42d4-a7f5-566693feb694,49,MVKVKIKVTVKVESPEIDPEFRKKVEALQAEREALKARGEKDELDE...,Cytoplasm,Soluble,True,False


In [55]:
found_timesteps = np.unique(merged_df["Timestep"].values)
timestep_datasets = []
for timestep in found_timesteps:
    timestep_datasets.append(merged_df[merged_df["Timestep"]==str(timestep)][["values", "Subcellular Localization", "Cytoplasm", "Nucleus"]])
timestep_datasets[1].head()

Unnamed: 0,values,Subcellular Localization,Cytoplasm,Nucleus
86674,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,True,False
86675,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,True,False
86676,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,True,False
86677,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,True,False
86678,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Cytoplasm,True,False


In [56]:
len(timestep_datasets)

50

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score

X = np.stack(merged_df['values'].apply(lambda x: x.flatten()).values)
y = merged_df['Cytoplasm'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = LogisticRegression(
    max_iter=1000,
    solver='lbfgs',
    class_weight='balanced',
    random_state=42
)

clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))
print(f"roc auc {roc_auc_score(y_test, y_pred)}")

In [24]:
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader

X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to('cuda')
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1).to('cuda')
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to('cuda')
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1).to('cuda')

# Create DataLoader
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)

In [25]:
class MLP(nn.Module):
    def __init__(self, input_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        return x

model = MLP(input_dim = X_train.shape[1]).to('cuda')  # input_dim = 98304

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-3)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * batch_X.size(0)

    avg_loss = running_loss / len(train_loader.dataset)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 10/100, Loss: 0.6240
Epoch 20/100, Loss: 0.6208
Epoch 30/100, Loss: 0.6187
Epoch 40/100, Loss: 0.6177
Epoch 50/100, Loss: 0.6150
Epoch 60/100, Loss: 0.6111
Epoch 70/100, Loss: 0.6099
Epoch 80/100, Loss: 0.6100
Epoch 90/100, Loss: 0.6073
Epoch 100/100, Loss: 0.6069


In [27]:
model.eval()
with torch.no_grad():
    y_pred = model(X_test_tensor).squeeze()

test accuracy: 0.0


In [35]:
from sklearn.metrics import roc_auc_score
roc_auc_score(y_test_tensor.cpu(), y_pred.cpu())

np.float64(0.6082172442348608)