In [1]:
import os 
import json 

In [31]:
class DirectoryNodeInitError(Exception): 
    code = 400 
    description = "Failure to initialize direcotry node"

class FileNodeInitError(Exception): 
    code = 400 
    description = "Failure to initialize file node"

class DirectoryGenerationError(Exception): 
    code = 400 
    description = "Error when generation directory"

In [149]:
import logging 

class FileNode: 
    def __init__(self, path: str):
        if (not os.path.isfile(path)): 
            raise FileExistsError
        self.type: str = "file" 
        self.path: str = path 
        self.name: str = os.path.basename(path) 
        self.size: str = "" #TODO: use module to get 
        self.file_content_modified: str = ""
        self.past_file_content: list[str] = [] #tracks past version of unit test in case we want to revert (maybe use dict bc i think we should only use 5 so we don't have ridiculous memory)

        try: 
            with open(path, "r") as f: 
                self.file_content_original: str = f.read() 
        except Exception as e: 
            logging.error("Error when reading file: ",e) 
            self.file_content_original = ""

    def unit_test(self): 
        """ Run unit test in the file. Return the unit tested code"""
        pass 

    def refactor(self): 
        """Refactor the code in this file. Return the refactored code"""


class DirectoryNode: 
    def __init__(self, path: str):
        if (not os.path.isdir(path)): 
            raise DirectoryNodeInitError
        self.type: str = "dir"
        self.path: str = path 
        self.name: str = os.path.basename(path)
        self.directory_content: list[FileNode | DirectoryNode] = []

class DirectoryTree: 
    """
    Representing a folder using a tree structure 
    """
    def __init__(self, 
                 root_path: str): 
        self.root = DirectoryNode(root_path)
        self.root = self.recursive_generate(self.root) 

        #TODO: maybe add a cache here for fast retrieval for projects with tons of files 

    def recursive_generate(self, dir_node: DirectoryNode): 
        """ 
        Recursively populate the directory node

        If file append 

        If directory recurse and populate 
        """ 
        if (not os.path.isdir(dir_node.path)): 
            raise DirectoryGenerationError

        for node in list(os.listdir(dir_node.path)): 
            path_to_node: str = os.path.join(dir_node.path, node)

            if os.path.isfile(path_to_node): 
                dir_node.directory_content.append(FileNode(path_to_node))
                 
            if os.path.isdir(path_to_node):
                directory_node = DirectoryNode(path_to_node)
                directory_content = self.recursive_generate(directory_node)
                dir_node.directory_content.append(directory_content)

        return dir_node
        
class DirectoryParser: 
    def __init__(self, root_path: str): 
        self.directory_tree = DirectoryTree(root_path)
        #TODO: get size of directory tree and save that

    def _crawl_directory(self, dir_node: DirectoryNode, dir_level: int = 1): 
        """crawl through all file and print out file or dir name"""
        if (not os.path.isdir(dir_node.path)): 
            raise DirectoryGenerationError

        for node in dir_node.directory_content: 
            path_to_node: str = os.path.join(dir_node.path, node.name)
            
            if os.path.isdir(path_to_node):
                print(f" |{'-' * dir_level} ", node.name)
                dir_level += 1
                self._crawl_directory(node, dir_level)
                dir_level -= 1

            if os.path.isfile(path_to_node): 
                print(" \t - ", node.name)
                 
    def display_tree(self): 
        """scan through entire tree and display a structure to view"""
        
        print("Starting at directory: ", self.directory_tree.root.name)
        self._crawl_directory(self.directory_tree.root) 
    
    def retrieve_file(self, dir_node, file_name): 
        """Scan through the entire tree and if filename match return node of file otherwise return None
        
        TODO: Maybe can optimize with a faster tree ??
        """

        for node in dir_node.directory_content: 
            path_to_node: str = os.path.join(dir_node.path, node.name)
            
            if os.path.isdir(path_to_node):
                n = self.retrieve_file(node, file_name=file_name)
                if n: 
                    return n
            if os.path.isfile(path_to_node) and os.path.basename(path_to_node) == file_name: 
                return node 
            
        return None 

In [150]:
root = os.path.join(os.getcwd(), "src")

In [151]:
directory_parser = DirectoryParser(root)

In [152]:
directory_parser.display_tree()

Starting at directory:  src
 |-  routers
 	 -  router.py
 	 -  config.py
 |-  utils
 |-  schema
 	 -  serialize.py
 	 -  validate_schema.py
 	 -  output_schema.py
 	 -  input_schema.py
 	 -  app.py
 	 -  errors.py
 	 -  main.py
 |-  gpt_model
 	 -  generate_authorization_token.py
 	 -  call_retry.py
 	 -  response_handler.py
 	 -  gpt_extraction.py
 	 -  prompts.py
 	 -  gpt_confidence_score.py
 	 -  gpt.py


In [154]:
f = directory_parser.retrieve_file(directory_parser.directory_tree.root, "config.py")

In [155]:
f

<__main__.FileNode at 0x10e8b85d0>

In [158]:
f.__dict__

{'type': 'file',
 'path': '/Users/zhidongjiang/Desktop/unit-test-refactoring/src/config.py',
 'name': 'config.py',
 'size': '',
 'file_content_modified': '',
 'past_file_content': [],
 'file_content_original': 'import os \nimport logging \nfrom dotenv import load_dotenv, find_dotenv \n\nfrom pydantic_settings import baseSettings \n\nload_dotenv(override=True) \n\nROOT_PATH = os.getenv("ROOT_PATH", None) \n\nclass Settings(BaseSettings): \n    ENV: str = "sit"\n    CONNECTING_STRING: str = "" \n    logger: logging.Logger = logger \n\n    EMBEDDINGS_API_KEY : str = "" \n    OPENAI_API_KEY: str = "" \n    OPENAI_API_BASE: str = "" \n    OPENAI_API_TYPE: str = "" \n    OPENAI_API_ENDPOINT: str = "" \n    LLM_MODEL_NAME: str = "" \n    GPT_USERNAME: str = "" \n\n    # ENV: os.getenv("ENV", "dev")\n    # logger = logging.getLogger(__name__)\n    # max_yt_api_calls = 100\n    # model: str="gpt-4o-mini" \n    # youtube_developer_key: str=os.environ(youtube_developer_key, None) \n\nsetting = Se