In [None]:
!pip install transformers biopython torch

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
import numpy as np
from tqdm import tqdm
import os
import gc

# --- C·∫§U H√åNH CHO MODEL 650M ---
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
TRAIN_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta"
TEST_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta"

# V·ªõi Model 650M tr√™n GPU T4 (16GB VRAM):
# Batch size = 4-6 l√† an to√†n nh·∫•t. N·∫øu crash th√¨ gi·∫£m xu·ªëng 2.
BATCH_SIZE = 6 
MAX_CONTEXT = 1024 
EMBED_DIM = 1280  # Model 650M t·∫°o ra vector 1280 chi·ªÅu

In [None]:
# --- 1. KH·ªûI T·∫†O ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {device}")

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

# B·∫Øt bu·ªôc d√πng FP16 ƒë·ªÉ ch·∫°y ƒë∆∞·ª£c Batch Size h·ª£p l√Ω
if device.type == 'cuda':
    model = model.half()
    
model.to(device)
model.eval()

In [None]:
# --- 2. H√ÄM T·∫†O EMBEDDING (H·ªó tr·ª£ chu·ªói d√†i) ---
def generate_embeddings_650M(fasta_path, prefix_name):
    print(f"\n--- Processing {prefix_name.upper()} set with ESM2-650M ---")
    
    if not os.path.exists(fasta_path):
        print(f"‚ö†Ô∏è Warning: File {fasta_path} not found.")
        return

    # Load d·ªØ li·ªáu
    ids = []
    seqs = []
    for record in SeqIO.parse(fasta_path, "fasta"):
        ids.append(record.id)
        seqs.append(str(record.seq))
    
    print(f"Found {len(seqs)} sequences.")
    
    # S·∫Øp x·∫øp theo ƒë·ªô d√†i ƒë·ªÉ t·ªëi ∆∞u t·ªëc ƒë·ªô (gi·∫£m padding th·ª´a)
    sorted_indices = np.argsort([len(s) for s in seqs])
    seqs_sorted = [seqs[i] for i in sorted_indices]
    
    batch_embeddings_list = []
    
    # B·∫Øt ƒë·∫ßu v√≤ng l·∫∑p
    for i in tqdm(range(0, len(seqs_sorted), BATCH_SIZE), desc=f"Embedding {prefix_name}"):
        batch_seqs = seqs_sorted[i : i + BATCH_SIZE]
        
        # Ki·ªÉm tra xem c√≥ chu·ªói n√†o d√†i qu√° 1022 k√Ω t·ª± kh√¥ng
        has_long_seq = any(len(s) > (MAX_CONTEXT - 2) for s in batch_seqs)
        
        if has_long_seq:
            # --- CHI·∫æN L∆Ø·ª¢C CHUNKING (C·∫Øt nh·ªè & G·ªôp) ---
            # D√†nh cho protein d√†i ƒë·ªÉ kh√¥ng b·ªã m·∫•t th√¥ng tin
            for seq in batch_seqs:
                chunks = [seq[j : j + (MAX_CONTEXT-2)] for j in range(0, len(seq), (MAX_CONTEXT-2))]
                chunk_vectors = []
                
                for chunk in chunks:
                    inputs = tokenizer(chunk, return_tensors="pt", padding=False, truncation=True, max_length=MAX_CONTEXT).to(device)
                    with torch.no_grad():
                        out = model(**inputs).last_hidden_state
                        # Mean pooling (b·ªè CLS/EOS)
                        if out.shape[1] > 2: vec = out[0, 1:-1, :].mean(dim=0)
                        else: vec = out[0, :, :].mean(dim=0)
                        chunk_vectors.append(vec.float().cpu().numpy())
                
                # T√≠nh trung b√¨nh c√°c ƒëo·∫°n
                batch_embeddings_list.append(np.mean(chunk_vectors, axis=0))
                
        else:
            # --- CHI·∫æN L∆Ø·ª¢C BATCH NHANH ---
            inputs = tokenizer(batch_seqs, return_tensors="pt", padding=True, truncation=True, max_length=MAX_CONTEXT).to(device)
            with torch.no_grad():
                outputs = model(**inputs)
                last_hidden_state = outputs.last_hidden_state
                attention_mask = inputs['attention_mask']
                
                # Masking chu·∫©n
                mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
                if device.type == 'cuda': mask_expanded = mask_expanded.half()

                sum_embeddings = torch.sum(last_hidden_state * mask_expanded, 1)
                sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
                
                # K·∫øt qu·∫£ batch
                batch_embs = (sum_embeddings / sum_mask).float().cpu().numpy()
                batch_embeddings_list.extend(batch_embs)
        
        # Clear VRAM th∆∞·ªùng xuy√™n h∆°n v√¨ model 650M kh√° n·∫∑ng
        if i % (BATCH_SIZE * 20) == 0: 
            torch.cuda.empty_cache()

    # S·∫Øp x·∫øp l·∫°i ƒë√∫ng th·ª© t·ª± ban ƒë·∫ßu
    final_embeddings = np.zeros((len(seqs), EMBED_DIM), dtype=np.float32)
    for idx, original_idx in enumerate(sorted_indices):
        final_embeddings[original_idx] = batch_embeddings_list[idx]
    
    # L∆∞u file
    # ƒê·∫∑t t√™n suffix _650M ƒë·ªÉ d·ªÖ ph√¢n bi·ªát
    np.save(f"{prefix_name}_embeddings_650M.npy", final_embeddings)
    np.save(f"{prefix_name}_ids.npy", np.array(ids))
    print(f"‚úÖ Saved: {prefix_name}_embeddings_650M.npy ({final_embeddings.shape})")
    
    # Gi·∫£i ph√≥ng RAM tri·ªát ƒë·ªÉ
    del final_embeddings, batch_embeddings_list, seqs, ids
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# --- 3. CH·∫†Y ---
generate_embeddings_650M(TRAIN_FASTA, "train")
generate_embeddings_650M(TEST_FASTA, "test")

print("\nüéâ DONE 650M PIPELINE!")