In [3]:
!pip install numpy==1.26.4  # Latest stable NumPy 1.x version
import os
import shutil
import copy
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import cv2
import numpy as np
import matplotlib.pyplot as plt
import fitz  # PyMuPDF
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from io import BytesIO
import pandas as pd
import warnings
from tqdm import tqdm
from chemical_dataset import ChemicalDataset

warnings.filterwarnings("ignore", category=UserWarning)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def extract_images_from_pdf(pdf_path, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    try:
        with fitz.open(pdf_path) as doc:
            print(f"Processing {len(doc)} pages from {pdf_path}")
            all_images = []
            for page_num in range(len(doc)):
                page = doc.load_page(page_num)
                img_list = page.get_images(full=True)
                for img_index, img in enumerate(img_list):
                    xref = img[0]
                    base_image = doc.extract_image(xref)
                    all_images.append({
                        'page': page_num,
                        'index': img_index,
                        'image': base_image["image"],
                        'ext': base_image["ext"]
                    })
                zoom = 2
                mat = fitz.Matrix(zoom, zoom)
                pix = page.get_pixmap(matrix=mat)
                all_images.append({
                    'page': page_num,
                    'index': -1,
                    'image': pix.tobytes(),
                    'ext': 'png'
                })
            for img in all_images:
                if img['index'] == -1:
                    output_path = os.path.join(output_folder, f"page_{img['page']+1}_full.png")
                else:
                    output_path = os.path.join(output_folder, f"page_{img['page']+1}_img_{img['index']}.{img['ext']}")
                with open(output_path, "wb") as f:
                    f.write(img['image'])
            return len(all_images)
    except Exception as e:
        print(f"Error processing PDF: {str(e)}")
        raise

class ChemicalStructureSegmenter:
    def __init__(self, min_structure_size=100, threshold=180, adaptive_thresh=True):
        """
        Enhanced chemical structure detector with:
        - Adaptive or fixed thresholding
        - Contour filtering
        - Structure validation
        """
        self.min_size = min_structure_size
        self.threshold = threshold
        self.adaptive_thresh = adaptive_thresh

    def segment_image(self, image_path, output_folder=None):
        """Improved segmentation with structure validation"""
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Could not read image: {image_path}")
        # Preprocessing pipeline
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        gray = cv2.medianBlur(gray, 3)
        gray = cv2.equalizeHist(gray)

        # Thresholding
        if self.adaptive_thresh:
            thresh = cv2.adaptiveThreshold(
                gray, 255,
                cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                cv2.THRESH_BINARY_INV, 11, 2
            )
        else:
            _, thresh = cv2.threshold(
                gray, self.threshold, 255,
                cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
            )

        # Morphological cleanup
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        cleaned = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=2)

        # Find contours
        contours, _ = cv2.findContours(cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        structures = []
        for i, cnt in enumerate(contours):
            x, y, w, h = cv2.boundingRect(cnt)
            # Size filtering
            if w < self.min_size or h < self.min_size:
                continue
            # Aspect ratio filtering
            aspect_ratio = w / float(h)
            if not (0.2 < aspect_ratio < 5.0):
                continue
            # Structure validation
            roi = gray[y:y+h, x:x+w]
            if self.is_chemical_structure(roi):
                structure = image[y:y+h, x:x+w]
                structures.append(structure)
                if output_folder:
                    os.makedirs(output_folder, exist_ok=True)
                    output_path = os.path.join(output_folder, f"structure_{i}.png")
                    cv2.imwrite(output_path, structure)
        return structures

    def is_chemical_structure(self, roi):
        """Validate if region contains a chemical structure"""
        circles = cv2.HoughCircles(roi, cv2.HOUGH_GRADIENT, 1, 20,
                                  param1=50, param2=30, minRadius=5, maxRadius=30)
        edges = cv2.Canny(roi, 50, 150)
        lines = cv2.HoughLinesP(edges, 1, np.pi/180, 50, minLineLength=20, maxLineGap=10)
        return (circles is not None and len(circles) > 3) or (lines is not None and len(lines) > 5)

def generate_chemical_images(output_dir="chemical_data", num_samples=5000, augment=True):
    classes = {
        'alkane': ['CCCC', 'CCCCC', 'CC(C)C', 'CCCCCC'],
        'alkene': ['C=CC', 'C=CCC', 'CC=CC', 'C=C(C)C'],
        'alcohol': ['CCO', 'CCCO', 'CC(C)O', 'CCCCO'],
        'carboxylic_acid': ['CC(=O)O', 'CCC(=O)O', 'CC(C)(=O)O'],
        'amine': ['CN', 'CCN', 'CC(C)N', 'CCCN'],
        'benzene': ['c1ccccc1', 'c1ccc(cc1)C', 'c1cc(ccc1)OC', 'c'],
        'amide': ['CC(=O)N', 'CCC(=O)N', 'CC(C)(=O)N'],
        'ether': ['COC', 'CCOC', 'CC(C)OC'],
        'ketone': ['CC(=O)C', 'CCC(=O)C', 'CC(C)(=O)C'],
        'aldehyde': ['CC=O', 'CCC=O', 'CC(C)=O'],
        'ester': ['CCOC=O', 'CCCOC=O'],
        'alkyne': ['C#C', 'C#CC', 'CC#CC'],
        'nitrile': ['C#N', 'CC#N'],
        'halide': ['CCl', 'CBr', 'CI'],
        # ... other classes ...
    }

    os.makedirs(output_dir, exist_ok=True)
    samples_per_class = num_samples // len(classes)
    remainder = num_samples % len(classes)
    print(f"Generating {num_samples} chemical structures across {len(classes)} classes...")
    for idx, (class_name, smiles_list) in enumerate(classes.items()):
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)
        current_samples = samples_per_class + (1 if idx < remainder else 0)
        for i in tqdm(range(current_samples), desc=f"Creating {class_name} images"):
            try:
                smiles = np.random.choice(smiles_list)
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    continue
                AllChem.Compute2DCoords(mol)
                drawer = rdMolDraw2D.MolDraw2DCairo(400, 400)
                options = drawer.drawOptions()
                options.bondLineWidth = int(np.random.randint(1, 3))
                options.highlightBondWidthMultiplier = 10 + np.random.randint(-5, 5)
                options.atomLabelDeuteriumTritium = bool(np.random.choice([True, False]))
                drawer.DrawMolecule(mol)
                drawer.FinishDrawing()
                img_bytes = drawer.GetDrawingText()
                img = Image.open(BytesIO(img_bytes))
                if augment:
                    img = apply_image_augmentations(img)
                output_path = os.path.join(class_dir, f"{idx}_{i}.png")
                img.save(output_path)
            except Exception as e:
                print(f"Error generating {class_name} {i}: {str(e)}")
                continue
    print(f"Synthetic dataset generated at '{output_dir}'")
    return list(classes.keys())

def apply_image_augmentations(img):
    img = np.array(img)
    angle = np.random.uniform(-15, 15)
    h, w = img.shape[:2]
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    img = cv2.warpAffine(img, M, (w, h), borderValue=(255, 255, 255))
    scale = np.random.uniform(0.9, 1.1)
    img = cv2.resize(img, None, fx=scale, fy=scale)
    if np.random.rand() > 0.7:
        noise = np.random.randint(0, 25, img.shape, dtype=np.uint8)
        img = cv2.add(img, noise)
    return Image.fromarray(img)

class ChemicalStructureRecognizer:
    def __init__(self, num_classes, device=None, model_name='resnet50'):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_name = model_name
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        self.model = self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=0.01)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=3)

    def train(self, train_loader, val_loader=None, num_epochs=20):
        """Enhanced training loop with early stopping"""
        best_acc = 0.0
        best_model_wts = None
        train_losses = []
        val_losses = []
        val_accuracies = []
    
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0
    
            for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
            # Record metrics
            epoch_loss = running_loss / len(train_loader)
            train_losses.append(epoch_loss)
    
            val_loss, val_acc = 0.0, 0.0
            if val_loader:
                val_loss, val_acc = self.evaluate(val_loader)
                val_losses.append(val_loss)
                val_accuracies.append(val_acc)
                self.scheduler.step(val_loss)
    
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_model_wts = copy.deepcopy(self.model.state_dict())
            else:
                val_losses.append(epoch_loss)  # Use train loss as fallback
                val_accuracies.append(0)
    
        if best_model_wts:
            self.model.load_state_dict(best_model_wts)
    
        # ⚠️ THIS LINE WAS MISSING — ADD IT!
        return {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies
        }

    def evaluate(self, data_loader):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        avg_loss = running_loss / len(data_loader)
        accuracy = correct / total
        return avg_loss, accuracy

    def predict(self, image, top_k=3):
        self.model.eval()
        if isinstance(image, str):
            try:
                image = Image.open(image).convert('RGB')
            except Exception as e:
                print(f"Error loading image: {str(e)}")
                return None, None
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        image = transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(image)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            top_probs, top_classes = torch.topk(probs, top_k)
        return top_classes.cpu().numpy()[0], top_probs.cpu().numpy()[0]

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.model.eval()
        

def run_full_pipeline(pdf_path, output_base_dir="chemical_output",
                     chembl_dir="chembl_35", data_dir="chembl_images",
                     model_path="chemical_model.pth"):
    os.makedirs(output_base_dir, exist_ok=True)
    extracted_dir = os.path.join(output_base_dir, "extracted_pages")
    segmented_dir = os.path.join(output_base_dir, "segmented_structures")
    results_dir = os.path.join(output_base_dir, "results")
    recognized_dir = os.path.join(output_base_dir, "recognized_structures")
    for dir_path in [extracted_dir, segmented_dir, results_dir, recognized_dir]:
        os.makedirs(dir_path, exist_ok=True)
    print("\nStep 1/5: Generating chemical structure images...")
    if not os.path.exists(data_dir):
        try:
            class_mapping = generate_chemical_images(chembl_dir, data_dir, num_samples=50)
            with open(os.path.join(results_dir, "class_mapping.txt"), "w") as f:
                for key, value in class_mapping.items():
                    f.write(f"{key}: {value}\n")
        except Exception as e:
            print(f"Error generating chemical images: {e}")
            print("Generating synthetic dataset instead...")
            class_mapping = generate_chemical_images(data_dir, num_samples=5000)
    else:
        print("Using existing chemical structure images")
    print("\nStep 2/5: Preparing dataset...")
    try:
        train_dataset = ChemicalDataset(data_dir, transform=train_transform, split='train')
        val_dataset = ChemicalDataset(data_dir, transform=val_transform, split='val') \
            if os.path.exists(os.path.join(data_dir, 'val')) else None
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, persistent_workers=True)
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2) if val_dataset else None
        print(f"Found {len(train_dataset.classes)} classes")
        print(f"Training samples: {len(train_dataset)}")
        if val_loader:
            print(f"Validation samples: {len(val_dataset)}")
    except Exception as e:
        print(f"Error creating datasets: {e}")
        print("Generating synthetic dataset and trying again...")
        class_mapping = generate_chemical_images(data_dir, num_samples=5000)
        train_dataset = ChemicalDataset(data_dir, transform=train_transform, split='train')
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
        val_loader = None
    print("\nStep 3/5: Training model...")
    num_classes = len(train_dataset.classes)
    print(f"Found {num_classes} classes: {train_dataset.classes}")
    model = ChemicalStructureRecognizer(num_classes=num_classes)
    if os.path.exists(model_path):
        print(f"Loading existing model from {model_path}")
        model.load(model_path)
    else:
        print("Training new model...")
        history = model.train(train_loader, val_loader, num_epochs=10)
        model.save(model_path)
        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history['train_losses'], label='Train Loss')
        plt.plot(history['val_losses'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.subplot(1, 2, 2)
        val_accuracies = [acc for acc in history['val_accuracies'] if acc > 0]
        if val_accuracies:
            plt.plot(range(len(val_accuracies)), val_accuracies, label='Val Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Validation Accuracy')
            plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, "training_history.png"))
        plt.close()
    print("\nExtracting content from PDF...")
    extracted_dir = os.path.join(output_base_dir, "extracted_content")
    os.makedirs(extracted_dir, exist_ok=True)
    extract_images_from_pdf(pdf_path, extracted_dir)
    print("\nProcessing extracted content...")
    chemical_structures_dir = os.path.join(output_base_dir, "chemical_structures")
    os.makedirs(chemical_structures_dir, exist_ok=True)
    segmenter = ChemicalStructureSegmenter(min_structure_size=100, threshold=160)
    for item in os.listdir(extracted_dir):
        if item.endswith(('.png','.jpg','.jpeg')):
            img_path = os.path.join(extracted_dir, item)
            structures = segmenter.segment_image(img_path, chemical_structures_dir)
            print(f"Found {len(structures)} structures in {item}")
    
    print("\nStep 6/6: Making predictions...")
    results = []
    structure_count = 0
    for root, dirs, files in os.walk(chemical_structures_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                structure_path = os.path.join(root, file)
                pred_classes, probs = model.predict(structure_path)  # Get all top classes and probs
                # Take the first (most confident) prediction
                pred_classes, probs = model.predict(structure_path, top_k=num_classes)  # Get all classes
                for pred_class, confidence in zip(pred_classes, probs):
                    results.append({
                        'structure_id': structure_count,
                        'predicted_class': train_dataset.classes[pred_class],
                        'confidence': float(confidence),
                        'output_path': structure_path
                    })
                structure_count += 1
                if structure_count % 10 == 0:
                    print(f"Processed {structure_count} structures...")
    results_df = pd.DataFrame(results)
    results_csv = os.path.join(results_dir, "recognition_results.csv")
    results_df.to_csv(results_csv, index=False)
    print(f"\nProcessing complete! Results saved to {output_base_dir}")
    print(f"Total structures recognized: {structure_count}")
    print(f"Results CSV: {results_csv}")
    return results_df

if __name__ == "__main__":
    PDF_PATH = "/Users/johnsnow/Downloads/chemstru_recognition/data/sample.pdf"
    results = run_full_pipeline(
        pdf_path=PDF_PATH,
        chembl_dir="/Users/johnsnow/Downloads/chemstru_recognition/chembl_35/chembl_35_sqlite",
        data_dir="chembl_images",
        model_path="chemical_model.pth"
    )


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

Step 1/5: Generating chemical structure images...
Using existing chemical structure images

Step 2/5: Preparing dataset...
Found 14 classes
Training samples: 3721
Validation samples: 940

Step 3/5: Training model...
Found 14 classes: ['alcohol', 'aldehyde', 'alkane', 'alkene', 'alkyne', 'amide', 'amine', 'benzene', 'carboxylic_acid', 'ester', 'ether', 'halide', 'ketone', 'nitrile']
Loading existing model from chemical_model.pth

Extracting content from PDF...
Processing 1 pages from /Users/johnsnow/Downloads/chemstru_recognition/data/sample.pdf

Processing extracted content...
Found 8 structures in page_1_full.png
Found 10 structures in page_1_img_0.jpeg

Step 6/6: Making predictions...
Processed 10 structures...
Processed 20 structu