# Import 

In [1]:
import os
import math
import random
import numpy as np
import pandas as pd
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForMaskedLM

  from .autonotebook import tqdm as notebook_tqdm


# Setting

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

Device: cuda


# Data Load 

In [None]:
data_path = './data/'

In [5]:
df = pd.read_csv(data_path + 'test.csv')

In [6]:
# 데이터 최대 길이 확인
max_seq_len = df["seq"].str.len().max()
print(f"✅ Rows = {len(df):,}, Max sequence length = {max_seq_len}")

✅ Rows = 13,711, Max sequence length = 1024


# Model Load

In [7]:
BATCH_SIZE = 64
NUM_WORKERS = 2

In [8]:
MODEL_ID = "InstaDeepAI/nucleotide-transformer-v2-500m-multi-species"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID, trust_remote_code=True)
model = model.to(DEVICE).eval()

In [9]:
MODEL_CAP = tokenizer.model_max_length 
EFFECTIVE_MAX_LEN = min(MODEL_CAP, max_seq_len)

# Dataset Define

In [10]:
class SeqDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.ids  = df["ID"].tolist()
        self.seqs = df["seq"].tolist()
        self.tok  = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        return {"ID": self.ids[idx], "seq": self.seqs[idx]}

def collate_fn(batch, tok=tokenizer, max_len=EFFECTIVE_MAX_LEN):
    ids  = [b["ID"] for b in batch]
    seqs = [b["seq"] for b in batch]
    enc  = tok.batch_encode_plus(
        seqs,
        return_tensors="pt",
        padding="longest",          
        truncation=True,
        max_length=max_len
    )
    # attention_mask: pad 토큰이 0
    return {
        "ids": ids,
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"]
    }

In [11]:
dataset = SeqDataset(df, tokenizer, EFFECTIVE_MAX_LEN)
loader  = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False,
                     num_workers=NUM_WORKERS, collate_fn=collate_fn)
print("✅ Dataloader ready.")

✅ Dataloader ready.


# Inference 

In [12]:
all_ids = []
all_embs = []
use_amp = (DEVICE == "cuda")

with torch.no_grad():
    for batch in loader:
        input_ids = batch["input_ids"].to(DEVICE)
        attn_mask = batch["attention_mask"].to(DEVICE)

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
            outs = model(
                input_ids,
                attention_mask=attn_mask,
                encoder_attention_mask=attn_mask,   
                output_hidden_states=True,
            )
            # 마지막 히든스테이트: (B, L, H)
            last_hidden = outs.hidden_states[-1]    # torch.Tensor

            # mask 모양 맞추기: (B, L, 1)
            mask_exp = attn_mask.unsqueeze(-1)      # 1 for valid tokens

            # 패딩 제외 평균: sum(hidden * mask) / sum(mask)
            summed = (last_hidden * mask_exp).sum(dim=1)                    # (B, H)
            counts = mask_exp.sum(dim=1).clamp(min=1)                       # (B, 1)
            seq_emb = summed / counts                                       # (B, H)

        all_ids.extend(batch["ids"])
        all_embs.append(seq_emb.detach().cpu())

emb = torch.vstack(all_embs).float()        # (N, H)
N, H = emb.shape
print(f"✅ Embedding shape = {N} x {H}")

✅ Embedding shape = 13711 x 1024


# Submission

In [13]:
sample_submission = pd.read_csv(data_path + 'sample_submission.csv')

In [14]:
emb_np = emb.numpy()
emb_cols = [f"emb_{i:04d}" for i in range(emb_np.shape[1])]
emb_df = pd.DataFrame(emb_np, columns=emb_cols)

In [15]:
submission = pd.concat([sample_submission['ID'], emb_df], axis=1)
submission.to_csv('baseline_submission.csv', index=False)