In [1]:
import sys
import os
import gc
import copy
import yaml
import pickle
import random
import joblib 
import shutil
from time import time
import typing as tp
from pathlib import Path
import psutil

import numpy as np
import pandas as pd
import scipy

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.metrics import average_precision_score as APS
import duckdb


import torch
import torchvision
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
from torch.nn import BCELoss
from torch.utils.data import Dataset


import timm
from mamba_ssm import Mamba
from transformers import AutoModel, AutoTokenizer

import albumentations as A
from albumentations.pytorch import ToTensorV2


# use one device only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
con = duckdb.connect()

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.11 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
# TRAIN_ENC_PATH = Path('../data/external/train_enc.parquet')
# train = pd.read_parquet(TRAIN_ENC_PATH)
TEST_ENC_PATH = Path('../data/external/test_enc.parquet')
test = pd.read_parquet(TEST_ENC_PATH)

In [5]:
display(train.head())
display(train.tail())
print(len(train))

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141,bind1,bind2,bind3
0,8,22,8,8,28,12,27,12,12,12,17,8,33,12,18,35,12,17,33,8,8,4,8,8,8,33,4,12,4,12,12,12,35,35,4,19,35,12,17,33,29,8,3,3,5,32,17,8,8,22,8,19,8,8,17,26,28,19,33,29,30,2,32,19,35,18,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,8,22,8,8,28,12,27,12,12,12,17,8,33,12,18,35,12,17,33,8,12,4,12,12,12,12,17,31,9,19,35,4,19,35,12,17,33,29,8,3,3,5,32,17,8,8,22,8,19,8,8,17,26,28,19,33,29,30,2,32,19,35,18,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,8,22,8,8,28,12,27,12,12,12,17,8,33,12,18,35,12,17,33,8,12,4,12,12,12,17,28,8,8,22,8,19,12,12,4,19,35,12,17,33,29,8,3,3,5,32,17,8,8,22,8,19,8,8,17,26,28,19,33,29,30,2,32,19,35,18,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,8,22,8,8,28,12,27,12,12,12,17,8,33,12,18,35,12,17,33,8,8,33,8,17,26,28,19,8,17,26,8,19,8,19,35,12,17,33,29,8,3,3,5,32,17,8,8,22,8,19,8,8,17,26,28,19,33,29,30,2,32,19,35,18,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,8,22,8,8,28,12,27,12,12,12,17,8,33,12,18,35,12,17,33,8,8,17,26,28,19,33,8,8,26,8,19,35,12,17,33,29,8,3,3,5,32,17,8,8,22,8,19,8,8,17,26,28,19,33,29,30,2,32,19,35,18,19,12,12,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141,bind1,bind2,bind3
98415605,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,12,17,33,12,18,35,12,35,12,4,12,18,35,12,35,4,8,18,8,8,8,8,28,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
98415606,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,35,35,12,17,6,19,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
98415607,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,35,35,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
98415608,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,35,35,35,29,35,5,32,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
98415609,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


98415610


In [3]:
display(test.head())
display(test.tail())
print(len(test))

Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141
0,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,35,12,35,12,4,12,18,35,12,35,4,8,8,17,8,19,28,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,8,22,8,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,12,18,12,12,12,17,8,26,8,19,12,12,18,19,35,12,17,33,12,18,35,12,35,12,4,12,18,35,12,35,4,8,8,17,8,19,28,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,enc0,enc1,enc2,enc3,enc4,enc5,enc6,enc7,enc8,enc9,enc10,enc11,enc12,enc13,enc14,enc15,enc16,enc17,enc18,enc19,enc20,enc21,enc22,enc23,enc24,enc25,enc26,enc27,enc28,enc29,enc30,enc31,enc32,enc33,enc34,enc35,enc36,enc37,enc38,enc39,enc40,enc41,enc42,enc43,enc44,enc45,enc46,enc47,enc48,enc49,enc50,enc51,enc52,enc53,enc54,enc55,enc56,enc57,enc58,enc59,enc60,enc61,enc62,enc63,enc64,enc65,enc66,enc67,enc68,enc69,enc70,enc71,enc72,enc73,enc74,enc75,enc76,enc77,enc78,enc79,enc80,enc81,enc82,enc83,enc84,enc85,enc86,enc87,enc88,enc89,enc90,enc91,enc92,enc93,enc94,enc95,enc96,enc97,enc98,enc99,enc100,enc101,enc102,enc103,enc104,enc105,enc106,enc107,enc108,enc109,enc110,enc111,enc112,enc113,enc114,enc115,enc116,enc117,enc118,enc119,enc120,enc121,enc122,enc123,enc124,enc125,enc126,enc127,enc128,enc129,enc130,enc131,enc132,enc133,enc134,enc135,enc136,enc137,enc138,enc139,enc140,enc141
1674891,8,28,8,27,8,8,8,17,8,8,33,12,18,35,12,17,33,12,4,35,13,12,25,12,12,12,17,7,19,12,12,4,25,19,35,12,17,33,29,8,3,3,5,32,17,8,8,8,33,26,29,33,14,32,26,29,33,36,32,19,8,17,26,28,19,33,29,30,2,32,19,35,18,19,8,8,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674892,8,28,8,27,8,8,8,17,8,8,33,12,18,35,12,17,33,12,4,35,13,12,25,12,12,12,17,7,19,12,12,4,25,19,35,12,17,33,29,8,3,3,5,32,17,8,8,8,33,26,29,33,14,32,26,29,33,36,32,19,8,17,26,28,19,33,29,30,2,32,19,35,18,19,8,8,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674893,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674894,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1674895,29,33,36,32,26,29,33,14,32,26,33,8,8,8,29,8,3,5,32,17,33,12,27,35,12,17,33,8,12,18,12,12,12,10,18,19,35,12,17,33,12,18,35,13,12,4,12,12,12,17,7,19,12,12,18,4,19,35,27,19,8,17,26,28,19,33,29,30,2,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


1674896


In [5]:
import pandas as pd

# Example DataFrame
data = {
    0: [1, 0, 0, 0, 0, 0],
    1: [0, 0, 1, 0, 1, 0],
    2: [1, 0, 2, 1, 1, 3]
}
df = pd.DataFrame(data)

# Function to extract values based on the specified pattern
def extract_values(df):
    extracted_values = []
    num_rows = len(df)
    for i in range(num_rows):
        column_index = i % 3  # Cycle through columns 0, 1, 2
        extracted_values.append(df.iloc[i, column_index])
    return extracted_values

# Extract values
extracted_values = extract_values(df)

# Create a new DataFrame with the extracted values
reshaped_df = pd.DataFrame({'binds': extracted_values})

# Display the reshaped DataFrame
print(reshaped_df)


   binds
0      1
1      0
2      2
3      0
4      1
5      3


tokenizer_config.json:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/50.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [22]:
import pandas as pd

test_raw = "/root/Kaggle_NeurIPS2024/data/raw/test.parquet"
test_df = pd.read_parquet(test_raw)
test_df.head()

test = test_df["molecule_smiles"][1]
print(test)
print(len(test))

C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy]
64


In [5]:
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
import torch.nn as nn

print(model)
vocab_size = model.lm_head.out_features
hidden_size = model.lm_head.in_features

new_output_size = 3  # or whatever size you need
new_head = nn.Linear(hidden_size, new_output_size)
model.lm_head = new_head

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 2560)
    (layers): ModuleList(
      (0-63): 64 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
          (act): SiLU()
          (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
          (x_proj): Linear(in_features=5120, out_features=192, bias=False)
          (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
          (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=50280, bias=False)
)


In [None]:
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")

test_raw = "/root/Kaggle_NeurIPS2024/data/raw/test.parquet"
test_df = pd.read_parquet(test_raw)

In [19]:

test = test_df["molecule_smiles"][1]
print(len(test)
encoded_input = tokenizer.encode(test, return_tensors='pt')
print(encoded_input.shape)
# model = Mamba(
#     # This module uses roughly 3 * expand * d_model^2 parameters
#     d_model=dim, # Model dimension d_model
#     d_state=16,  # SSM state expansion factor
#     d_conv=4,    # Local convolution width
#     expand=2,    # Block expansion factor
# ).to("cuda")
y = model(encoded_input)
# assert y.shape == x.shape
print(y)

64
torch.Size([1, 52])
MambaCausalLMOutput(loss=None, logits=tensor([[[0.8130, 2.9081, 0.3035],
         [1.7261, 1.6294, 0.2114],
         [0.9585, 1.3269, 0.8121],
         [1.9681, 0.5043, 0.9812],
         [1.1408, 1.9614, 0.6307],
         [0.9342, 0.6813, 0.8935],
         [1.8693, 1.5248, 1.1185],
         [1.2571, 1.0376, 0.9497],
         [0.9230, 1.5591, 0.7519],
         [1.1453, 1.7103, 0.7646],
         [2.0499, 1.4489, 1.1660],
         [2.0921, 1.6491, 0.7394],
         [2.3179, 1.5648, 0.7229],
         [1.5479, 2.0764, 0.6939],
         [1.4725, 1.6644, 0.6814],
         [2.2119, 1.2821, 0.7466],
         [1.7882, 1.1790, 0.5899],
         [2.0568, 1.1002, 0.7091],
         [1.6804, 1.7699, 0.5188],
         [0.5277, 0.6734, 0.6331],
         [1.4521, 1.7179, 0.6676],
         [1.8872, 1.2086, 0.6913],
         [1.9374, 1.5089, 0.8574],
         [1.5472, 1.5038, 1.0881],
         [1.9928, 1.4917, 0.9543],
         [1.8801, 1.2907, 0.4936],
         [2.1358, 1.2517, 0.6

In [20]:
nump = y.logits.detach().cpu().numpy()
nump.shape

(1, 52, 3)

In [3]:
from transformers import AutoTokenizer

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")

# Encode input text
text = "This is an example sentence."
encoded_input = tokenizer.encode(text, return_tensors='pt')  # 'pt' for PyTorch, use 'tf' for TensorFlow

# Print encoded input
print("Encoded input:", encoded_input)

# Decode token IDs back to text
decoded_text = tokenizer.decode(encoded_input[0])
print("Decoded text:", decoded_text)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Encoded input: tensor([[1552,  310,  271, 1650, 6197,   15]])
Decoded text: This is an example sentence.
