In [5]:
import ast
import importlib.metadata as md
import nbformat
import os
import sys

# --- CONFIGURATION ---
TARGET_ROOT = "./parametric_qst"
ENV_NAME = "my_project_env"
OUTPUT_YAML = "environment.yaml"
OUTPUT_TXT = "requirements.txt"

# Directories to strictly ignore during the recursive scan
IGNORE_DIRS = {
    ".git", ".idea", ".vscode", "__pycache__",
    ".ipynb_checkpoints", "venv", ".venv", "env",
    "node_modules", ".mypy_cache", "build", "dist",
}

# Specific files to ignore
IGNORE_FILES = {
    "setup.py",
    ".DS_Store",              # <-- file, not dir
    OUTPUT_YAML,
    OUTPUT_TXT,
    # strongly recommended: ignore the generator itself
    "generate_environment_yaml.py",
    "generate_environment_yaml.ipynb",
}

print(f"Scanning target: '{os.path.abspath(TARGET_ROOT)}'")

Scanning target: '/Users/Tonni/Desktop/master-code/neural-quantum-tomo/parametric_qst'


In [6]:
def print_tree(startpath: str) -> None:
    print(f"üìÅ Project Root: {os.path.basename(os.path.abspath(startpath))}/")

    try:
        with os.scandir(startpath) as entries:
            sorted_entries = sorted(entries, key=lambda e: (not e.is_dir(), e.name.lower()))

            for entry in sorted_entries:
                ignored = (entry.name in IGNORE_DIRS) or (entry.name in IGNORE_FILES)
                status = "üö´ (Ignored)" if ignored else "‚úÖ"

                if entry.is_dir():
                    print(f"    ‚îú‚îÄ‚îÄ üìÇ {entry.name}/  {status}")
                else:
                    print(f"    ‚îú‚îÄ‚îÄ üìÑ {entry.name}  {status}")
    except FileNotFoundError:
        print(f"‚ùå Error: Directory '{startpath}' not found.")

print_tree(TARGET_ROOT)

üìÅ Project Root: parametric_qst/
    ‚îú‚îÄ‚îÄ üìÇ __pycache__/  üö´ (Ignored)
    ‚îú‚îÄ‚îÄ üìÇ data_handling/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ hyper_rbm/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_16_err/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_16_final/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_3x3_error/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_3x3_final/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_4x4/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_4x4_error/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_4x4_final/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ tfim_4x4_final_again/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÇ visualization_err/  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÑ .DS_Store  üö´ (Ignored)
    ‚îú‚îÄ‚îÄ üìÑ tfim_16_mod.zip  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÑ tfim_3x3.zip  ‚úÖ
    ‚îú‚îÄ‚îÄ üìÑ wavefunction_overlap.py  ‚úÖ


In [7]:
def extract_imports_from_code(code: str) -> set[str]:
    """Return imported top-level module names. Skips relative imports."""
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return set()

    imports: set[str] = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                imports.add(alias.name.split(".")[0])

        elif isinstance(node, ast.ImportFrom):
            # Skip relative imports like: from .foo import bar
            if getattr(node, "level", 0) and node.level > 0:
                continue
            if node.module:
                imports.add(node.module.split(".")[0])

    return imports


def get_imports_from_file(filepath: str) -> set[str]:
    if filepath.endswith(".py"):
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                return extract_imports_from_code(f.read())
        except Exception as e:
            print(f"‚ö†Ô∏è  Error reading {filepath}: {e}")
            return set()

    if filepath.endswith(".ipynb"):
        try:
            nb = nbformat.read(filepath, as_version=4)
            out: set[str] = set()
            for cell in nb.cells:
                if cell.cell_type == "code":
                    out |= extract_imports_from_code(cell.source)
            return out
        except Exception as e:
            print(f"‚ö†Ô∏è  Error reading notebook {filepath}: {e}")
            return set()

    return set()


def collect_local_names(target_root: str, ignore_dirs: set[str], ignore_files: set[str]) -> set[str]:
    """
    Collect local modules AND packages:
    - module: foo.py -> "foo"
    - package: foo/__init__.py -> "foo"
    """
    local: set[str] = set()

    for root, dirs, files in os.walk(target_root):
        # prune ignored dirs
        dirs[:] = [d for d in dirs if (d not in ignore_dirs and not d.startswith("."))]

        # package dirs (presence of __init__.py)
        if "__init__.py" in files:
            pkg_name = os.path.basename(root)
            if pkg_name not in ignore_dirs:
                local.add(pkg_name)

        for fname in files:
            if fname in ignore_files:
                continue
            if fname.endswith(".py"):
                local.add(os.path.splitext(fname)[0])

    return local


# --- EXECUTION ---
all_raw_imports: set[str] = set()

print(f"üîé Starting recursive scan of '{TARGET_ROOT}'...")

local_names = collect_local_names(TARGET_ROOT, IGNORE_DIRS, IGNORE_FILES)

for root, dirs, files in os.walk(TARGET_ROOT):
    dirs[:] = [d for d in dirs if (d not in IGNORE_DIRS and not d.startswith("."))]

    for fname in files:
        if fname in IGNORE_FILES:
            continue
        if fname.startswith("."):
            continue

        full_path = os.path.join(root, fname)

        if fname.endswith((".py", ".ipynb")):
            all_raw_imports |= get_imports_from_file(full_path)

print("\n‚úÖ Scan Complete.")
print(f"   Found {len(all_raw_imports)} unique raw imports.")
print(f"   Detected {len(local_names)} local modules/packages (excluded from reqs).")

print("\nLocal examples:", ", ".join(sorted(local_names)[:15]) if local_names else "(none)")

üîé Starting recursive scan of './parametric_qst'...

‚úÖ Scan Complete.
   Found 27 unique raw imports.
   Detected 17 local modules/packages (excluded from reqs).

Local examples: __init__, data_gen, data_handling, dataloader, hyper_rbm, io, io_npz, io_txt, measurement, single_point_rbm, single_point_rbm_v2, single_point_rbm_v3, symmetric_hyper_rbm, test_file, training


In [8]:
# --- STDLIB FILTERING ---
STDLIB_LIKE = {
    "__future__", "ast", "importlib", "os", "pathlib", "sys", "time", "typing",
    "builtins", "types", "json", "math", "re", "subprocess", "datetime",
    "shutil", "random", "collections", "itertools", "functools", "pickle",
    "logging", "platform", "io", "contextlib", "copy", "csv",
    "dataclasses",
}

official_stdlib = set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()

resolved_reqs: list[str] = []
missing_packages: list[str] = []
skipped_local: list[str] = []
skipped_stdlib: list[str] = []

print("üßê Analyzing imports...")

for pkg in sorted(all_raw_imports):
    if pkg.startswith("_"):
        continue

    # local modules/packages
    if pkg in local_names:
        skipped_local.append(pkg)
        continue

    # stdlib
    if pkg in official_stdlib or pkg in STDLIB_LIKE:
        skipped_stdlib.append(pkg)
        continue

    # installed pip distributions
    try:
        version = md.version(pkg)
        resolved_reqs.append(f"{pkg}=={version}")
    except md.PackageNotFoundError:
        missing_packages.append(pkg)

print("\n--- RESULTS ---")
print(f"üìö Standard Lib Skipped: {len(skipped_stdlib)} (e.g. {', '.join(sorted(skipped_stdlib)[:8])})")
print(f"üè† Local Skipped:        {len(skipped_local)} (e.g. {', '.join(sorted(skipped_local)[:8])})")
print(f"‚ùå Not Installed/Found:  {len(missing_packages)} (CHECK: {', '.join(sorted(missing_packages))})")
print(f"‚úÖ Ready to write:       {len(resolved_reqs)} packages")

# --- PREVIEW WHAT WILL BE WRITTEN ---
py_ver = f"{sys.version_info.major}.{sys.version_info.minor}"

yaml_lines = [
    f"name: {ENV_NAME}",
    "dependencies:",
    f"  - python={py_ver}",
    "  - pip",
    "  - pip:",
]
yaml_lines += [f"      - {r}" for r in resolved_reqs]

yaml_preview = "\n".join(yaml_lines) + "\n"
reqs_preview = "\n".join(resolved_reqs) + ("\n" if resolved_reqs else "")

print("\n--- PREVIEW environment.yaml ---\n")
print(yaml_preview)

print("\n--- PREVIEW requirements.txt ---\n")
print(reqs_preview if reqs_preview.strip() else "(empty)")

üßê Analyzing imports...

--- RESULTS ---
üìö Standard Lib Skipped: 12 (e.g. dataclasses, datetime, itertools, json, math, os, pathlib, random)
üè† Local Skipped:        3 (e.g. data_handling, hyper_rbm, wavefunction_overlap)
‚ùå Not Installed/Found:  2 (CHECK: argparse, glob)
‚úÖ Ready to write:       9 packages

--- PREVIEW environment.yaml ---

name: my_project_env
dependencies:
  - python=3.9
  - pip
  - pip:
      - IPython==8.11.0
      - joblib==1.4.2
      - matplotlib==3.7.1
      - netket==3.13.0
      - numpy==1.24.2
      - pandas==2.2.0
      - scipy==1.13.1
      - torch==1.13.1
      - tqdm==4.65.0


--- PREVIEW requirements.txt ---

IPython==8.11.0
joblib==1.4.2
matplotlib==3.7.1
netket==3.13.0
numpy==1.24.2
pandas==2.2.0
scipy==1.13.1
torch==1.13.1
tqdm==4.65.0

