In [None]:
# @title ðŸ§¬ DNA Analysis Tool (Nucleotide Transformer Colab Edition)
# @markdown Run this cell to start the app. **Ensure you have selected T4 GPU runtime.**

import os
import sys
import threading
from datetime import datetime
import pandas as pd
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files

# --- 1. Install Dependencies ---
if not os.environ.get("SKIP_PIP_INSTALL"):
    print("Installing dependencies (one-time; please wait)...")
    !pip install -q transformers torch pandas huggingface_hub[hf_xet] hf_transfer sentencepiece
else:
    print("Skipping dependency install (SKIP_PIP_INSTALL=1).")
print("Dependencies step finished. UI will appear below.")

from transformers import AutoTokenizer, AutoModelForMaskedLM

# --- 2. Core Analysis Class ---
class DNAAnalyzerEngine:
    def __init__(self, model_name):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None
        self.tokenizer = None
        # Configuration for sequence processing
        self.max_length = 1000
        self.chunk_overlap = 50
        
    def load_model(self, use_fp16=True, status_callback=None):
        """Loads model and tokenizer with optional FP16 optimization."""
        try:
            if status_callback: status_callback("Downloading Tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            
            if status_callback: status_callback("Downloading Model Weights (This takes time)...")
            
            # Use FP16 if requested and on CUDA to save VRAM
            torch_dtype = torch.float16 if use_fp16 and self.device == "cuda" else torch.float32
            
            self.model = AutoModelForMaskedLM.from_pretrained(
                self.model_name, 
                trust_remote_code=True,
                torch_dtype=torch_dtype
            )
            self.model.to(self.device)
            self.model.eval()
            return True, f"Loaded on {self.device.upper()} ({'FP16' if use_fp16 else 'FP32'})"
        except Exception as e:
            return False, str(e)

    def calculate_perplexity(self, sequence):
        """Calculates perplexity using sliding window for long sequences."""
        inputs = self.tokenizer(sequence, return_tensors="pt", add_special_tokens=True)
        input_ids = inputs["input_ids"][0]
        
        # Chunking Strategy
        chunks = []
        if len(input_ids) <= self.max_length:
            chunks.append(input_ids)
        else:
            # Sliding window
            stride = self.max_length - self.chunk_overlap
            for i in range(0, len(input_ids), stride):
                chunk = input_ids[i : i + self.max_length]
                if len(chunk) > 10: chunks.append(chunk)

        if not chunks: return None

        total_loss = 0
        count = 0
        with torch.no_grad():
            for chunk in chunks:
                chunk = chunk.unsqueeze(0).to(self.device)
                # Masked LM loss where labels = inputs
                outputs = self.model(chunk, labels=chunk)
                loss = outputs.loss
                if not torch.isnan(loss):
                    total_loss += loss.item()
                    count += 1
        
        if count == 0: return None
        return torch.exp(torch.tensor(total_loss / count)).item()

    def repair_sequence(self, sequence):
        """
        Repairs DNA sequence by predicting masked 'N' tokens.
        Fix: Handles 6-mer expansion by truncating prediction to 1 base.
        """
        # Basic validation
        if not sequence or 'N' not in sequence.upper():
            return sequence, "No 'N' found"

        # 1. Identify all 'N' positions for precise repair.
        # For simplicity/efficiency, we maintain the masked replacement logic
        # but will handle length control during the decoding phase.
        
        mask_str = self.tokenizer.mask_token
        # Note: Simple replacement makes the tokenizer see "...ACGT[MASK]ACGT..."
        # The NT tokenizer splits surrounding ACGT into bases or fragments, 
        # while [MASK] is treated as a 6-mer to be predicted.
        masked_seq_str = sequence.upper().replace('N', mask_str)
        
        inputs = self.tokenizer(masked_seq_str, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.device)
        
        # Length check
        if input_ids.shape[1] > self.max_length:
            return sequence, "Skipped (Too Long)"

        # Model Inference
        with torch.no_grad():
            logits = self.model(input_ids).logits

        # Find all mask positions
        mask_indices = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
        if len(mask_indices[0]) == 0:
             return sequence, "Tokenizer Error"

        # Get predicted token IDs
        predicted_ids = logits[mask_indices].argmax(dim=-1)
        
        # --- Critical Fix Start ---
        # We cannot directly replace input_ids with predicted_ids and decode, 
        # as that would insert the entire 6-mer (causing sequence length expansion).
        
        # We need to manually construct the repaired string.
        # Due to complex tokenizer behavior, the most robust method is:
        # 1. Convert input_ids back to a token list.
        # 2. For each [MASK] position, decode its predicted ID and take only the first character.
        
        # Convert original input_ids to list (assuming batch size is 1 here)
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
        
        # Create a pointer to iterate through predicted_ids
        pred_idx = 0
        
        restored_tokens = []
        for token in tokens:
            if token == mask_str:
                # This is a masked position; retrieve the corresponding predicted ID
                pred_id = predicted_ids[pred_idx]
                pred_idx += 1
                
                # Decode this token the full predicted 6-mer (e.g., might decode to "ACGTGC")
                pred_str = self.tokenizer.decode(pred_id)
                
                # Take only the first base (assuming 'N' represents a single point deletion)
                # Note: If 'N' represents a larger segment, keeping the full 6-mer might be better.
                # But in this "Deep Repair" context, we enforce 1:1 replacement.
                fix_base = pred_str[0] if pred_str else "N" 
                restored_tokens.append(fix_base)
            else:
                restored_tokens.append(token)
        
        # Reassemble the string
        # Use convert_tokens_to_string to handle any subword prefixes (like ##) correctly
        repaired_seq = self.tokenizer.convert_tokens_to_string(restored_tokens)
        
        # Clean up any spaces or residual special tokens
        repaired_seq = repaired_seq.replace(" ", "").replace(mask_str, "")
        
        # --- Critical Fix End ---
        
        return repaired_seq, "Repaired"


# --- 3. Colab UI Logic ---

# Model Selection Dropdown
# IMPORTANT: NT-500M Human Ref is set as default for Human Detection
model_dropdown = widgets.Dropdown(
    options=[
        ("NT-500M Human Ref (Best for Human Detection)", "InstaDeepAI/nucleotide-transformer-500m-human-ref"),
        ("NT-2.5B 1000G (Best Accuracy for Human Variants)", "InstaDeepAI/nucleotide-transformer-2.5b-1000g"),
        ("NT-2.5B Multi-species (Best for Repair/General)", "InstaDeepAI/nucleotide-transformer-2.5b-multi-species"),
    ],
    value="InstaDeepAI/nucleotide-transformer-500m-human-ref",
    description='Model:',
    layout=widgets.Layout(width='500px')
)

# Initialize engine with default model
engine = DNAAnalyzerEngine(model_dropdown.value)
uploaded_filename = None

# UI Widgets
header = widgets.HTML("<h2>ðŸ§¬ DNA Analysis Tool (Colab Edition)</h2>")
status_label = widgets.Label(value="System Status: Idle (Waiting for file upload)")
upload_btn = widgets.FileUpload(accept='.csv', multiple=False)
col_name_input = widgets.Text(value='sequence_text', description='Col Name:')
mode_dropdown = widgets.Dropdown(
    options=[
        ('Classify (Is Human DNA?)', 'classify'), 
        ('Deep Repair (Fill N)', 'repair')
    ], 
    description='Mode:'
)
load_btn = widgets.Button(description="Initialize Model", button_style='primary', icon='download')
run_btn = widgets.Button(description="Start Analysis", button_style='success', icon='play', disabled=True)
output_area = widgets.Output()
progress_bar = widgets.IntProgress(value=0, min=0, max=100, description='Progress:', bar_style='info', orientation='horizontal', layout=widgets.Layout(width='100%'))

def on_upload_change(change):
    global uploaded_filename
    if not upload_btn.value: return
    
    try:
        # Handle new ipywidgets upload format
        uploaded_filename = list(upload_btn.value.keys())[0]
        content = upload_btn.value[uploaded_filename]['content']
    except:
        # Handle older format (sometimes tuple)
        uploaded_file = upload_btn.value[0]
        uploaded_filename = uploaded_file.name
        content = uploaded_file.content
        
    with open(uploaded_filename, "wb") as f:
        f.write(content)
    status_label.value = f"File Uploaded: {uploaded_filename}"

upload_btn.observe(on_upload_change, names='value')

def on_load_click(b):
    load_btn.disabled = True
    status_label.value = "Initializing Model... (Please wait)"
    
    # Update engine with currently selected model
    engine.model_name = model_dropdown.value
    
    with output_area:
        clear_output()
        print(f"Initializing model: {engine.model_name}")
        if "human-ref" in engine.model_name:
            print("Context: Human Reference (Suitable for detecting Human vs Contamination)")
        else:
            print("Context: Multi-Species (Suitable for general DNA repair or embeddings)")
    
    def _load():
        success, msg = engine.load_model(status_callback=lambda s: setattr(status_label, 'value', s))
        if success:
            status_label.value = f"Ready: {msg}"
            run_btn.disabled = False
            with output_area:
                print(f"Model ready: {engine.model_name} | {msg}")
        else:
            status_label.value = f"Error: {msg}"
            load_btn.disabled = False
            with output_area:
                print(f"Error loading model: {msg}")
            
    threading.Thread(target=_load).start()

load_btn.on_click(on_load_click)

def on_run_click(b):
    if not uploaded_filename:
        status_label.value = "Error: No CSV file uploaded!"
        return
    
    run_btn.disabled = True
    col = col_name_input.value
    mode = mode_dropdown.value
    
    # Determine context for logic interpretation
    is_human_ref = "human-ref" in engine.model_name.lower() or "500m" in engine.model_name.lower()
    
    with output_area:
        clear_output()
        print(f"Reading {uploaded_filename}...")
        
        try:
            df = pd.read_csv(uploaded_filename)
            if col not in df.columns:
                print(f"Error: Column '{col}' not found.")
                run_btn.disabled = False
                return
            
            results, details = [], []
            total = len(df)
            progress_bar.max = total
            progress_bar.value = 0
            
            print(f"Processing {total} rows. Mode: {mode}...")
            
            # Threshold for human reference perplexity
            human_ppl_threshold = 45.0
            
            for i, row in df.iterrows():
                seq = str(row[col]).strip()
                if not seq or seq.lower() == 'nan':
                    results.append(""); details.append("Empty")
                    continue
                    
                if mode == 'classify':
                    ppl = engine.calculate_perplexity(seq)
                    if ppl:
                        score = round(ppl, 2)
                        
                        # --- INTERPRETATION LOGIC ---
                        if is_human_ref:
                            # For Human Ref model: Low PPL = Human, High PPL = Alien/Contamination
                            if ppl < human_ppl_threshold:
                                res_str = "Likely Human DNA"
                            else:
                                res_str = "Likely Contamination/Non-Human"
                        else:
                            # For Multi-species model: Low PPL = Valid DNA (could be bacteria)
                            if ppl < human_ppl_threshold:
                                res_str = "Valid DNA (Species Unknown)"
                            else:
                                res_str = "High Perplexity / Noise"
                        
                        results.append(res_str)
                        details.append(score)
                    else:
                        results.append("Error"); details.append(-1)
                else:
                    # Repair Mode
                    fixed, status = engine.repair_sequence(seq)
                    results.append(fixed); details.append(status)
                
                progress_bar.value = i + 1
                if i % 10 == 0: print(f"Row {i}/{total} processed", end='\r')
            
            # Save with timestamp to avoid collisions
            ts = datetime.now().strftime("%Y%m%d_%H%M%S")
            out_name = f"processed_{mode}_{ts}_{uploaded_filename}"
            
            if mode == 'classify':
                df['classification_result'] = results
                df['perplexity_score'] = details
                df['model_used'] = engine.model_name
            else:
                df['repaired_sequence'] = results
                df['repair_status'] = details
            
            df.to_csv(out_name, index=False)
            print(f"\nDone! Saved to {out_name}")
            
            # Trigger Download
            files.download(out_name)
            status_label.value = "Analysis Complete! File downloaded."
            
        except Exception as e:
            print(f"Error: {e}")
        finally:
            run_btn.disabled = False

run_btn.on_click(on_run_click)

# Layout Construction
ui = widgets.VBox([
    header,
    widgets.HBox([upload_btn, status_label]),
    model_dropdown,
    widgets.HBox([col_name_input, mode_dropdown]),
    widgets.HBox([load_btn, run_btn]),
    progress_bar,
    output_area
])

display(ui)
