In [None]:
import os
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm.auto import tqdm
import sys


DATA_ROOT = "/kaggle/input/cafa-6-protein-function-prediction"
GO_OBO = os.path.join(DATA_ROOT, "Train", "go-basic.obo")
IA_TSV = os.path.join(DATA_ROOT, "IA.tsv")
TRAIN_TAXONOMY = os.path.join(DATA_ROOT, "Train", "train_taxonomy.tsv")
TEST_FASTA = os.path.join(DATA_ROOT, "Test", "testsuperset.fasta")

GOA_PRED_FILE = os.path.join("/kaggle/input/protein-go-annotations", "goa_uniprot_ver228.tsv")
ESM_PRED_FILE = os.path.join("/kaggle/input/submission", "submission.tsv") ### my previous ESM predictions
OUTPUT_FILE = "submission.tsv"


CONFIG = {
    'TOP_K': 200,               
    'MIN_SCORE': 0.001,        
    'MAX_SCORE': 0.97,
    'NEG_PROP_ALPHA': 0.7,
    'SCALING_POWER': 0.8,
    'USE_IA': True,
    'USE_TAXONOMY': True,
    'IA_BOOST_FACTOR': 1.2,     # idk, im just changing them
    'TAXONOMY_BOOST': 1.15,     
    'WEIGHT_GOA': 0.55,
    'WEIGHT_ESM': 0.45,
}


print("Parsing GO ontology...")
term_parents = defaultdict(set)

with open(GO_OBO, 'r') as f:
    cur_id = None
    for line in f:
        line = line.strip()
        if line.startswith('id: '):
            cur_id = line.split('id: ')[1].strip()
        elif line.startswith('is_a: ') and cur_id:
            term_parents[cur_id].add(line.split()[1].strip())
        elif line.startswith('relationship: part_of ') and cur_id:
            term_parents[cur_id].add(line.split()[2].strip())

ROOTS = {'GO:0003674', 'GO:0008150', 'GO:0005575'}

ancestors_map = {}
def get_ancestors(term):
    if term in ancestors_map:
        return ancestors_map[term]
    parents = term_parents.get(term, set())
    all_anc = set(parents)
    for p in parents:
        all_anc |= get_ancestors(p)
    ancestors_map[term] = all_anc
    return all_anc

# Precompute ancestors for efficiency
for term in tqdm(term_parents.keys(), desc="Precomputing ancestors"):
    get_ancestors(term)

print(f"Parsed GO ontology with {len(ancestors_map)} terms")

ia_weights = {}
if CONFIG['USE_IA']:
    print("\nLoading IA weights...")
    ia_df = pd.read_csv(IA_TSV, sep='\t', header=None, 
                       names=['go_term', 'ia_weight'])
    
    # Normalize to 0-1 range, idk if it's helpful
    min_ia = ia_df['ia_weight'].min()
    max_ia = ia_df['ia_weight'].max()
    if max_ia > min_ia:
        ia_df['norm_weight'] = (ia_df['ia_weight'] - min_ia) / (max_ia - min_ia)
    else:
        ia_df['norm_weight'] = 1.0
    
    ia_weights = dict(zip(ia_df['go_term'], ia_df['norm_weight']))
    print(f"Loaded IA weights for {len(ia_weights)} terms")

train_tax_dict = {}
test_tax_dict = {}
taxon_to_go = {}

if CONFIG['USE_TAXONOMY']:
    print("\nLoading taxonomy information...")
    
    train_tax = pd.read_csv(TRAIN_TAXONOMY, sep='\t', 
                           names=['protein_id', 'taxon_id'])
    train_tax_dict = dict(zip(train_tax['protein_id'], train_tax['taxon_id']))
    
    test_tax_dict = {}
    if os.path.exists(TEST_FASTA):
        with open(TEST_FASTA, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        protein_id = parts[0][1:]
                        for part in parts[1:]:
                            if part.isdigit() and 4 <= len(part) <= 10:
                                test_tax_dict[protein_id] = int(part)
                                break
        
        print(f"Extracted taxonomy for {len(test_tax_dict)} test proteins")
    
    train_terms_path = os.path.join(DATA_ROOT, "Train", "train_terms.tsv")
    if os.path.exists(train_terms_path):
        train_terms = pd.read_csv(train_terms_path, sep='\t', 
                                 names=['protein_id', 'go_term', 'aspect'])
        
        train_merged = train_terms.merge(
            pd.DataFrame(list(train_tax_dict.items()), columns=['protein_id', 'taxon_id']),
            on='protein_id', 
            how='left'
        )
        
        for _, row in train_merged.iterrows():
            if pd.notna(row['taxon_id']):
                taxon_id = int(row['taxon_id'])
                if taxon_id not in taxon_to_go:
                    taxon_to_go[taxon_id] = set()
                taxon_to_go[taxon_id].add(row['go_term'])
        
        print(f"Mapped {len(taxon_to_go)} taxa to GO terms")


def load_predictions(filepath, is_goa=False):
    """Load predictions with chunking"""
    data = defaultdict(dict)
    
    if is_goa:
        try:
            for chunk in pd.read_csv(filepath, sep='\t', header=0, 
                                    chunksize=100000):
                for _, row in chunk.iterrows():
                    try:
                        pid = str(row.iloc[0]).strip()
                        go = str(row.iloc[1]).strip()
                        score = 1.0
                        if len(row) >= 3:
                            try:
                                score = float(row.iloc[2])
                            except (ValueError, TypeError):
                                score = 1.0
                        data[pid][go] = max(data[pid].get(go, 0), score)
                    except:
                        continue
        except:
            with open(filepath, 'r') as f:
                lines = f.readlines()
                start_idx = 1 if 'qualifier' in lines[0].lower() else 0
                for line in tqdm(lines[start_idx:], desc=f"Loading GOA"):
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        pid, go = parts[0], parts[1]
                        score = 1.0
                        if len(parts) >= 3:
                            try:
                                score = float(parts[2])
                            except:
                                score = 1.0
                        data[pid][go] = max(data[pid].get(go, 0), score)
    else:
        # ESM predictions
        try:
            for chunk in pd.read_csv(filepath, sep='\t', header=None,
                                    names=['protein', 'go', 'score'],
                                    chunksize=100000):
                for _, row in chunk.iterrows():
                    pid, go, score = str(row['protein']), str(row['go']), float(row['score'])
                    data[pid][go] = max(data[pid].get(go, 0), score)
        except:
            with open(filepath, 'r') as f:
                for line in tqdm(f, desc="Loading ESM"):
                    parts = line.strip().split('\t')
                    if len(parts) >= 3:
                        pid, go, score = parts[0], parts[1], float(parts[2])
                        data[pid][go] = max(data[pid].get(go, 0), score)
    
    return data

print("\nLoading predictions...")
goa_preds = load_predictions(GOA_PRED_FILE, is_goa=True)
esm_preds = load_predictions(ESM_PRED_FILE, is_goa=False)

print(f"GOA proteins: {len(goa_preds):,}")
print(f"ESM2 proteins: {len(esm_preds):,}")


print("\nCreating enhanced ensemble...")
esm_proteins = set(esm_preds.keys())
goa_proteins = set(goa_preds.keys())
all_proteins = esm_proteins & goa_proteins  

ensemble = defaultdict(dict)

for pid in tqdm(all_proteins, desc="Ensemble with IA"):
    goa = goa_preds.get(pid, {})
    esm = esm_preds.get(pid, {})
    all_terms = set(goa.keys()) | set(esm.keys())
    
    for term in all_terms:
        s_goa = goa.get(term, 0)
        s_esm = esm.get(term, 0)
        
        if s_goa > 0 and s_esm > 0:
            base_score = CONFIG['WEIGHT_GOA'] * s_goa + CONFIG['WEIGHT_ESM'] * s_esm
        elif s_goa > 0:
            base_score = s_goa
        else:
            base_score = s_esm
        
        if CONFIG['USE_IA']:
            ia_weight = ia_weights.get(term, 0.5)
            boost = 1.0 + (CONFIG['IA_BOOST_FACTOR'] - 1.0) * ia_weight
            base_score = min(base_score * boost, CONFIG['MAX_SCORE'])
        
        ensemble[pid][term] = base_score

print(f"Ensemble created: {len(ensemble):,} proteins")


def enhanced_process_protein(protein_id, scores_dict):
    """Enhanced version of process_protein with IA and taxonomy"""
    updated = scores_dict.copy()
    
    if CONFIG['USE_TAXONOMY']:
        test_taxon = test_tax_dict.get(protein_id)
        if test_taxon and test_taxon in taxon_to_go:
            common_terms = taxon_to_go[test_taxon]
            for term in list(updated.keys()):
                if term in common_terms:
                    updated[term] = min(updated[term] * CONFIG['TAXONOMY_BOOST'], CONFIG['MAX_SCORE'])
    
    # Positive propagation: ensure parent >= child, explanation in readme
    for term, score in scores_dict.items():
        for anc in get_ancestors(term):
            current = updated.get(anc, 0)
            if score > current:
                updated[anc] = score
    
    # Negative propagation: child â‰¤ parent
    for term in list(updated.keys()):
        if term in ROOTS:
            continue
        ancs = get_ancestors(term)
        if ancs:
            anc_scores = [updated.get(a, 0) for a in ancs if a in updated]
            if anc_scores and min(anc_scores) < updated[term]:
                alpha = CONFIG['NEG_PROP_ALPHA']
                updated[term] = alpha * min(anc_scores) + (1 - alpha) * updated[term]
    
    # Power scaling
    non_root = [s for t, s in updated.items() if t not in ROOTS]
    if non_root:
        max_val = max(non_root)
        if 0 < max_val < CONFIG['MAX_SCORE']:
            for t in updated:
                if t not in ROOTS:
                    new_score = np.power(updated[t] / max_val, CONFIG['SCALING_POWER']) * CONFIG['MAX_SCORE']
                    updated[t] = min(CONFIG['MAX_SCORE'], new_score)
    
    # Force roots
    for r in ROOTS:
        updated[r] = 1.0
    
    return updated


print(f"\nApplying enhanced propagation (TOP_K={CONFIG['TOP_K']})...")
final_rows = []

for pid, scores in tqdm(ensemble.items(), desc="Processing proteins"):
    updated = enhanced_process_protein(pid, scores)
    
    sorted_terms = sorted(updated.items(), key=lambda x: -x[1])
    
    kept_terms = []
    for term, score in sorted_terms:
        if score >= CONFIG['MIN_SCORE']:
            kept_terms.append((term, score))
            if len(kept_terms) >= CONFIG['TOP_K']:
                break
    
    for term, score in kept_terms:
        final_rows.append(f"{pid}\t{term}\t{score:.6f}")


print(f"\nSaving {len(final_rows):,} predictions...")
with open(OUTPUT_FILE, 'w') as f:
    for line in final_rows:
        f.write(line + "\n")

size_mb = os.path.getsize(OUTPUT_FILE) / (1024 * 1024)
print(f"Submission file saved: {OUTPUT_FILE}")
print(f"Size: {size_mb:.1f} MB, Predictions: {len(final_rows):,}")
print(f"Average predictions per protein: {len(final_rows)/len(ensemble):.2f}")

## summary
summary_file = "enhanced_top200_summary.txt"
with open(summary_file, 'w') as f:
    f.write("=== ENHANCED ENSEMBLE WITH TOP_K=200 ===\n")
    f.write(f"GO ontology terms: {len(ancestors_map)}\n")
    f.write(f"IA weights loaded: {CONFIG['USE_IA']}\n")
    f.write(f"Taxonomy data loaded: {CONFIG['USE_TAXONOMY']}\n")
    f.write(f"GOA proteins: {len(goa_preds)}\n")
    f.write(f"ESM proteins: {len(esm_preds)}\n")
    f.write(f"Ensemble proteins (intersection): {len(ensemble)}\n")
    f.write(f"Final predictions: {len(final_rows)}\n")
    f.write(f"Average predictions per protein: {len(final_rows)/len(ensemble):.2f}\n")
    f.write(f"File size: {size_mb:.1f} MB\n")
    f.write(f"\nConfiguration:\n")
    for key, value in CONFIG.items():
        f.write(f"  {key}: {value}\n")

print(f"\nSummary saved to: {summary_file}")

### preview
print("\nSample predictions (first 10):")
for i, line in enumerate(final_rows[:10]):
    print(f"  {line.strip()}")