In [2]:
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)

tensor([[[ 0.0046,  0.0108, -0.0157,  ..., -0.0179, -0.0226,  0.0497],
         [ 0.0063,  0.0288, -0.0023,  ...,  0.0114,  0.0188,  0.0328],
         [ 0.0366, -0.0113, -0.1003,  ...,  0.0279, -0.0380,  0.0105],
         ...,
         [-0.0461,  0.0007,  0.0366,  ...,  0.0532, -0.0284, -0.0347],
         [-0.0073,  0.0311, -0.0177,  ...,  0.0332,  0.0037,  0.0007],
         [-0.0008, -0.0456,  0.0199,  ..., -0.0099, -0.0158, -0.0012]],

        [[ 0.0254,  0.0037, -0.0638,  ...,  0.0042,  0.0123, -0.0212],
         [-0.0255, -0.0134, -0.0106,  ..., -0.0006, -0.0259, -0.0111],
         [-0.0008,  0.0295,  0.0479,  ..., -0.0059,  0.0114, -0.0043],
         ...,
         [ 0.0026, -0.0040, -0.0070,  ..., -0.0498, -0.0085,  0.0358],
         [ 0.0534, -0.0341, -0.0274,  ..., -0.0974, -0.0472,  0.0185],
         [-0.0236, -0.0280,  0.0078,  ...,  0.0251,  0.0085, -0.0092]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)


In [None]:
# tensor([[[ 0.0046,  0.0108, -0.0157,  ..., -0.0179, -0.0226,  0.0497],
#          [ 0.0063,  0.0288, -0.0023,  ...,  0.0114,  0.0188,  0.0328],
#          [ 0.0366, -0.0113, -0.1003,  ...,  0.0279, -0.0380,  0.0105],
#          ...,
#          [-0.0461,  0.0007,  0.0366,  ...,  0.0532, -0.0284, -0.0347],
#          [-0.0073,  0.0311, -0.0177,  ...,  0.0332,  0.0037,  0.0007],
#          [-0.0008, -0.0456,  0.0199,  ..., -0.0099, -0.0158, -0.0012]],

#         [[ 0.0254,  0.0037, -0.0638,  ...,  0.0042,  0.0123, -0.0212],
#          [-0.0255, -0.0134, -0.0106,  ..., -0.0006, -0.0259, -0.0111],
#          [-0.0008,  0.0295,  0.0479,  ..., -0.0059,  0.0114, -0.0043],
#          ...,
#          [ 0.0026, -0.0040, -0.0070,  ..., -0.0498, -0.0085,  0.0358],
#          [ 0.0534, -0.0341, -0.0274,  ..., -0.0974, -0.0472,  0.0185],
#          [-0.0236, -0.0280,  0.0078,  ...,  0.0251,  0.0085, -0.0092]]],
#        device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [3]:
import sys
import os
import gc
import copy
import yaml
import pickle
import random
import joblib 
import shutil
from time import time
import typing as tp
from pathlib import Path
import psutil

import numpy as np
import pandas as pd
import scipy

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.metrics import average_precision_score as APS
import duckdb


import torch
import torchvision
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
from torch.nn import BCELoss
from torch.utils.data import Dataset


import timm
from mamba_ssm import Mamba
from transformers import AutoModel, AutoTokenizer

import albumentations as A
from albumentations.pytorch import ToTensorV2


# use one device only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
con = duckdb.connect()

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.11 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [4]:


class CFG:
    TEST_No = 1
    TEST_NUM = int(1674896/32 * TEST_No)
    TEST_OFFSET = int(TEST_NUM * (TEST_No-1))
    TRAIN_CLM_PATH = Path('/root/Kaggle_NeurIPS2024/data/processed/20000_50per_CLM.parquet')
    TEST_CLM_PATH = Path(f'/root/Kaggle_NeurIPS2024/data/processed/test_CLM_{TEST_OFFSET}_to_{TEST_NUM}.parquet')
    TRAIN_ENC_PATH = Path('/root/Kaggle_NeurIPS2024/data/external/train_enc.parquet')
    TEST_ENC_PATH = Path('/root/Kaggle_NeurIPS2024/data/external/test_enc.parquet')
    TRAIN_PATH = Path('/root/Kaggle_NeurIPS2024/data/raw/train.parquet')
    TEST_PATH = Path('/root/Kaggle_NeurIPS2024/data/raw/test.parquet')

In [11]:
test = con.query(f"""(SELECT *
                        FROM parquet_scan('{CFG.TEST_PATH}')
                        )""").df()
test.head()

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name
0,295246830,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,BRD4
1,295246831,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,HSA
2,295246832,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,sEH
3,295246833,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,BRD4
4,295246834,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,HSA


In [6]:
test["buildingblock1_smiles"].describe()


count                                               10000
unique                                                  3
top       C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21
freq                                                 4227
Name: buildingblock1_smiles, dtype: object

In [12]:
test['length1'] = test['buildingblock1_smiles'].apply(len)
test['length2'] = test['buildingblock2_smiles'].apply(len)
test['length3'] = test['buildingblock3_smiles'].apply(len)
display(test['length1'].describe())
display(test['length2'].describe())
display(test['length3'].describe())

count    1.674896e+06
mean     5.110863e+01
std      6.031255e+00
min      3.500000e+01
25%      4.600000e+01
50%      5.000000e+01
75%      5.600000e+01
max      6.900000e+01
Name: length1, dtype: float64

count    1.674896e+06
mean     2.158281e+01
std      7.422572e+00
min      7.000000e+00
25%      1.600000e+01
50%      2.000000e+01
75%      2.600000e+01
max      5.400000e+01
Name: length2, dtype: float64

count    1.674896e+06
mean     2.003411e+01
std      5.990853e+00
min      7.000000e+00
25%      1.500000e+01
50%      2.000000e+01
75%      2.400000e+01
max      4.300000e+01
Name: length3, dtype: float64