In [1]:
"""
Protein Structure Prediction Agent - Core Implementation
IONLACE Technical Interview Assignment
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any
from datetime import datetime
import hashlib

In [2]:
"""
ESMFold Client for Protein Structure Prediction
Uses HuggingFace Transformers local inference for structure prediction
Fallback to ESM Atlas API if available
"""

import asyncio
import json
import logging
import torch
from typing import Dict, Optional, Any, Union
from dataclasses import dataclass
import warnings

# Suppress warnings for cleaner demo output
warnings.filterwarnings("ignore")

try:
    from transformers import AutoTokenizer, EsmForProteinFolding
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False

try:
    import aiohttp
    AIOHTTP_AVAILABLE = True
except ImportError:
    AIOHTTP_AVAILABLE = False


@dataclass
class ESMFoldConfig:
    """Configuration for ESMFold inference"""
    # Local model settings
    model_name: str = "facebook/esmfold_v1"
    device: str = "auto"  # auto, cpu, cuda
    max_length: int = 400  # ESMFold limit
    num_recycles: int = 4  # Number of recycles for better accuracy

    # Fallback API settings (if model loading fails)
    fallback_api_url: str = "https://api.esmatlas.com/foldSequence/v1/pdb/"
    api_timeout: int = 300
    max_retries: int = 3


@dataclass
class PredictionResult:
    """Result from ESMFold prediction"""
    success: bool
    pdb_content: Optional[str] = None
    error_message: Optional[str] = None
    prediction_time: float = 0.0
    confidence_scores: Optional[Dict[str, float]] = None
    method_used: str = "unknown"  # "local_transformers" or "api_fallback"
    plddt_scores: Optional[list] = None


class ESMFoldClient:
    """Simplified ESMFold client that follows your working example pattern"""

    def __init__(self, config: ESMFoldConfig = None):
        self.config = config or ESMFoldConfig()
        self.logger = logging.getLogger("ESMFoldClient")
        self.model = None
        self.tokenizer = None
        self.device = None
        self._model_loaded = False

    async def _ensure_model_loaded(self):
        """Simple model loading following your working example"""
        if self._model_loaded:
            return

        if not TRANSFORMERS_AVAILABLE:
            return

        try:
            # Follow your working example exactly
            self.logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)

            self.logger.info("Loading model...")
            self.model = EsmForProteinFolding.from_pretrained(
                self.config.model_name,
                low_cpu_mem_usage=True
            )

            # Move to CUDA exactly as in your working example
            self.logger.info("Moving to CUDA...")
            self.model = self.model.cuda()

            # Set device attribute
            self.device = torch.device("cuda")

            # Set half precision exactly as in your working example
            self.logger.info("Setting ESM to half precision...")
            self.model.esm = self.model.esm.half()

            # Enable TF32 exactly as in your working example
            self.logger.info("Enabling TF32...")
            torch.backends.cuda.matmul.allow_tf32 = True

            self._model_loaded = True
            self.logger.info("✅ Model loaded successfully")

        except Exception as e:
            self.logger.error(f"Failed to load model: {e}")

    async def _predict_local(self, sequence: str) -> PredictionResult:
        """Direct prediction following your working example exactly"""

        try:
            start_time = asyncio.get_event_loop().time()

            # Use EXACTLY the same approach as your working example
            self.logger.info("Tokenizing...")
            tokenized_input = self.tokenizer(
                [sequence],
                return_tensors="pt",
                add_special_tokens=False
            )['input_ids']

            self.logger.info(f"Tokenized input shape: {tokenized_input.shape}")

            # Move to CUDA exactly as in your working example
            self.logger.info("Moving tokenized input to CUDA...")
            tokenized_input = tokenized_input.cuda()

            self.logger.info("Running inference...")
            with torch.no_grad():
                outputs = self.model(tokenized_input)

            prediction_time = asyncio.get_event_loop().time() - start_time
            self.logger.info("✅ Inference completed successfully")

            # Extract confidence scores from model outputs - FIXED VERSION
            self.logger.info("Extracting confidence scores...")
            confidence_scores = {}
            plddt_scores = None

            try:
                # Check for pLDDT directly on outputs (as shown in your working example)
                if hasattr(outputs, 'plddt'):
                    # The plddt tensor has shape [1, 50, 37] - we want the mean across all atoms for each residue
                    plddt_tensor = outputs.plddt  # Shape: [1, 50, 37]

                    # Calculate mean pLDDT per residue (average across all atoms for each residue)
                    plddt_per_residue = plddt_tensor.mean(dim=2).squeeze(0)  # Shape: [50]
                    plddt_scores = plddt_per_residue.tolist()  # Convert to list for plddt_scores field

                    # Calculate overall statistics
                    confidence_scores['mean_plddt'] = float(plddt_per_residue.mean().item())
                    confidence_scores['min_plddt'] = float(plddt_per_residue.min().item())
                    confidence_scores['max_plddt'] = float(plddt_per_residue.max().item())

                    self.logger.info(f"✅ Extracted pLDDT scores - Mean: {confidence_scores['mean_plddt']:.2f}")
                    self.logger.info(f"✅ pLDDT range: {confidence_scores['min_plddt']:.2f} - {confidence_scores['max_plddt']:.2f}")
                    self.logger.info(f"✅ pLDDT per residue: {len(plddt_scores)} values")
                else:
                    self.logger.info("❌ No plddt attribute in outputs")

                # Check for other confidence metrics
                if hasattr(outputs, 'ptm'):
                    ptm_score = float(outputs.ptm.item())  # PTM is already a scalar
                    confidence_scores['ptm_score'] = ptm_score
                    self.logger.info(f"✅ Extracted PTM score: {ptm_score:.3f}")

                # Check for other attributes
                if hasattr(outputs, 'positions'):
                    confidence_scores['has_positions'] = 1.0
                    self.logger.info("✅ 3D positions available")

                # Log what we found
                if confidence_scores:
                    self.logger.info(f"✅ Confidence scores extracted: {list(confidence_scores.keys())}")
                else:
                    self.logger.info("❌ No confidence scores found in model outputs")

            except Exception as e:
                self.logger.warning(f"Could not extract confidence scores: {e}")
                import traceback
                traceback.print_exc()  # This will show the full error
                confidence_scores = {
                    'error_note': 0.0,
                    'method': 1.0
                }

            # If no confidence scores found, provide fallback
            if not confidence_scores:
                confidence_scores = {
                    'fallback_note': 0.0,
                    'method': 1.0
                }

            # Convert to PDB and return result
            pdb_content = self._outputs_to_pdb(outputs, sequence)

            # Log what we're about to return
            self.logger.info(f"🔍 DEBUG: About to return PredictionResult with confidence_scores: {confidence_scores}")
            self.logger.info(f"🔍 DEBUG: About to return PredictionResult with plddt_scores: {plddt_scores}")

            return PredictionResult(
                success=True,
                pdb_content=pdb_content,
                prediction_time=prediction_time,
                method_used="local_transformers",
                confidence_scores=confidence_scores,
                plddt_scores=plddt_scores
            )

        except Exception as e:
            self.logger.error(f"Prediction failed: {e}")
            return PredictionResult(
                success=False,
                error_message=str(e),
                method_used="local_transformers"
            )

    def _outputs_to_pdb(self, outputs, sequence: str) -> str:
      """Convert model outputs to PDB format - fixed for actual output shape"""
      try:
          # Extract coordinates - shape is (recycles, batch, length, atoms, 3)
          if not hasattr(outputs, 'positions') or outputs.positions is None:
              raise ValueError("No positions in model output")

          positions = outputs.positions.cpu().numpy()
          self.logger.info(f"Raw positions shape: {positions.shape}")

          # Handle the actual output shape: (8, 1, length, 14, 3)
          if len(positions.shape) == 5:
              # Use the last recycle (most refined) and remove batch dimension
              positions = positions[-1, 0]  # Shape becomes (length, 14, 3)
              self.logger.info(f"Processed positions shape: {positions.shape}")
          elif len(positions.shape) == 4:
              # Remove batch dimension
              positions = positions[0]  # Shape becomes (length, atoms, 3)
          elif len(positions.shape) != 3:
              raise ValueError(f"Unexpected positions shape: {positions.shape}")

          if positions.shape[0] != len(sequence):
              self.logger.warning(f"Position length {positions.shape[0]} != sequence length {len(sequence)}")
              # Truncate to shorter length
              min_len = min(positions.shape[0], len(sequence))
              positions = positions[:min_len]
              sequence = sequence[:min_len]

          # Create PDB content
          pdb_lines = []
          pdb_lines.append("HEADER    ESMFold PREDICTION")
          pdb_lines.append("REMARK    Generated by ESMFold via HuggingFace Transformers")
          pdb_lines.append(f"REMARK    Sequence length: {len(sequence)}")
          pdb_lines.append(f"REMARK    Positions shape: {positions.shape}")

          atom_id = 1
          for res_idx in range(len(sequence)):
              if res_idx >= positions.shape[0]:
                  break

              amino_acid = sequence[res_idx]
              res_num = res_idx + 1

              # Standard atom order: N, CA, C, O (indices 0, 1, 2, 3)
              atom_names = ['N', 'CA', 'C', 'O']

              for atom_idx, atom_name in enumerate(atom_names):
                  if atom_idx < positions.shape[1]:  # Check if atom exists
                      coords = positions[res_idx, atom_idx]

                      # Check for valid coordinates
                      if not all(abs(coord) < 1000 for coord in coords):
                          continue  # Skip invalid coordinates

                      x, y, z = coords

                      pdb_line = (
                          f"ATOM  {atom_id:5d}  {atom_name:<3s} {amino_acid} A{res_num:4d}    "
                          f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00 85.00           {atom_name[0]}"
                      )
                      pdb_lines.append(pdb_line)
                      atom_id += 1

          pdb_lines.append("END")
          pdb_content = '\n'.join(pdb_lines)

          # Validate PDB content
          if atom_id == 1:  # No atoms were added
              raise ValueError("No valid atoms found in structure")

          self.logger.info(f"Generated PDB with {atom_id-1} atoms")
          return pdb_content

      except Exception as e:
          self.logger.error(f"PDB conversion failed: {str(e)}")
          # Return minimal PDB with error info
          return (
              "HEADER    ESMFold PREDICTION FAILED\n"
              f"REMARK    Error: {str(e)}\n"
              "REMARK    This is a placeholder structure\n"
              "END"
          )

    def _is_valid_amino_acid_sequence(self, sequence: str) -> bool:
        """Validate amino acid sequence contains only standard residues"""
        valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
        return all(aa in valid_aa for aa in sequence)

    async def predict_structure(self, sequence: str) -> PredictionResult:
        """Main prediction method with fallback"""
        # Validate input
        if not sequence or not isinstance(sequence, str):
            return PredictionResult(
                success=False,
                error_message="Invalid sequence: must be non-empty string"
            )

        # Clean and validate sequence
        clean_sequence = sequence.strip().upper()
        if not self._is_valid_amino_acid_sequence(clean_sequence):
            return PredictionResult(
                success=False,
                error_message="Invalid amino acid sequence: contains non-standard characters"
            )

        if len(clean_sequence) > self.config.max_length:
            return PredictionResult(
                success=False,
                error_message=f"Sequence too long ({len(clean_sequence)} > {self.config.max_length} residues)"
            )

        self.logger.info(f"Predicting structure for {len(clean_sequence)} residue sequence")

        # Try local inference first
        try:
            await self._ensure_model_loaded()
            if self._model_loaded:
                self.logger.info("Using local model for prediction")
                result = await self._predict_local(clean_sequence)
                if result and result.success:
                    return result
                else:
                    self.logger.warning(f"Local prediction failed: {result.error_message if result else 'returned None'}")
            else:
                self.logger.info("Local model not available")
        except Exception as e:
            self.logger.error(f"Local prediction exception: {str(e)}")

        # Fall back to API if local fails
        if AIOHTTP_AVAILABLE:
            self.logger.info("Attempting API fallback...")
            try:
                result = await self._predict_api(clean_sequence)
                if result:
                    return result
            except Exception as e:
                self.logger.error(f"API fallback exception: {str(e)}")

        # If we get here, everything failed
        return PredictionResult(
            success=False,
            error_message="All prediction methods failed: local inference and API fallback both unavailable",
            method_used="none"
        )

    async def _predict_api(self, sequence: str) -> PredictionResult:
        """Fallback to ESM Atlas API"""
        if not AIOHTTP_AVAILABLE:
            return PredictionResult(
                success=False,
                error_message="aiohttp not available for API calls"
            )

        try:
            timeout = aiohttp.ClientTimeout(total=self.config.api_timeout)

            # SSL context that ignores certificate errors (for demo purposes)
            import ssl
            ssl_context = ssl.create_default_context()
            ssl_context.check_hostname = False
            ssl_context.verify_mode = ssl.CERT_NONE

            connector = aiohttp.TCPConnector(ssl=ssl_context)

            async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
                start_time = asyncio.get_event_loop().time()

                async with session.post(
                    self.config.fallback_api_url,
                    data=sequence,
                    headers={'Content-Type': 'text/plain'}
                ) as response:

                    prediction_time = asyncio.get_event_loop().time() - start_time

                    if response.status == 200:
                        pdb_content = await response.text()

                        if 'ATOM' in pdb_content or 'HEADER' in pdb_content:
                            return PredictionResult(
                                success=True,
                                pdb_content=pdb_content,
                                prediction_time=prediction_time,
                                method_used="api_fallback"
                            )
                        else:
                            return PredictionResult(
                                success=False,
                                error_message="API returned invalid PDB content",
                                method_used="api_fallback"
                            )
                    else:
                        error_text = await response.text()
                        return PredictionResult(
                            success=False,
                            error_message=f"API error {response.status}: {error_text}",
                            method_used="api_fallback"
                        )

        except Exception as e:
            return PredictionResult(
                success=False,
                error_message=f"API call failed: {str(e)}",
                method_used="api_fallback"
            )

    async def cleanup(self):
      """Clean up GPU memory"""
      if self.model is not None:
          try:
              self.model.cpu()
              del self.model
              self.model = None
              if torch.cuda.is_available():
                  torch.cuda.empty_cache()
              self._model_loaded = False
              self.logger.info("✅ GPU memory cleaned up")
          except Exception as e:
              self.logger.warning(f"Cleanup warning: {e}")

    async def __aexit__(self, exc_type, exc_val, exc_tb):
      """Async context manager exit with cleanup"""
      await self.cleanup()

In [3]:
async def test_with_memory_management():
    """Test with proper memory management"""

    test_sequence = "PQITLWQRPLVTIKIGGQLKEALLDTGADD"

    print("🧪 Testing with Memory Management")
    print(f"Sequence: {test_sequence}")
    print("="*50)

    try:
        # Test 1: Simplified client
        print("\n1. Testing Simplified Client...")
        client1 = ESMFoldClient()
        await client1._ensure_model_loaded()

        if client1._model_loaded:
            result1 = await client1._predict_local(test_sequence)
            if result1 and result1.success:
                print("✅ Simplified client SUCCESS!")
                print(f"   PDB length: {len(result1.pdb_content)}")
                print(f"   Contains ATOM: {'ATOM' in result1.pdb_content}")

                # Show first few lines
                pdb_lines = result1.pdb_content.split('\n')[:8]
                for line in pdb_lines:
                    print(f"   {line}")
            else:
                print("❌ Simplified client FAILED")
                if result1:
                    print(f"   Error: {result1.error_message}")

        # Clean up first client
        await client1.cleanup()

        # Test 2: Working example
        print("\n2. Testing Working Example...")
        try:
            from transformers import AutoTokenizer, EsmForProteinFolding
            import torch

            tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
            model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
            model = model.cuda()
            model.esm = model.esm.half()
            torch.backends.cuda.matmul.allow_tf32 = True

            tokenized_input = tokenizer([test_sequence], return_tensors="pt", add_special_tokens=False)['input_ids']
            tokenized_input = tokenized_input.cuda()

            with torch.no_grad():
                outputs = model(tokenized_input)

            print("✅ Working example SUCCESS!")
            print(f"   Output type: {type(outputs)}")
            if hasattr(outputs, 'positions'):
                print(f"   Positions shape: {outputs.positions.shape}")
            if hasattr(outputs, 'plddt'):
                print(f"   pLDDT mean: {outputs.plddt.mean().item():.2f}")

            # Clean up
            model.cpu()
            del model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"❌ Working example FAILED: {e}")

        # Test 3: Simplified client again (should work after cleanup)
        print("\n3. Testing Simplified Client Again...")
        client2 = ESMFoldClient()
        await client2._ensure_model_loaded()

        if client2._model_loaded:
            result2 = await client2._predict_local(test_sequence)
            if result2 and result2.success:
                print("✅ Simplified client SUCCESS again!")
                print(f"   PDB length: {len(result2.pdb_content)}")
                print(f"   Contains ATOM: {'ATOM' in result2.pdb_content}")
            else:
                print("❌ Simplified client FAILED on second try")
                if result2:
                    print(f"   Error: {result2.error_message}")

        await client2.cleanup()

        return True

    except Exception as e:
        print(f"❌ Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

In [4]:
async def test_simplified_esmfold_client():
    """Test the simplified ESMFold client with the same approach as your working example"""

    # Use the same test sequence as before
    test_sequence = "PQITLWQRPLVTIKIGGQLKEALLDTGADD"

    print(f"Testing Simplified ESMFold Client")
    print(f"Sequence: {test_sequence}")
    print("="*60)

    try:
        # Create client instance
        client = ESMFoldClient()

        # Load model (this should now work like your working example)
        print("🔄 Loading model...")
        await client._ensure_model_loaded()

        if not client._model_loaded:
            print("❌ Model failed to load")
            return None

        print("✅ Model loaded successfully")
        print(f"   Device: {client.device}")
        print(f"   Model device: {client.model.device if client.model else 'None'}")

        # Test local prediction
        print("\n🔄 Testing local prediction...")
        result = await client._predict_local(test_sequence)

        if result and result.success:
            print("✅ Local prediction successful!")
            print(f"   Method: {result.method_used}")
            print(f"   Time: {result.prediction_time:.2f}s")
            print(f"   PDB length: {len(result.pdb_content)} characters")

            # Show first few lines of PDB
            print("\n📄 First few PDB lines:")
            pdb_lines = result.pdb_content.split('\n')[:8]
            for line in pdb_lines:
                print(f"   {line}")

            # Show confidence scores if available
            if result.confidence_scores:
                print(f"\n�� Confidence scores: {result.confidence_scores}")

            return result

        else:
            print("❌ Local prediction failed")
            if result:
                print(f"   Error: {result.error_message}")
            else:
                print("   Result is None")
            return None

    except Exception as e:
        print(f"❌ Test failed with exception: {str(e)}")
        import traceback
        print("Full traceback:")
        print(traceback.format_exc())
        return None


# Alternative test that compares with your working example
async def test_comparison():
    """Compare the simplified client with your working example"""

    test_sequence = "PQITLWQRPLVTIKIGGQLKEALLDTGADD"

    print("�� Testing Simplified Client vs Working Example")
    print("="*60)

    # Test 1: Your working example
    print("\n🧪 Test 1: Your Working Example")
    try:
        from transformers import AutoTokenizer, EsmForProteinFolding
        import torch

        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")

        print("Loading model...")
        model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

        print("Moving to CUDA...")
        model = model.cuda()

        print("Setting ESM to half precision...")
        model.esm = model.esm.half()

        print("Enabling TF32...")
        torch.backends.cuda.matmul.allow_tf32 = True

        print("Tokenizing...")
        tokenized_input = tokenizer([test_sequence], return_tensors="pt", add_special_tokens=False)['input_ids']

        print("Moving to CUDA...")
        tokenized_input = tokenized_input.cuda()

        print("Running inference...")
        with torch.no_grad():
            output = model(output)

        print("✅ Working example SUCCESS!")
        print(f"Output type: {type(output)}")
        if hasattr(output, 'positions'):
            print(f"Positions shape: {output.positions.shape}")
        if hasattr(output, 'plddt'):
            print(f"pLDDT mean: {output.plddt.mean().item():.2f}")

    except Exception as e:
        print(f"❌ Working example FAILED: {e}")

    # Test 2: Simplified client
    print("\n�� Test 2: Simplified ESMFold Client")
    try:
        client = ESMFoldClient()
        await client._ensure_model_loaded()

        if client._model_loaded:
            result = await client._predict_local(test_sequence)
            if result and result.success:
                print("✅ Simplified client SUCCESS!")
                print(f"PDB length: {len(result.pdb_content)}")
            else:
                print("❌ Simplified client FAILED")
                if result:
                    print(f"Error: {result.error_message}")
        else:
            print("❌ Simplified client model failed to load")

    except Exception as e:
        print(f"❌ Simplified client FAILED: {e}")


# Quick validation test
async def test_quick_validation():
    """Quick test to validate the client works"""

    test_sequence = "MKLLVL"  # Very short sequence for quick testing

    print("🧪 Quick Validation Test")
    print(f"Sequence: {test_sequence}")
    print("="*40)

    try:
        client = ESMFoldClient()

        # Test model loading
        print("1. Testing model loading...")
        await client._ensure_model_loaded()
        print(f"   Model loaded: {client._model_loaded}")

        if client._model_loaded:
            # Test prediction
            print("2. Testing prediction...")
            result = await client._predict_local(test_sequence)
            print(f"   Prediction success: {result.success if result else False}")

            if result and result.success:
                print("3. Testing PDB output...")
                print(f"   PDB length: {len(result.pdb_content)}")
                print(f"   Contains ATOM: {'ATOM' in result.pdb_content}")
                print(f"   Contains END: {'END' in result.pdb_content}")

                return True
            else:
                print(f"   Error: {result.error_message if result else 'Result is None'}")
                return False
        else:
            print("   Model loading failed")
            return False

    except Exception as e:
        print(f"   Exception: {e}")
        return False

In [5]:
# Main test runner
async def run_all_tests():
    """Run all tests to validate the simplified client"""

    print("🚀 Running All Tests for Simplified ESMFold Client")
    print("="*70)

    tests = [
        ("Quick Validation", test_quick_validation),
        ("Simplified Client", test_simplified_esmfold_client),
        ("Memory Management Test", test_with_memory_management),  # New test
        # ("Comparison Test", test_comparison),  # Comment out the problematic one
    ]

    results = {}

    for test_name, test_func in tests:
        print(f"\n{'='*20} {test_name} {'='*20}")
        try:
            result = await test_func()
            results[test_name] = "PASS" if result else "FAIL"
        except Exception as e:
            print(f"❌ Test {test_name} crashed: {e}")
            results[test_name] = "CRASH"

    # Summary
    print(f"\n{'='*20} TEST SUMMARY {'='*20}")
    for test_name, result in results.items():
        status_emoji = "✅" if result == "PASS" else "❌" if result == "FAIL" else "��"
        print(f"{status_emoji} {test_name}: {result}")

    return results

# Run the test
if __name__ == "__main__":
    #import asyncio
    #asyncio.run(test_your_exact_example())
    await run_all_tests()

🚀 Running All Tests for Simplified ESMFold Client

🧪 Quick Validation Test
Sequence: MKLLVL
1. Testing model loading...


tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/8.44G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/8.44G [00:00<?, ?B/s]

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


   Model loaded: True
2. Testing prediction...
   Prediction success: True
3. Testing PDB output...
   PDB length: 2007
   Contains ATOM: True
   Contains END: True

Testing Simplified ESMFold Client
Sequence: PQITLWQRPLVTIKIGGQLKEALLDTGADD
🔄 Loading model...


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Model loaded successfully
   Device: cuda
   Model device: cuda:0

🔄 Testing local prediction...
✅ Local prediction successful!
   Method: local_transformers
   Time: 1.32s
   PDB length: 9401 characters

📄 First few PDB lines:
   HEADER    ESMFold PREDICTION
   REMARK    Generated by ESMFold via HuggingFace Transformers
   REMARK    Sequence length: 30
   REMARK    Positions shape: (30, 14, 3)
   ATOM      1  N   P A   1      -1.480   6.329 -12.393  1.00 85.00           N
   ATOM      2  CA  P A   1      -0.202   5.704 -12.743  1.00 85.00           C
   ATOM      3  C   P A   1      -0.357   4.568 -13.751  1.00 85.00           C
   ATOM      4  O   P A   1      -1.351   3.838 -13.715  1.00 85.00           O

�� Confidence scores: {'mean_plddt': 0.6586815118789673, 'min_plddt': 0.46007055044174194, 'max_plddt': 0.7873799204826355, 'ptm_score': 0.3420056998729706, 'has_positions': 1.0}

🧪 Testing with Memory Management
Sequence: PQITLWQRPLVTIKIGGQLKEALLDTGADD

1. Testing Simplified Cl

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Simplified client SUCCESS!
   PDB length: 9401
   Contains ATOM: True
   HEADER    ESMFold PREDICTION
   REMARK    Generated by ESMFold via HuggingFace Transformers
   REMARK    Sequence length: 30
   REMARK    Positions shape: (30, 14, 3)
   ATOM      1  N   P A   1      -1.480   6.329 -12.393  1.00 85.00           N
   ATOM      2  CA  P A   1      -0.202   5.704 -12.743  1.00 85.00           C
   ATOM      3  C   P A   1      -0.357   4.568 -13.751  1.00 85.00           C
   ATOM      4  O   P A   1      -1.351   3.838 -13.715  1.00 85.00           O

2. Testing Working Example...


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Working example SUCCESS!
   Output type: <class 'transformers.models.esm.modeling_esmfold.EsmForProteinFoldingOutput'>
   Positions shape: torch.Size([8, 1, 30, 14, 3])
   pLDDT mean: 0.66

3. Testing Simplified Client Again...


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Simplified client SUCCESS again!
   PDB length: 9401
   Contains ATOM: True

✅ Quick Validation: PASS
✅ Simplified Client: PASS
✅ Memory Management Test: PASS


In [6]:
!pip install biopython

Collecting biopython
  Downloading biopython-1.85-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━[0m [32m2.2/3.3 MB[0m [31m65.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [26]:
"""
Protein Structure Prediction Agent - Core Implementation
IONLACE Technical Interview Assignment
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any
from datetime import datetime
import hashlib

# Import the working ESMFold client
#from esmfold_client import ESMFoldClient, ESMFoldConfig, PredictionResult

class StepStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    RETRYING = "retrying"


class DecisionType(Enum):
    CONTINUE = "continue"
    RETRY = "retry"
    FALLBACK = "fallback"
    ABORT = "abort"
    COMPLETE = "complete"


@dataclass
class ExecutionStep:
    """Represents a single step in the agent's execution plan"""
    name: str
    description: str
    function_name: str
    dependencies: List[str] = field(default_factory=list)
    max_retries: int = 3
    timeout_seconds: int = 30


@dataclass
class StepResult:
    """Result of executing a single step"""
    step_name: str
    status: StepStatus
    data: Optional[Any] = None
    error: Optional[str] = None
    execution_time: float = 0.0
    attempt_number: int = 1


@dataclass
class Decision:
    """Agent's decision after observing a step result"""
    action: DecisionType
    reason: str
    next_step: Optional[str] = None
    retry_delay: float = 1.0


@dataclass
class AgentState:
    """Maintains agent's execution state"""
    sequence: str = ""
    sequence_hash: str = ""
    current_step_index: int = 0
    execution_plan: List[ExecutionStep] = field(default_factory=list)
    results: Dict[str, StepResult] = field(default_factory=dict)
    execution_log: List[str] = field(default_factory=list)
    start_time: datetime = field(default_factory=datetime.now)

    def is_complete(self) -> bool:
        """Check if all steps are completed successfully"""
        return self.current_step_index >= len(self.execution_plan)

    def get_current_step(self) -> Optional[ExecutionStep]:
        """Get the current step to execute"""
        if self.current_step_index < len(self.execution_plan):
            return self.execution_plan[self.current_step_index]
        return None

    def mark_step_complete(self):
        """Move to next step"""
        self.current_step_index += 1


class ProteinStructureAgent:
    """
    Autonomous agent for protein structure prediction and analysis.

    Implements think/plan/act/observe cycle to:
    1. Validate amino acid sequences
    2. Call structure prediction APIs (ESMFold)
    3. Parse and analyze results
    4. Generate comprehensive reports
    """

    def __init__(self, log_level: str = "INFO"):
        self.state = AgentState()
        self.logger = self._setup_logging(log_level)
        self.esmfold_client = None  # Will be initialized when needed

        # Updated step functions to use real implementations
        self.step_functions = {
            "validate_sequence": self._validate_sequence,
            "predict_structure": self._predict_structure,
            "parse_structure": self._parse_structure,
            "calculate_metrics": self._calculate_metrics,
            "generate_report": self._generate_report
        }

    def _setup_logging(self, level: str) -> logging.Logger:
        """Setup demo-friendly logging with emojis"""
        logger = logging.getLogger("ProteinAgent")

        # Clear any existing handlers and configuration
        logger.handlers.clear()
        logger.propagate = False  # Prevent propagation to root logger

        logger.setLevel(getattr(logging, level))

        # Add only our custom handler
        handler = logging.StreamHandler()
        formatter = logging.Formatter(
            '%(asctime)s | %(message)s',
            datefmt='%H:%M:%S'
        )
        handler.setFormatter(formatter)
        logger.addHandler(handler)

        return logger

    async def execute(self, sequence: str) -> Dict[str, Any]:
        """
        Main orchestrator - executes the complete think/plan/act/observe cycle

        Args:
            sequence: Amino acid sequence to analyze

        Returns:
            Final report with structure prediction and analysis
        """
        self.logger.info("🚀 Starting Protein Structure Prediction Agent")

        # Initialize state
        self.state.sequence = sequence.strip().upper()
        self.state.sequence_hash = hashlib.md5(sequence.encode()).hexdigest()[:8]

        # THINK: Analyze the task
        thought = self.think()
        self.logger.info(f"🤔 THINK: {thought}")

        # PLAN: Create execution strategy
        execution_plan = self.plan()
        self.state.execution_plan = execution_plan
        self.logger.info(f"📋 PLAN: {len(execution_plan)} steps planned")
        for i, step in enumerate(execution_plan, 1):
            self.logger.info(f"   Step {i}: {step.description}")

        # Main execution loop
        while not self.state.is_complete():
            current_step = self.state.get_current_step()
            if not current_step:
                break

            self.logger.info(f"�� ACT: {current_step.description}...")

            # Execute step with timing
            start_time = time.time()
            result = await self.act(current_step)
            result.execution_time = time.time() - start_time

            # Store result
            self.state.results[current_step.name] = result

            # OBSERVE: Evaluate result and decide next action
            decision = self.observe(result)
            self.logger.info(f"👁️  OBSERVE: {decision.reason}")

            # Handle decision
            if decision.action == DecisionType.CONTINUE:
                self.state.mark_step_complete()
            elif decision.action == DecisionType.RETRY:
                self.logger.info(f"🔄 Retrying in {decision.retry_delay}s...")
                await asyncio.sleep(decision.retry_delay)
            elif decision.action == DecisionType.FALLBACK:
                self.logger.warning("⚠️  Attempting fallback strategy")
                self.state.mark_step_complete()  # For now, continue
            elif decision.action == DecisionType.ABORT:
                self.logger.error("❌ Execution aborted")
                break
            elif decision.action == DecisionType.COMPLETE:
                self.logger.info("✅ All steps completed successfully")
                break

        # Generate final report
        total_time = (datetime.now() - self.state.start_time).total_seconds()
        self.logger.info(f"🏁 Agent execution completed in {total_time:.2f}s")

        return self._create_final_report()

    def think(self) -> str:
        """
        Analyze the current situation and task requirements

        Returns:
            Agent's reasoning about the task
        """
        seq_len = len(self.state.sequence)

        # Basic sequence validation thoughts
        if seq_len == 0:
            return "Empty sequence provided - cannot proceed"
        elif seq_len < 20:
            return f"Short sequence ({seq_len} residues) - prediction may be unreliable"
        elif seq_len > 1000:
            return f"Long sequence ({seq_len} residues) - may require extended processing time"
        else:
            return f"Analyzing sequence of {seq_len} residues for structure prediction"

    def plan(self) -> List[ExecutionStep]:
        """
        Create the execution strategy based on current context

        Returns:
            List of steps to execute
        """
        return [
            ExecutionStep(
                name="validate_sequence",
                description="Validate amino acid sequence format",
                function_name="validate_sequence",
                max_retries=1
            ),
            ExecutionStep(
                name="predict_structure",
                description="Call ESMFold API for structure prediction",
                function_name="predict_structure",
                dependencies=["validate_sequence"],
                max_retries=3,
                timeout_seconds=180  # Increased for real API calls
            ),
            ExecutionStep(
                name="parse_structure",
                description="Parse PDB structure and extract coordinates",
                function_name="parse_structure",
                dependencies=["predict_structure"],
                max_retries=2
            ),
            ExecutionStep(
                name="calculate_metrics",
                description="Calculate pLDDT, secondary structure, and other metrics",
                function_name="calculate_metrics",
                dependencies=["parse_structure"],
                max_retries=2
            ),
            ExecutionStep(
                name="generate_report",
                description="Generate human-readable and JSON reports",
                function_name="generate_report",
                dependencies=["calculate_metrics"],
                max_retries=1
            )
        ]

    async def act(self, step: ExecutionStep) -> StepResult:
        """
        Execute a single step with error handling and retries

        Args:
            step: The step to execute

        Returns:
            Result of the step execution
        """
        step_function = self.step_functions.get(step.function_name)
        if not step_function:
            return StepResult(
                step_name=step.name,
                status=StepStatus.FAILED,
                error=f"Unknown function: {step.function_name}"
            )

        # Execute with timeout and retry logic
        for attempt in range(1, step.max_retries + 1):
            try:
                # Check dependencies
                for dep in step.dependencies:
                    if dep not in self.state.results or \
                       self.state.results[dep].status != StepStatus.SUCCESS:
                        return StepResult(
                            step_name=step.name,
                            status=StepStatus.FAILED,
                            error=f"Dependency {dep} not satisfied",
                            attempt_number=attempt
                        )

                # Execute the step function
                result_data = await asyncio.wait_for(
                    step_function(),
                    timeout=step.timeout_seconds
                )

                return StepResult(
                    step_name=step.name,
                    status=StepStatus.SUCCESS,
                    data=result_data,
                    attempt_number=attempt
                )

            except asyncio.TimeoutError:
                error = f"Step timed out after {step.timeout_seconds}s"
                if attempt == step.max_retries:
                    return StepResult(
                        step_name=step.name,
                        status=StepStatus.FAILED,
                        error=error,
                        attempt_number=attempt
                    )
                self.logger.warning(f"⏰ {error}, retrying... ({attempt}/{step.max_retries})")

            except Exception as e:
                error = f"Step failed: {str(e)}"
                if attempt == step.max_retries:
                    return StepResult(
                        step_name=step.name,
                        status=StepStatus.FAILED,
                        error=error,
                        attempt_number=attempt
                    )
                self.logger.warning(f"❌ {error}, retrying... ({attempt}/{step.max_retries})")

        # Should not reach here, but safety fallback
        return StepResult(
            step_name=step.name,
            status=StepStatus.FAILED,
            error="Max retries exceeded"
        )

    def observe(self, result: StepResult) -> Decision:
        """
        Evaluate step result and decide on next action

        Args:
            result: Result from the executed step

        Returns:
            Decision on how to proceed
        """
        if result.status == StepStatus.SUCCESS:
            if result.step_name == "generate_report":
                return Decision(
                    action=DecisionType.COMPLETE,
                    reason=f"✅ {result.step_name} completed successfully"
                )
            else:
                return Decision(
                    action=DecisionType.CONTINUE,
                    reason=f"✅ {result.step_name} completed, proceeding to next step"
                )

        elif result.status == StepStatus.FAILED:
            # Decide whether to retry, fallback, or abort
            current_step = self.state.get_current_step()

            if current_step and result.attempt_number < current_step.max_retries:
                return Decision(
                    action=DecisionType.RETRY,
                    reason=f"⚠️ {result.step_name} failed, retrying ({result.attempt_number}/{current_step.max_retries})",
                    retry_delay=min(2.0 ** result.attempt_number, 10.0)  # Exponential backoff
                )

            elif result.step_name == "predict_structure":
                return Decision(
                    action=DecisionType.FALLBACK,
                    reason="🔄 Structure prediction failed, attempting fallback predictor"
                )

            else:
                return Decision(
                    action=DecisionType.ABORT,
                    reason=f"❌ Critical step {result.step_name} failed: {result.error}"
                )

        # Default fallback
        return Decision(
            action=DecisionType.ABORT,
            reason=f"Unexpected result status: {result.status}"
        )


    async def _validate_sequence(self):
        """Validate amino acid sequence format"""
        sequence = self.state.sequence

        # Check if sequence contains only valid amino acids
        valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
        invalid_chars = [aa for aa in sequence if aa not in valid_aa]

        if invalid_chars:
            return {
                "valid": False,
                "error": f"Invalid amino acids: {invalid_chars}",
                "length": len(sequence)
            }

        if len(sequence) < 20:
            return {
                "valid": True,
                "warning": "Short sequence - prediction may be unreliable",
                "length": len(sequence)
            }

        return {
            "valid": True,
            "length": len(sequence),
            "amino_acids": list(set(sequence))
        }


    async def _predict_structure(self):
        """Call ESMFold for structure prediction"""
        try:
            # Initialize ESMFold client if not already done
            if self.esmfold_client is None:
                self.logger.info("Initializing ESMFold client...")
                self.esmfold_client = ESMFoldClient()
                await self.esmfold_client._ensure_model_loaded()

            # Predict structure
            self.logger.info("Calling ESMFold for structure prediction...")
            result = await self.esmfold_client._predict_local(self.state.sequence)

            if result and result.success:
                # Use the confidence scores directly from the ESMFold result
                # The _predict_local method already extracts pLDDT correctly
                return {
                    "pdb_content": result.pdb_content,
                    "prediction_time": result.prediction_time,
                    "method_used": result.method_used,
                    "confidence_scores": result.confidence_scores,  # Use the good scores from _predict_local
                    "plddt_scores": result.plddt_scores  # Include the pLDDT scores
                }
            else:
                raise Exception(f"ESMFold prediction failed: {result.error_message if result else 'Unknown error'}")

        except Exception as e:
            self.logger.error(f"Structure prediction failed: {e}")
            raise

    async def _parse_structure_manual(self):
      """Manual PDB parsing as fallback when BioPython fails"""
      # Get the PDB content from previous step
      predict_result = self.state.results["predict_structure"]
      if not predict_result.data or "pdb_content" not in predict_result.data:
          raise Exception("No PDB content available for parsing")

      pdb_content = predict_result.data["pdb_content"]

      # Parse PDB content using proper field handling
      atoms = []
      residues = set()

      for line in pdb_content.split('\n'):
          if line.startswith('ATOM'):
              # Add debug line here - right after the ATOM check
              #self.logger.info(f"Line: '{line}'")
              #self.logger.info(f"Pos 17-20: '{line[17:20]}' | Pos 21: '{line[21]}' | Pos 22-26: '{line[22:26]}'")

              try:

                  # Parse ATOM line with proper field handling
                  residue_chain = line[17:20].strip()  # "P A" or "Q A" etc.
                  residue_name = residue_chain[0]      # "P", "Q", etc.
                  chain_id = residue_chain[2] if len(residue_chain) > 2 else ""  # "A", "B", etc.

                  atom_info = {
                      "atom_id": int(line[6:11].strip()),
                      "atom_name": line[12:16].strip(),
                      "residue_name": residue_name,        # Just the residue (P, Q, I, etc.)
                      "chain_id": chain_id,                # Just the chain (A, B, etc.)
                      "residue_id": int(line[22:26].strip()),
                  }

                  # Handle coordinates more carefully - they can be variable width
                  # Extract the coordinate section and split it properly
                  coord_section = line[30:54].strip()

                  # Split coordinates on whitespace, but be careful with negative numbers
                  coords = coord_section.split()
                  if len(coords) >= 3:
                      atom_info["x"] = float(coords[0])
                      atom_info["y"] = float(coords[1])
                      atom_info["z"] = float(coords[2])
                  else:
                      # Fallback to fixed-width if splitting fails
                      atom_info["x"] = float(line[30:38].strip())
                      atom_info["y"] = float(line[38:46].strip())
                      atom_info["z"] = float(line[46:54].strip())

                  # Handle occupancy and B-factor - check line length first
                  line_length = len(line)

                  # Occupancy (position 54-60, but check if line is long enough)
                  if line_length > 54:
                      occupancy_str = line[54:60].strip()
                      # Split on whitespace to handle cases where fields run together
                      occupancy_parts = occupancy_str.split()
                      if occupancy_parts:
                          atom_info["occupancy"] = float(occupancy_parts[0])
                      else:
                          atom_info["occupancy"] = 1.0  # Default value
                  else:
                      atom_info["occupancy"] = 1.0  # Default value

                  # B-factor (position 60-66, but check if line is long enough)
                  if line_length > 60:
                      bfactor_str = line[60:66].strip()
                      # Split on whitespace to handle cases where fields run together
                      bfactor_parts = bfactor_str.split()
                      if bfactor_parts:
                          atom_info["b_factor"] = float(bfactor_parts[0])
                      else:
                          atom_info["b_factor"] = 0.0  # Default value
                  else:
                      atom_info["b_factor"] = 0.0  # Default value

                  atoms.append(atom_info)
                  residues.add(atom_info["residue_id"])

              except (ValueError, IndexError) as e:
                  # Skip malformed lines but log them
                  self.logger.warning(f"Skipping malformed ATOM line: {line[:60]}... Error: {e}")
                  continue

      if not atoms:
          raise Exception("No valid ATOM records found in PDB")

      return {
          "total_atoms": len(atoms),
          "total_residues": len(residues),
          "atoms": atoms,
          "residue_list": sorted(list(residues)),
          "pdb_content": pdb_content
      }

    async def _parse_with_manual(self, pdb_content: str):
      """Execute manual strategy (chosen by agent)"""
      # Call your existing manual parsing method
      return await self._parse_structure_manual()

    async def _parse_structure(self):
      """Agentic structure parsing with intelligent strategy selection"""

      # Get PDB content
      predict_result = self.state.results["predict_structure"]
      pdb_content = predict_result.data["pdb_content"]

      # 🧠 AGENTIC: Analyze and decide BEFORE trying anything
      parsing_analysis = self._analyze_parsing_strategy(pdb_content)
      self.logger.info(f"🤔 Parsing Analysis: {parsing_analysis['reasoning']}")

      # 🎯 AGENTIC: Execute chosen strategy (no fallbacks, no errors)
      if parsing_analysis["recommended_strategy"] == "biopython":
          self.logger.info(f"🧠 Agent Decision: Using BioPython (Reason: {', '.join(parsing_analysis['reasoning'])})")
          return await self._parse_with_biopython(pdb_content)
      else:
          self.logger.info(f"🧠 Agent Decision: Using manual parsing (Reason: {', '.join(parsing_analysis['reasoning'])})")
          return await self._parse_with_manual(pdb_content)

    async def _parse_with_biopython(self, pdb_content: str):#renaming from _parse_structure
      """Parse PDB structure using BioPython for robustness"""
      try:
          from Bio.PDB import PDBParser
          from io import StringIO

          # Parse using BioPython
          parser = PDBParser(QUIET=True)
          structure = parser.get_structure("predicted", StringIO(pdb_content))

          # Extract information
          atoms = []
          residues = set()

          for model in structure:
              for chain in model:
                  for residue in chain:
                      residues.add(residue.id[1])  # residue number
                      for atom in residue:
                          atoms.append({
                              "atom_id": atom.serial_number,
                              "atom_name": atom.name,
                              "residue_name": residue.resname,
                              "chain_id": chain.id,
                              "residue_id": residue.id[1],
                              "x": atom.coord[0],
                              "y": atom.coord[1],
                              "z": atom.coord[2],
                              "occupancy": atom.occupancy,
                              "b_factor": atom.bfactor
                          })

          return {
              "total_atoms": len(atoms),
              "total_residues": len(residues),
              "atoms": atoms,
              "residue_list": sorted(list(residues)),
              "pdb_content": pdb_content,
              "method_used": "biopython"
          }

      except ImportError:
          # Fallback to manual parsing if BioPython not available
          self.logger.warning("BioPython not available, using manual parsing via Agent")
          #return await self._parse_structure_manual()
      #except Exception as e:
      #    self.logger.error(f"BioPython parsing failed: {e}")
      #    # Fallback to manual parsing
      #    return await self._parse_structure_manual()


    async def _analyze_prediction_strategy(self, sequence: str) -> Dict[str, Any]:
        """Agent analyzes the sequence and chooses the best prediction strategy"""

        analysis = {
            "sequence_length": len(sequence),
            "complexity": "simple",
            "recommended_strategy": "local_transformers",
            "reasoning": []
        }

        # THINK: Analyze sequence characteristics
        if len(sequence) < 20:
            analysis["complexity"] = "simple"
            analysis["recommended_strategy"] = "local_transformers"
            analysis["reasoning"].append("Short sequence - local model can handle efficiently")
        elif len(sequence) > 200:
            analysis["complexity"] = "complex"
            analysis["recommended_strategy"] = "api_fallback"
            analysis["reasoning"].append("Long sequence - API may be more reliable")
        else:
            analysis["complexity"] = "medium"
            analysis["recommended_strategy"] = "local_transformers"
            analysis["reasoning"].append("Medium sequence - local model preferred for speed")

        # Check for special amino acids that might affect prediction
        special_aa = set('XZBUO')  # Non-standard amino acids
        if any(aa in sequence for aa in special_aa):
            analysis["recommended_strategy"] = "api_fallback"
            analysis["reasoning"].append("Contains non-standard amino acids - API may handle better")

        return analysis

    def _analyze_parsing_strategy(self, pdb_content: str) -> Dict[str, Any]:
        """Agent analyzes PDB format and chooses the best parsing strategy"""

        analysis = {
            "format_type": "unknown",
            "recommended_strategy": "biopython",
            "reasoning": []
        }

        # THINK: Analyze PDB content characteristics
        if "ESMFold" in pdb_content:
            analysis["format_type"] = "esmfold"
            analysis["recommended_strategy"] = "manual"
            analysis["reasoning"].append("ESMFold format detected - manual parsing more reliable")
        elif "ATOM" in pdb_content and "HEADER" in pdb_content:
            analysis["format_type"] = "standard_pdb"
            analysis["recommended_strategy"] = "biopython"
            analysis["reasoning"].append("Standard PDB format - BioPython should work")
        else:
            analysis["format_type"] = "unknown"
            analysis["recommended_strategy"] = "manual"
            analysis["reasoning"].append("Unknown format - manual parsing as fallback")

        return analysis

    async def _calculate_metrics(self):
        """Calculate structure analysis metrics"""
        # Get parsed structure data
        parse_result = self.state.results["parse_structure"]
        if not parse_result.data:
            raise Exception("No parsed structure data available")

        structure_data = parse_result.data

        # Add debug logging
        self.logger.info(f"Structure data keys: {list(structure_data.keys())}")
        self.logger.info(f"Total atoms: {structure_data.get('total_atoms')}")
        self.logger.info(f"Total residues: {structure_data.get('total_residues')}")
        self.logger.info(f"Atoms list type: {type(structure_data.get('atoms'))}")
        self.logger.info(f"Atoms list length: {len(structure_data.get('atoms', []))}")

        # Calculate basic metrics
        total_atoms = structure_data["total_atoms"]
        total_residues = structure_data["total_residues"]

        # Debug: Check first few atoms
        self.logger.info(f"First atom: {structure_data['atoms'][0] if structure_data['atoms'] else 'None'}")
        self.logger.info(f"First atom name: {structure_data['atoms'][0]['atom_name'] if structure_data['atoms'] else 'None'}")

        # Calculate secondary structure (simplified)
        # In a real implementation, you'd use DSSP or similar
        self.logger.info("About to filter CA atoms...")
        ca_atoms = [atom for atom in structure_data["atoms"] if atom["atom_name"].strip() == "CA"]
        self.logger.info(f"Found {len(ca_atoms)} CA atoms")

        # Simple helix detection (simplified)
        helix_count = 0
        sheet_count = 0

        if len(ca_atoms) >= 4:
            # Very simplified secondary structure detection
            # This is a placeholder - real implementation would use DSSP
            helix_count = max(0, len(ca_atoms) // 10)  # Rough estimate
            sheet_count = max(0, len(ca_atoms) // 15)  # Rough estimate

        # Enhanced pLDDT and confidence extraction
        self.logger.info("About to extract confidence scores...")

        # Extract confidence metrics from prediction
        predict_result = self.state.results["predict_structure"]
        plddt_score = None
        confidence_details = {}

        # Debug: Check what's available
        self.logger.info(f"Predict result data keys: {list(predict_result.data.keys()) if predict_result.data else 'None'}")

        if predict_result.data and "confidence_scores" in predict_result.data:
            confidence = predict_result.data["confidence_scores"]
            self.logger.info(f"Confidence scores: {confidence}")

            if confidence:
                # Extract pLDDT scores (primary confidence metric)
                if "mean_plddt" in confidence:
                    plddt_score = confidence["mean_plddt"]
                    confidence_details["plddt_mean"] = plddt_score
                    confidence_details["plddt_min"] = confidence.get("min_plddt")
                    confidence_details["plddt_max"] = confidence.get("max_plddt")
                    self.logger.info(f"Extracted pLDDT: Mean={plddt_score:.2f}, Min={confidence_details['plddt_min']:.2f}, Max={confidence_details['plddt_max']:.2f}")

                # Extract PTM score (alternative confidence metric)
                elif "ptm_score" in confidence:
                    ptm_score = confidence["ptm_score"]
                    estimated_plddt = confidence.get("estimated_plddt")
                    confidence_details["ptm_score"] = ptm_score
                    confidence_details["estimated_plddt"] = estimated_plddt
                    self.logger.info(f"Extracted PTM: {ptm_score:.3f}, Estimated pLDDT: {estimated_plddt:.1f}")

                # Extract other confidence metrics
                if "plddt_per_residue" in confidence:
                    confidence_details["plddt_per_residue"] = confidence["plddt_per_residue"]
                    self.logger.info(f"Per-residue pLDDT available for {len(confidence['plddt_per_residue'])} residues")

                if "has_distogram" in confidence:
                    confidence_details["has_distogram"] = confidence["has_distogram"]
                    self.logger.info("Distance matrix (distogram) available")

                # Store method information
                confidence_details["method"] = confidence.get("method", "unknown")
                confidence_details["note"] = confidence.get("note", "")

            else:
                self.logger.warning("Confidence scores object is None or empty")
        else:
            self.logger.warning("No confidence scores found in prediction results")

        # Debug: Check all values before return
        self.logger.info(f"Debug values - helix_count: {helix_count}, sheet_count: {sheet_count}")
        self.logger.info(f"Debug values - total_residues: {total_residues}")
        self.logger.info(f"Debug values - plddt_score: {plddt_score}")

        # Get sequence length safely
        try:
            sequence_length = len(self.state.sequence) if self.state.sequence else 0
            self.logger.info(f"Sequence length: {sequence_length}")
        except Exception as e:
            self.logger.warning(f"Could not get sequence length: {e}")
            sequence_length = 0

        # Enhanced result dictionary with confidence details
        result_dict = {
            "total_atoms": total_atoms,
            "total_residues": total_residues,
            "helix_percent": (helix_count / total_residues * 100) if total_residues > 0 else 0,
            "sheet_percent": (sheet_count / total_residues * 100) if total_residues > 0 else 0,
            "plddt_score": plddt_score,
            "sequence_length": sequence_length,
            "confidence_details": confidence_details
        }

        self.logger.info(f"Result dictionary: {result_dict}")
        return result_dict

    async def _generate_report(self):
        """Generate comprehensive final reports"""
        """Generate comprehensive final reports"""
        # Collect all results - FIXED: Access .data attribute directly
        validation = self.state.results["validate_sequence"].data
        prediction = self.state.results["predict_structure"].data
        parsing = self.state.results["parse_structure"].data
        metrics = self.state.results["calculate_metrics"].data

        # Get sequence safely from validation results
        sequence = validation.get('sequence', 'Unknown')
        sequence_length = validation.get('length', 0)

        # Generate human-readable report
        human_report = f"""
    PROTEIN STRUCTURE PREDICTION REPORT
    ===================================

    Sequence Information:
    - Sequence: {sequence[:50] if sequence != 'Unknown' else 'Unknown'}{'...' if sequence != 'Unknown' and len(sequence) > 50 else ''}
    - Length: {sequence_length} residues
    - Hash: {self.state.sequence_hash if hasattr(self.state, 'sequence_hash') else 'Unknown'}

    Validation Results:
    - Valid: {validation.get('valid', 'Unknown')}
    - Length: {validation.get('length', 'Unknown')} residues

    Structure Prediction:
    - Method: {prediction.get('method_used', 'Unknown')}
    - Time: {prediction.get('prediction_time', 0):.2f} seconds
    - PDB Size: {len(prediction.get('pdb_content', ''))} characters

    Structure Analysis:
    - Total Atoms: {metrics.get('total_atoms', 'Unknown')}
    - Total Residues: {metrics.get('total_residues', 'Unknown')}
    - Helix %: {metrics.get('helix_percent', 0):.1f}%
    - Sheet %: {metrics.get('sheet_percent', 0):.1f}%
    - pLDDT Score: {metrics.get('plddt_score', 'Unknown')}

    Execution Summary:
    - Total Steps: {len(self.state.execution_plan)}
    - Successful Steps: {len([r for r in self.state.results.values() if r.status == StepStatus.SUCCESS])}
    - Total Time: {(datetime.now() - self.state.start_time).total_seconds():.2f} seconds
    """

        # Generate JSON report
        json_report = {
            "sequence_info": {
                "sequence": sequence,
                "sequence_hash": self.state.sequence_hash if hasattr(self.state, 'sequence_hash') else 'Unknown',
                "length": sequence_length
            },
            "validation": validation,
            "prediction": {
                "method": prediction.get('method_used'),
                "time": prediction.get('prediction_time'),
                "pdb_size": len(prediction.get('pdb_content', '')),
                "confidence_scores": prediction.get('confidence_scores', {})
            },
            "structure_analysis": metrics,
            "execution_summary": {
                "total_steps": len(self.state.execution_plan),
                "successful_steps": len([r for r in self.state.results.values() if r.status == StepStatus.SUCCESS]),
                "total_time": (datetime.now() - self.state.start_time).total_seconds(),
                "steps": {name: {
                    "status": result.status.value,
                    "execution_time": result.execution_time,
                    "attempts": result.attempt_number
                } for name, result in self.state.results.items()}
            }
        }

        return {
            "human_readable": human_report,
            "json_report": json_report,
            "pdb_content": prediction.get('pdb_content', '')
        }

    def _create_final_report(self) -> Dict[str, Any]:
        """Create comprehensive final report"""
        return {
            "sequence_info": {
                "sequence": self.state.sequence,
                "sequence_hash": self.state.sequence_hash,
                "length": len(self.state.sequence)
            },
            "execution_summary": {
                "total_steps": len(self.state.execution_plan),
                "successful_steps": len([r for r in self.state.results.values()
                                       if r.status == StepStatus.SUCCESS]),
                "total_time": (datetime.now() - self.state.start_time).total_seconds(),
                "steps": {name: {
                    "status": result.status.value,
                    "execution_time": result.execution_time,
                    "attempts": result.attempt_number
                } for name, result in self.state.results.items()}
            },
            "results": {name: result.data for name, result in self.state.results.items()
                       if result.data is not None}
        }

    async def cleanup(self):
        """Clean up resources"""
        if self.esmfold_client:
            await self.esmfold_client.cleanup()


In [20]:
# Demo Usage
async def demo():
    # Test sequence (4KRP chain B - first 50 residues)
    test_sequence = "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGI"

    agent = ProteinStructureAgent(log_level="INFO")

    try:
        result = await agent.execute(test_sequence)

        print("\n" + "="*60)
        print("FINAL REPORT")
        print("="*60)

        # Show human-readable report
        if "results" in result and "generate_report" in result["results"]:
            report_data = result["results"]["generate_report"]
            if "human_readable" in report_data:
                print(report_data["human_readable"])

        # Show JSON summary
        print("\n" + "="*60)
        print("JSON SUMMARY")
        print("="*60)
        print(json.dumps(result, indent=2, default=str))

        return result

    finally:
        # Clean up
        await agent.cleanup()

if __name__ == "__main__":
    # Run the demo
    # asyncio.run(demo()) # for standalone Python execution
    await demo() # Use await directly in Colab

21:01:02 | 🚀 Starting Protein Structure Prediction Agent
21:01:02 | 🤔 THINK: Analyzing sequence of 50 residues for structure prediction
21:01:02 | 📋 PLAN: 5 steps planned
21:01:02 |    Step 1: Validate amino acid sequence format
21:01:02 |    Step 2: Call ESMFold API for structure prediction
21:01:02 |    Step 3: Parse PDB structure and extract coordinates
21:01:02 |    Step 4: Calculate pLDDT, secondary structure, and other metrics
21:01:02 |    Step 5: Generate human-readable and JSON reports
21:01:02 | �� ACT: Validate amino acid sequence format...
21:01:02 | 👁️  OBSERVE: ✅ validate_sequence completed, proceeding to next step
21:01:02 | �� ACT: Call ESMFold API for structure prediction...
21:01:02 | Initializing ESMFold client...
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-s


FINAL REPORT

    PROTEIN STRUCTURE PREDICTION REPORT

    Sequence Information:
    - Sequence: Unknown
    - Length: 50 residues
    - Hash: e4b26517

    Validation Results:
    - Valid: True
    - Length: 50 residues

    Structure Prediction:
    - Method: local_transformers
    - Time: 1.31 seconds
    - PDB Size: 15561 characters

    Structure Analysis:
    - Total Atoms: 200
    - Total Residues: 50
    - Helix %: 10.0%
    - Sheet %: 6.0%
    - pLDDT Score: 0.5696514248847961

    Execution Summary:
    - Total Steps: 5
    - Successful Steps: 4
    - Total Time: 11.34 seconds
    

JSON SUMMARY
{
  "sequence_info": {
    "sequence": "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGI",
    "sequence_hash": "e4b26517",
    "length": 50
  },
  "execution_summary": {
    "total_steps": 5,
    "successful_steps": 5,
    "total_time": 11.340273,
    "steps": {
      "validate_sequence": {
        "status": "success",
        "execution_time": 8.368492126464844e-05,
        "atte

In [27]:
# test_biopython_standalone.py

async def test_biopython_directly(agent):
    """Test BioPython parsing directly to see if it works"""

    # Test with a simple PDB content (no ESMFold)
    test_pdb = """HEADER    PROTEIN
TITLE     TEST PROTEIN STRUCTURE
REMARK    This is a standard PDB format for testing
ATOM      1  N   ALA A   1      27.462  14.810  32.433  1.00 20.00           N
ATOM      2  CA  ALA A   1      26.000  15.000  32.000  1.00 20.00           C
ATOM      3  C   ALA A   1      25.000  14.000  31.000  1.00 20.00           C
ATOM      4  O   ALA A   1      24.000  14.500  30.500  1.00 20.00           O
TER
END"""

    try:
        print("🧪 Testing BioPython directly...")
        print(f"PDB Content: {test_pdb[:100]}...")

        # Test the BioPython method directly
        result = await agent._parse_with_biopython(test_pdb)

        print("✅ BioPython parsing SUCCESS!")
        print(f"   Total Atoms: {result.get('total_atoms')}")
        print(f"   Total Residues: {result.get('total_residues')}")
        print(f"   Method Used: {result.get('method_used')}")

        return result

    except Exception as e:
        print(f"❌ BioPython parsing FAILED: {e}")
        import traceback
        traceback.print_exc()
        return None

async def test_parsing_strategy_selection(agent):
    """Test the agent's parsing strategy selection"""

    # Test 1: Standard PDB (should choose BioPython)
    standard_pdb = """HEADER    PROTEIN
TITLE     TEST PROTEIN STRUCTURE
ATOM      1  N   ALA A   1      27.462  14.810  32.433  1.00 20.00           N
ATOM      2  CA  ALA A   1      26.000  15.000  32.000  1.00 20.00           C
TER
END"""

    # Test 2: ESMFold PDB (should choose manual)
    esmfold_pdb = """HEADER    ESMFold PREDICTION
REMARK    Generated by ESMFold via HuggingFace Transformers
ATOM      1  N   P A   1      -5.099   2.590 -11.945  1.00  5.00           N
TER
END"""

    print("🧪 Testing Agentic Parsing Strategy Selection")
    print("="*60)

    # Test standard PDB
    print("\n1. Testing Standard PDB (should choose BioPython):")
    analysis1 = agent._analyze_parsing_strategy(standard_pdb)
    print(f"   Format: {analysis1['format_type']}")
    print(f"   Strategy: {analysis1['recommended_strategy']}")
    print(f"   Reasoning: {', '.join(analysis1['reasoning'])}")

    # Test ESMFold PDB
    print("\n2. Testing ESMFold PDB (should choose manual):")
    analysis2 = agent._analyze_parsing_strategy(esmfold_pdb)
    print(f"   Format: {analysis2['format_type']}")
    print(f"   Strategy: {analysis2['recommended_strategy']}")
    print(f"   Reasoning: {', '.join(analysis2['reasoning'])}")

    return analysis1, analysis2

async def run_biopython_tests():
    """Run all BioPython and strategy tests"""

    print("🚀 Running BioPython and Strategy Tests")
    print("="*60)

    # Create agent instance
    agent = ProteinStructureAgent(log_level="INFO")

    try:
        # Test 1: BioPython parsing
        print("\n" + "="*40)
        print("TEST 1: BioPython Parsing")
        print("="*40)
        await test_biopython_directly(agent)

        # Test 2: Strategy selection
        print("\n" + "="*40)
        print("TEST 2: Strategy Selection")
        print("="*40)
        await test_parsing_strategy_selection(agent)

        print("\n✅ All tests completed!")

    finally:
        # Clean up
        await agent.cleanup()

# Run the tests
if __name__ == "__main__":
    await run_biopython_tests()

🚀 Running BioPython and Strategy Tests

TEST 1: BioPython Parsing
🧪 Testing BioPython directly...
PDB Content: HEADER    PROTEIN
TITLE     TEST PROTEIN STRUCTURE
REMARK    This is a standard PDB format for testi...
✅ BioPython parsing SUCCESS!
   Total Atoms: 4
   Total Residues: 1
   Method Used: biopython

TEST 2: Strategy Selection
🧪 Testing Agentic Parsing Strategy Selection

1. Testing Standard PDB (should choose BioPython):
   Format: standard_pdb
   Strategy: biopython
   Reasoning: Standard PDB format - BioPython should work

2. Testing ESMFold PDB (should choose manual):
   Format: esmfold
   Strategy: manual
   Reasoning: ESMFold format detected - manual parsing more reliable

✅ All tests completed!


In [29]:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any


In [30]:
class ProteinStructureAPI:
    """Mock API class for Jupyter notebook demonstration"""

    def __init__(self):
        self.agent = None

    async def predict_structure(self, sequence: str, log_level: str = "INFO") -> Dict[str, Any]:
        """API endpoint simulation - returns the same response format as a real API"""

        try:
            # Initialize agent (same as your existing demo)
            self.agent = ProteinStructureAgent(log_level=log_level)

            # Execute prediction (your existing workflow)
            result = await self.agent.execute(sequence)

            # Return API-like response format
            return {
                "success": True,
                "status_code": 200,
                "data": result,
                "message": "Structure prediction completed successfully",
                "timestamp": datetime.now().isoformat()
            }

        except Exception as e:
            return {
                "success": False,
                "status_code": 500,
                "error": str(e),
                "message": "Structure prediction failed",
                "timestamp": datetime.now().isoformat()
            }

    async def cleanup(self):
        """Clean up resources"""
        if self.agent:
            await self.agent.cleanup()

In [31]:
async def demo_api_style():
    """Demo that simulates API usage in Jupyter notebook"""

    # Test sequence
    test_sequence = "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGI"

    print("🚀 PROTEIN STRUCTURE PREDICTION API DEMO")
    print("="*60)
    print(f"📝 Request: POST /predict")
    print(f"📊 Payload: {{'sequence': '{test_sequence[:20]}...'}}")
    print("="*60)

    # Simulate API call
    api = ProteinStructureAPI()

    try:
        # Simulate API request
        print("🔄 Processing API request...")
        response = await api.predict_structure(test_sequence)

        # Display API response
        print(f"\n📡 API Response:")
        print(f"   Status: {response['status_code']}")
        print(f"   Success: {response['success']}")
        print(f"   Message: {response['message']}")
        print(f"   Timestamp: {response['timestamp']}")

        if response['success']:
            print(f"\n�� Response Data Keys: {list(response['data'].keys())}")

            # Show human-readable report
            if "results" in response['data'] and "generate_report" in response['data']["results"]:
                report_data = response['data']["results"]["generate_report"]
                if "human_readable" in report_data:
                    print("\n" + "="*60)
                    print("HUMAN READABLE REPORT")
                    print("="*60)
                    print(report_data["human_readable"])

            # Show JSON summary
            print("\n" + "="*60)
            print("JSON RESPONSE")
            print("="*60)
            print(json.dumps(response, indent=2, default=str))

        return response

    finally:
        await api.cleanup()

# Run the API-style demo
if __name__ == "__main__":
    await demo_api_style()

22:20:43 | 🚀 Starting Protein Structure Prediction Agent
22:20:43 | 🤔 THINK: Analyzing sequence of 50 residues for structure prediction
22:20:43 | 📋 PLAN: 5 steps planned
22:20:43 |    Step 1: Validate amino acid sequence format
22:20:43 |    Step 2: Call ESMFold API for structure prediction
22:20:43 |    Step 3: Parse PDB structure and extract coordinates
22:20:43 |    Step 4: Calculate pLDDT, secondary structure, and other metrics
22:20:43 |    Step 5: Generate human-readable and JSON reports
22:20:43 | �� ACT: Validate amino acid sequence format...
22:20:43 | 👁️  OBSERVE: ✅ validate_sequence completed, proceeding to next step
22:20:43 | �� ACT: Call ESMFold API for structure prediction...
22:20:43 | Initializing ESMFold client...


🚀 PROTEIN STRUCTURE PREDICTION API DEMO
📝 Request: POST /predict
📊 Payload: {'sequence': 'PQITLWQRPLVTIKIGGQLK...'}
🔄 Processing API request...


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
22:20:53 | Calling ESMFold for structure prediction...
22:20:54 | 👁️  OBSERVE: ✅ predict_structure completed, proceeding to next step
22:20:54 | �� ACT: Parse PDB structure and extract coordinates...
22:20:54 | 🤔 Parsing Analysis: ['ESMFold format detected - manual parsing more reliable']
22:20:54 | 🧠 Agent Decision: Using manual parsing (Reason: ESMFold format detected - manual parsing more reliable)
22:20:54 | 👁️  OBSERVE: ✅ parse_structure completed, proceeding to next step
22:20:54 | �� ACT: Calculate pLDDT, secondary structure, and other metrics...
22:20:54 | Structure data keys: ['total_atoms', 'total_residues', 'atoms', 'residue_list', 'pdb_content']
22:20:54 | Tot


📡 API Response:
   Status: 200
   Success: True
   Message: Structure prediction completed successfully
   Timestamp: 2025-09-03T22:20:54.675190

�� Response Data Keys: ['sequence_info', 'execution_summary', 'results']

HUMAN READABLE REPORT

    PROTEIN STRUCTURE PREDICTION REPORT

    Sequence Information:
    - Sequence: Unknown
    - Length: 50 residues
    - Hash: e4b26517

    Validation Results:
    - Valid: True
    - Length: 50 residues

    Structure Prediction:
    - Method: local_transformers
    - Time: 1.29 seconds
    - PDB Size: 15561 characters

    Structure Analysis:
    - Total Atoms: 200
    - Total Residues: 50
    - Helix %: 10.0%
    - Sheet %: 6.0%
    - pLDDT Score: 0.5696514248847961

    Execution Summary:
    - Total Steps: 5
    - Successful Steps: 4
    - Total Time: 11.34 seconds
    

JSON RESPONSE
{
  "success": true,
  "status_code": 200,
  "data": {
    "sequence_info": {
      "sequence": "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGI",
      

In [32]:
async def demo_multiple_sequences():
    """Demo with multiple sequences to show API robustness"""

    test_cases = [
        {
            "name": "Short Sequence",
            "sequence": "MKLLVL",
            "expected": "Should work quickly"
        },
        {
            "name": "Medium Sequence",
            "sequence": "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGI",
            "expected": "Standard case"
        },
        {
            "name": "Longer Sequence",
            "sequence": "MKLLVLFLGAAGVKNTGTAAGLMGVSHDNADKRVGRLLAVCQNKGSCGIGGDEFNLYQLLYCDGACQKSTLVHQRNC"[:50],
            "expected": "More complex prediction"
        }
    ]

    print("🧪 MULTIPLE SEQUENCE API TESTING")
    print("="*60)

    api = ProteinStructureAPI()
    results = {}

    try:
        for i, test_case in enumerate(test_cases, 1):
            print(f"\n📋 Test {i}: {test_case['name']}")
            print(f"   Sequence: {test_case['sequence'][:20]}...")
            print(f"   Expected: {test_case['expected']}")
            print(f"   Length: {len(test_case['sequence'])} residues")

            response = await api.predict_structure(test_case['sequence'])
            results[test_case['name']] = response

            status_emoji = "✅" if response['success'] else "❌"
            print(f"   Result: {status_emoji} {response['status_code']} - {response['message']}")

            if response['success']:
                data = response['data']
                if "results" in data and "calculate_metrics" in data["results"]:
                    metrics = data["results"]["calculate_metrics"]
                    print(f"   Metrics: {metrics.get('total_atoms', 'N/A')} atoms, "
                          f"{metrics.get('plddt_score', 'N/A'):.2f} pLDDT")

    finally:
        await api.cleanup()

    return results

# Run multiple sequence demo
if __name__ == "__main__":
    await demo_multiple_sequences()

22:21:20 | 🚀 Starting Protein Structure Prediction Agent
22:21:20 | 🤔 THINK: Short sequence (6 residues) - prediction may be unreliable
22:21:20 | 📋 PLAN: 5 steps planned
22:21:20 |    Step 1: Validate amino acid sequence format
22:21:20 |    Step 2: Call ESMFold API for structure prediction
22:21:20 |    Step 3: Parse PDB structure and extract coordinates
22:21:20 |    Step 4: Calculate pLDDT, secondary structure, and other metrics
22:21:20 |    Step 5: Generate human-readable and JSON reports
22:21:20 | �� ACT: Validate amino acid sequence format...
22:21:20 | 👁️  OBSERVE: ✅ validate_sequence completed, proceeding to next step
22:21:20 | �� ACT: Call ESMFold API for structure prediction...
22:21:20 | Initializing ESMFold client...


🧪 MULTIPLE SEQUENCE API TESTING

📋 Test 1: Short Sequence
   Sequence: MKLLVL...
   Expected: Should work quickly
   Length: 6 residues


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
22:21:30 | Calling ESMFold for structure prediction...
22:21:31 | 👁️  OBSERVE: ✅ predict_structure completed, proceeding to next step
22:21:31 | �� ACT: Parse PDB structure and extract coordinates...
22:21:31 | 🤔 Parsing Analysis: ['ESMFold format detected - manual parsing more reliable']
22:21:31 | 🧠 Agent Decision: Using manual parsing (Reason: ESMFold format detected - manual parsing more reliable)
22:21:31 | 👁️  OBSERVE: ✅ parse_structure completed, proceeding to next step
22:21:31 | �� ACT: Calculate pLDDT, secondary structure, and other metrics...
22:21:31 | Structure data keys: ['total_atoms', 'total_residues', 'atoms', 'residue_list', 'pdb_content']
22:21:31 | Tot

   Result: ✅ 200 - Structure prediction completed successfully
   Metrics: 24 atoms, 0.61 pLDDT

📋 Test 2: Medium Sequence
   Sequence: PQITLWQRPLVTIKIGGQLK...
   Expected: Standard case
   Length: 50 residues


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
22:21:41 | Calling ESMFold for structure prediction...
22:21:42 | 👁️  OBSERVE: ✅ predict_structure completed, proceeding to next step
22:21:42 | �� ACT: Parse PDB structure and extract coordinates...
22:21:42 | 🤔 Parsing Analysis: ['ESMFold format detected - manual parsing more reliable']
22:21:42 | 🧠 Agent Decision: Using manual parsing (Reason: ESMFold format detected - manual parsing more reliable)
22:21:42 | 👁️  OBSERVE: ✅ parse_structure completed, proceeding to next step
22:21:42 | �� ACT: Calculate pLDDT, secondary structure, and other metrics...
22:21:42 | Structure data keys: ['total_atoms', 'total_residues', 'atoms', 'residue_list', 'pdb_content']
22:21:42 | Tot

   Result: ✅ 200 - Structure prediction completed successfully
   Metrics: 200 atoms, 0.57 pLDDT

📋 Test 3: Longer Sequence
   Sequence: MKLLVLFLGAAGVKNTGTAA...
   Expected: More complex prediction
   Length: 50 residues


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
22:21:51 | Calling ESMFold for structure prediction...
22:21:52 | 👁️  OBSERVE: ✅ predict_structure completed, proceeding to next step
22:21:52 | �� ACT: Parse PDB structure and extract coordinates...
22:21:52 | 🤔 Parsing Analysis: ['ESMFold format detected - manual parsing more reliable']
22:21:52 | 🧠 Agent Decision: Using manual parsing (Reason: ESMFold format detected - manual parsing more reliable)
22:21:52 | 👁️  OBSERVE: ✅ parse_structure completed, proceeding to next step
22:21:52 | �� ACT: Calculate pLDDT, secondary structure, and other metrics...
22:21:52 | Structure data keys: ['total_atoms', 'total_residues', 'atoms', 'residue_list', 'pdb_content']
22:21:52 | Tot

   Result: ✅ 200 - Structure prediction completed successfully
   Metrics: 200 atoms, 0.47 pLDDT
