In [37]:
import io
import json
import os
import re
import sys
import subprocess
import tempfile as tmp
from pprint import pprint
from textwrap import dedent

from collections import defaultdict

In [38]:
test_py_src = dedent("""
    import pandas as pd
    from sklearn.datasets import load_iris
    
    # Load Iris dataset
    iris = load_iris()
    df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
    
    # Calculate mean and standard deviation for sepal length and width
    sepal_length_mean = df['sepal length (cm)'].mean()
    sepal_width_mean = df['sepal width (cm)'].mean()
    sepal_length_std = df['sepal length (cm)'].std()
    sepal_width_std = df['sepal width (cm)'].std()
    
    print(f"Sepal Length - Mean: {sepal_length_mean:.2f}, Std: {sepal_length_std:.2f}")
    print(f"Sepal Width - Mean: {sepal_width_mean:.2f}, Std: {sepal_width_std:.2f}")
    
    my_str = "foo"
    
    my_dict = {}
    my_dict[my_str] = 123

    if True:
        scoped_df = pd.DataFrame()
""")

with open("test.py", "w") as f:
    f.write(test_py_src)
    f.flush()

In [39]:
pyright_config = {
    "reportMissingImports": False
}

with open("pyrightconfig.json", "w") as f:
    json.dump(pyright_config, f)

In [40]:
pyright_cmd_base = ['pyright', '--outputjson']

def pyright_run(source: str):
    with tmp.NamedTemporaryFile(mode='w', suffix=".py") as file:
        file.write(source)
        file.flush()

        try:
            result = subprocess.run([*pyright_cmd_base, file.name], check=True, capture_output=True)
            return json.loads(result.stdout)
        except subprocess.CalledProcessError as cpe:
            print(cpe)
            print(cpe.output.decode())

In [41]:
pyright_run(open("./test.py").read())

{'version': '1.1.361',
 'time': '1715027177806',
 'generalDiagnostics': [],
 'summary': {'filesAnalyzed': 1,
  'errorCount': 0,
  'informationCount': 0,
  'timeInSec': 0.872}}

In [42]:
def pyright_infer(source: str, targets: list[tuple[str, int | None]]):
    """
    Infers types for expressions. Each target may have a line number specified, in which case the
    type revelation is inserted after that line. If None, it's inserted at the end.

    :param source: original source to infer
    :param targets: list of target expressions to infer, and the line to do so after (or None for EOF).
    """
    str_io = io.StringIO()
    target_exprs, _ = zip(*targets)  # this is python magic to unzip...

    ln_no_to_target_exprs = defaultdict(list)
    for (target_expr, target_line) in targets:
        ln_no_to_target_exprs[target_line].append(target_expr)


    target_name_to_expr = {}
    next_id = 0
        
    # Iterate through the source line by line, inserting line-specific targets.
    for ln_no, source_line in enumerate(source.splitlines()):
        ln_no = ln_no + 1  # index from 1, not 0!

        # print(f"{source_line}\n")
        
        str_io.write(f"{source_line}\n")

        # Gotcha: make sure to use the _same indentation as the line!_
        indent_level = len(source_line) - len(source_line.lstrip())
        indent = indent_level * " "

        for target_expr in ln_no_to_target_exprs[ln_no]:
            target_name = f"__infer_target_{next_id}__"
            next_id = next_id + 1

            target_name_to_expr[target_name] = target_expr

            # print(f"\n{indent}{target_name} = {target_expr}")
            # print(f"\n{indent}reveal_type({target_name})")

            str_io.write(f"\n{indent}{target_name} = {target_expr}")
            str_io.write(f"\n{indent}reveal_type({target_name})")
    str_io.write("\n")

    # Insert the end-of-file targets.
    for target_expr in ln_no_to_target_exprs[None]:
        target_name = f"__infer_target_{next_id}__"
        next_id = next_id + 1

        target_name_to_expr[target_name] = target_expr
        
        str_io.write(f"\n{target_name} = {target_expr}")
        str_io.write(f"\nreveal_type({target_name})")
    str_io.write("\n")

    # Rewind and feed the augmented script into pyright.
    str_io.seek(0)
    
    result = pyright_run(str_io.read())

    # Pull out what we want to know from the 'information's 
    pattern = r'Type of "(?P<target>\w+)" is "(?P<type>.+?)"'
    regex = re.compile(pattern)
    
    informations = {}
    inferred_types = dict.fromkeys(target_exprs)
    
    for diagnostic in result['generalDiagnostics']:
        # print(diagnostic)
        if diagnostic['severity'] == 'information':
            message = diagnostic['message']
            if (match := regex.fullmatch(message)):
                target_name = match.group("target")
                inferred_type = match.group("type")

                inferred_types[target_name_to_expr[target_name]] = inferred_type
            else:
                continue

    return inferred_types

In [43]:
pyright_infer(open("./test.py").read(), [
    ("df", None),
    ("df['foo']", None),
    ('df[df["bar"] > 0]', None),
    ("my_dict", None), 
    ("my_dict['foo']", None),
    ("scoped_df", 24)
])

{'df': 'DataFrame',
 "df['foo']": 'Series[Unknown]',
 'df[df["bar"] > 0]': 'DataFrame',
 'my_dict': 'dict[Unknown, Unknown]',
 "my_dict['foo']": 'Unknown',
 'scoped_df': 'DataFrame'}