# Fine-tuning Language Models to Predict Stock Performance
This notebook is a relatively basic attempt at using the capabilities provided by language models to predict stock performance based on historical stock performance and macro-economic data.

In [None]:
# Install a way to convert between SEC CIK primary keys and TICKR symbols. Leverage the finagg library which aggregates some useful free to access financial data.
%pip install --upgrade sec-cik-mapper finagg requests python-dotenv

In [None]:
from os import environ, getcwd
from dotenv import load_dotenv

environ['FINAGG_ROOT_PATH'] = f'{getcwd()}'
"""
Prior to running this cell you'll need  a .env file with some api keys.

BEA_API_KEY=#get from https://apps.bea.gov/api/signup/
FRED_API_KEY=#get fromhttps://fred.stlouisfed.org/docs/api/api_key.html
"""
load_dotenv(f'{getcwd()}/env')

In [None]:
!finagg install -ss economic -ts sec -ts indices --stock-data -z -r -s sec -s yfinance
!finagg fred install --raw series --series observation

In [None]:
import requests, os, csv
# Set up folder to persist financial data to.
fin_data_path = 'findata'
os.makedirs(fin_data_path, exist_ok=True)

russell_path_raw = f'{fin_data_path}/russell-3000.csv'
russell_path = f'{fin_data_path}/russell-3000-clean.csv'

url = 'https://www.ishares.com/us/products/239714/ishares-russell-3000-etf/1467271812596.ajax?fileType=csv&fileName=IWV_holdings&dataType=fund&asOfDate=20240321'
response = requests.get(url)


with open(russell_path_raw, 'wb') as f:
    f.write(response.content)

with open(russell_path_raw, 'r', encoding='utf-8') as f:
    reader = csv.reader(f)
    rows = list(reader)

# Get the start and end row of csv
empty_row_indicies = [i for i in range(len(rows)) if (len(rows[i]) == 0 or '\xa0' in rows[i])]

print('Empty rows:', empty_row_indicies)

start = empty_row_indicies[0] + 1
end = empty_row_indicies[1]
# Skip rows with ticker symbols and skip irrelevant file metadata. Only include NASDAQ and NYSE exchanges.
relevant_rows = [ r for r in rows[start:end] if r[0] != '-' and r[10].strip() in ('Exchange', 'NASDAQ', 'New York Stock Exchange Inc.') ]

# write csv
with open(russell_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerows(relevant_rows)

In [None]:
# Install pandas so we can work with dataframes
%pip install --upgrade pandas pandas-ai


In [None]:
import pandas as pd

# load Russell 3000 holdings CSV into a dataframe
holdings = pd.read_csv(russell_path)
holdings.columns

In [None]:

holdings.drop(columns=['Market Value', 'Weight (%)', 'Notional Value', 'Price', 'FX Rate', 'Currency', 'Market Currency', 'Accrual Date', 'Asset Class'], inplace=True)
holdings

In [None]:
# Load the sec_cik_mapper module to help us find financial data associated with above companies
import sec_cik_mapper

stock_mapper = sec_cik_mapper.StockMapper()
ticker_to_cik = stock_mapper.ticker_to_cik
cik_to_exchange = stock_mapper.cik_to_exchange
cik_to_name = stock_mapper.cik_to_company_name

In [None]:
holdings['cik'] = holdings['Ticker'].map(lambda ticker: ticker_to_cik.get(ticker))
train_companies = holdings.dropna(subset=['cik'])
train_companies

In [None]:
# Store output organized in output directories
outdir = './out'

In [None]:
%pip install numpy dataclasses-json

In [None]:
import finagg

# Utility methods to get financial indicators:
quarter_ranges = {
    1: ('01-01', '04-01'),
    2: ('04-01', '07-01'),
    3: ('07-01', '10-01'),
    4: ('10-01', '01-01')
}

def get_fred_indicator(year, q, indicator):
    """
    Fetches the specified FRED indicator for the given year and quarter.
    Caches the results to avoid redundant API calls.
    """
    try:
        with shelve.open("fred-indicators") as indicator_cache:
            year, q = int(year), int(q)
            indicator_cache_key = "|".join([str(year), str(q), indicator])
            if indicator_cache_key in indicator_cache:
                return indicator_cache[indicator_cache_key]
            
            # Determine the end year based on the quarter
            end_year = year if q < 4 else year + 1
            
            # Fetch data from FRED API
            data = finagg.fred.api.series.observations.get(
                indicator,
                observation_start=f'{year}-{quarter_ranges[q][0]}',
                observation_end=f'{end_year}-{quarter_ranges[q][1]}',
            )
            
            # Calculate the average for the quarter
            result = data['value'].aggregate(lambda l: sum(l) / len(l))
            indicator_cache[indicator_cache_key] = result  # Cache the result
            return result
    except KeyError:
        return None

def get_share_price(ticker, start_date, end_date):
    return finagg.yfinance.api.get(ticker, start=str(start_date), end=str(end_date))



In [None]:
%pip install --upgrade retry timeout_decorator # Some of what we're about to do is unreliable. We install libraries to help us make it reliable

In [None]:
%pip install --upgrade --quiet google-genai

In [None]:
from IPython.display import HTML, Markdown, display
from google import genai
from google.genai.types import (
    FunctionDeclaration,
    GenerateContentConfig,
    GoogleSearch,
    Part,
    Retrieval,
    SafetySetting,
    Tool,
    VertexAISearch,
)
from abc import ABC, abstractmethod
from typing import List, Optional
import json
from dataclasses import dataclass

class Prompt(ABC):
    
    @abstractmethod
    def text(self) -> dict:
        raise NotImplementedError()

@dataclass
class GeminiPrompt(Prompt):
    task: str
    context_information: Optional[List[str]] = None
    instructions: Optional[List[str]] = None
    response_format_instructions: Optional[List[str]] = None

    
    def text(self) -> dict:
        """
        According to nova [guidance](https://docs.aws.amazon.com/nova/latest/userguide/prompting-precision.html), response format instructions should be in the form:
        // use this to clearly define the task and job needed by the model
        Task:
        {{Task summary}} 

        // use this to provide contextual information related to the task
        Context information:
        - {{Context and content information 1}}
        - {{Context and content information 2}}
        ...

        // use this to provide any model instructions that you want model to adhere to
        Model Instructions:
        - {{ Other Model Instructions }}
        ...

        // use this to provide response style and formatting guidance
        Response style and format requirements:
        - {{Style and format requirement 1}}
        - {{Style and format requirement 2}}
        ...

        """
        context_information = "\n".join([ f"- {info}" for info in self.context_information]) if self.context_information else ""
        model_instructions = "\n".join([ f"- {instruction}" for instruction in self.instructions]) if self.instructions else ""
        response_format_instructions = "\n".join([ f"- {instruction}" for instruction in self.response_format_instructions]) if self.response_format_instructions else ""
        prompt_text = f"Task:\n{self.task}\n\n"
        if context_information:
            prompt_text += f"Context information:\n{context_information}\n\n"
        if model_instructions:
            prompt_text += f"Model Instructions:\n{model_instructions}\n\n"
        if response_format_instructions:
            prompt_text += f"Response style and format requirements:\n{response_format_instructions}\n\n"
        return prompt_text


In [None]:
PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))
LOCATION = "us-central1"
client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
MODEL_ID = "gemini-2.0-flash-exp"

def get_gemini_response(prompt: Prompt,  output_format="json", max_tokens=1000, model_id=MODEL_ID):
    response = client.models.generate_content(
        model=MODEL_ID,
        contents=prompt.text(),
        config=GenerateContentConfig(
            stop_sequences=["\n```"],
            max_output_tokens=max_tokens
        )
    )
    try:
        text = response.text.replace(f"```{output_format}","").replace("```","")
        if output_format == "json":
            return json.loads(text)
        return text
    except Exception as e:
        print(f"Failed to parse response: {response}.\n Query was: {prompt.text()}")
        raise
        



In [None]:
# Example prompt that selects 10 keys from the us-gaap keys provided in the context below to enable a financial analyst to quickly but accurately understand the company's performance.

keys = ['AccountsPayableAndAccruedLiabilitiesCurrent', 'AccountsReceivableNetCurrent', 'AccruedIncomeTaxesCurrent', 'AccumulatedDepreciationDepletionAndAmortizationPropertyPlantAndEquipment', 'AccumulatedOtherComprehensiveIncomeLossNetOfTax', 'AdditionalPaidInCapitalCommonStock', 'AllowanceForDoubtfulAccountsReceivableCurrent', 'AssetImpairmentCharges', 'Assets', 'AssetsCurrent', 'AssetsOfDisposalGroupIncludingDiscontinuedOperation', 'AvailableForSaleDebtSecuritiesGrossUnrealizedLoss', 'AvailableForSaleSecurities', 'AvailableForSaleSecuritiesAmortizedCost', 'AvailableForSaleSecuritiesDebtMaturitiesAfterFiveThroughTenYearsAmortizedCost', 'AvailableForSaleSecuritiesDebtMaturitiesAfterFiveThroughTenYearsFairValue', 'AvailableForSaleSecuritiesDebtMaturitiesAfterOneThroughFiveYearsAmortizedCost', 'AvailableForSaleSecuritiesDebtMaturitiesAfterOneThroughFiveYearsFairValue', 'AvailableForSaleSecuritiesDebtMaturitiesAfterTenYearsAmortizedCost', 'AvailableForSaleSecuritiesDebtMaturitiesAfterTenYearsFairValue', 'AvailableForSaleSecuritiesDebtMaturitiesWithinOneYearAmortizedCost', 'AvailableForSaleSecuritiesDebtMaturitiesWithinOneYearFairValue', 'AvailableForSaleSecuritiesGrossRealizedGains', 'AvailableForSaleSecuritiesGrossRealizedLosses', 'AvailableForSaleSecuritiesGrossUnrealizedGains', 'AvailableForSaleSecuritiesGrossUnrealizedLoss', 'BusinessCombinationStepAcquisitionEquityInterestInAcquireeRemeasurementLoss', 'CashAndCashEquivalentsAtCarryingValue', 'CashAndCashEquivalentsPeriodIncreaseDecrease', 'CashCashEquivalentsAndShortTermInvestments', 'CashFlowHedgeGainLossToBeReclassifiedWithinTwelveMonths', 'CommonStockDividendsPerShareDeclared', 'CommonStockParOrStatedValuePerShare', 'CommonStockSharesAuthorized', 'CommonStockSharesIssued', 'CommonStockValue', 'ComprehensiveIncomeNetOfTax', 'ComprehensiveIncomeNetOfTaxAttributableToNoncontrollingInterest', 'ComprehensiveIncomeNetOfTaxIncludingPortionAttributableToNoncontrollingInterest', 'CostMethodInvestments', 'CostOfGoodsSold', 'DeferredIncomeTaxExpenseBenefit', 'DeferredTaxLiabilitiesNoncurrent', 'DefinedBenefitPlanContributionsByEmployer', 'DefinedBenefitPlansEstimatedFutureEmployerContributionsInCurrentFiscalYear', 'DepreciationDepletionAndAmortization', 'DerivativeCollateralObligationToReturnCash', 'DisposalGroupNotDiscontinuedOperationGainLossOnDisposal', 'DividendsCommonStockCash', 'EarningsPerShareBasic', 'EarningsPerShareDiluted', 'EffectiveIncomeTaxRateContinuingOperations', 'EffectiveIncomeTaxRateReconciliationAtFederalStatutoryIncomeTaxRate', 'EffectOfExchangeRateOnCashAndCashEquivalents', 'EquityMethodInvestmentRealizedGainLossOnDisposal', 'EquityMethodInvestments', 'ExtinguishmentOfDebtAmount', 'ForeignCurrencyTransactionGainBeforeTax', 'ForeignCurrencyTransactionGainLossUnrealized', 'GainLossOnSaleOfOtherAssets', 'GainsLossesOnExtinguishmentOfDebt', 'Goodwill', 'GrossProfit', 'ImpairmentOfIntangibleAssetsExcludingGoodwill', 'ImpairmentOfIntangibleAssetsIndefinitelivedExcludingGoodwill', 'IncomeLossFromContinuingOperationsBeforeIncomeTaxesExtraordinaryItemsNoncontrollingInterest', 'IncomeLossFromEquityMethodInvestments', 'IncomeLossFromEquityMethodInvestmentsNetOfDividendsOrDistributions', 'IncomeTaxExpenseBenefit', 'IncreaseDecreaseInOperatingCapital', 'IndefiniteLivedFranchiseRights', 'IndefiniteLivedIntangibleAssetsExcludingGoodwillFairValueDisclosure', 'IndefiniteLivedTrademarks', 'InterestExpense', 'InventoryFinishedGoodsNetOfReserves', 'InventoryNet', 'InventoryRawMaterialsAndSuppliesNetOfReserves', 'InvestmentIncomeInterest', 'LiabilitiesAndStockholdersEquity', 'LiabilitiesCurrent', 'LiabilitiesOfDisposalGroupIncludingDiscontinuedOperation', 'LongTermDebt', 'LongTermDebtCurrent', 'LongTermDebtFairValue', 'LongTermDebtNoncurrent', 'MarketableSecuritiesCurrent', 'MinorityInterest', 'MinorityInterestDecreaseFromDistributionsToNoncontrollingInterestHolders', 'NetCashProvidedByUsedInFinancingActivities', 'NetCashProvidedByUsedInInvestingActivities', 'NetCashProvidedByUsedInOperatingActivities', 'NetIncomeLoss', 'NetIncomeLossAttributableToNoncontrollingInterest', 'NoncontrollingInterestDecreaseFromDeconsolidation', 'OperatingIncomeLoss', 'OtherAssetsNoncurrent', 'OtherComprehensiveIncomeAvailableforsaleSecuritiesAdjustmentBeforeTaxPortionAttributableToParent', 'OtherComprehensiveIncomeAvailableforsaleSecuritiesAdjustmentNetOfTaxPortionAttributableToParent', 'OtherComprehensiveIncomeAvailableforsaleSecuritiesTaxPortionAttributableToParent', 'OtherComprehensiveIncomeDefinedBenefitPlansAdjustmentNetOfTaxPortionAttributableToParent', 'OtherComprehensiveIncomeDerivativesQualifyingAsHedgesNetOfTaxPortionAttributableToParent', 'OtherComprehensiveIncomeForeignCurrencyTransactionAndTranslationAdjustmentNetOfTaxPortionAttributableToParent', 'OtherComprehensiveIncomeForeignCurrencyTransactionAndTranslationGainLossArisingDuringPeriodNetOfTax', 'OtherComprehensiveIncomeForeignCurrencyTransactionAndTranslationGainLossBeforeReclassificationAndTax', 'OtherComprehensiveIncomeForeignCurrencyTranslationGainLossArisingDuringPeriodTax', 'OtherComprehensiveIncomeLossAmortizationAdjustmentFromAOCIPensionAndOtherPostretirementBenefitPlansForNetPriorServiceCostCreditNetOfTax', 'OtherComprehensiveIncomeLossBeforeTax', 'OtherComprehensiveIncomeLossDerivativesQualifyingAsHedgesBeforeTax', 'OtherComprehensiveIncomeLossDerivativesQualifyingAsHedgesNetOfTax', 'OtherComprehensiveIncomeLossDerivativesQualifyingAsHedgesTax', 'OtherComprehensiveIncomeLossForeignCurrencyTransactionAndTranslationAdjustmentBeforeTax', 'OtherComprehensiveIncomeLossForeignCurrencyTransactionAndTranslationAdjustmentNetOfTax', 'OtherComprehensiveIncomeLossForeignCurrencyTransactionAndTranslationReclassificationAdjustmentFromAOCIRealizedUponSaleOrLiquidationBeforeTax', 'OtherComprehensiveIncomeLossForeignCurrencyTransactionAndTranslationReclassificationAdjustmentFromAOCIRealizedUponSaleOrLiquidationNetOfTax', 'OtherComprehensiveIncomeLossForeignCurrencyTransactionAndTranslationReclassificationAdjustmentFromAOCIRealizedUponSaleOrLiquidationTax', 'OtherComprehensiveIncomeLossForeignCurrencyTranslationAdjustmentTax', 'OtherComprehensiveIncomeLossNetOfTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansAdjustmentBeforeReclassificationAdjustmentsAndTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansAdjustmentBeforeReclassificationAdjustmentsNetOfTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansAdjustmentBeforeTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansAdjustmentNetOfTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansBeforeReclassificationAdjustmentsTax', 'OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIForSaleOfSecuritiesBeforeTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIForSaleOfSecuritiesNetOfTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIForSaleOfSecuritiesTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIOnDerivativesBeforeTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIOnDerivativesNetOfTax', 'OtherComprehensiveIncomeLossReclassificationAdjustmentFromAOCIOnDerivativesTax', 'OtherComprehensiveIncomeLossTax', 'OtherComprehensiveIncomeUnrealizedGainLossOnDerivativesArisingDuringPeriodBeforeTax', 'OtherComprehensiveIncomeUnrealizedGainLossOnDerivativesArisingDuringPeriodNetOfTax', 'OtherComprehensiveIncomeUnrealizedGainLossOnDerivativesArisingDuringPeriodTax', 'OtherComprehensiveIncomeUnrealizedHoldingGainLossOnSecuritiesArisingDuringPeriodBeforeTax', 'OtherComprehensiveIncomeUnrealizedHoldingGainLossOnSecuritiesArisingDuringPeriodNetOfTax', 'OtherComprehensiveIncomeUnrealizedHoldingGainLossOnSecuritiesArisingDuringPeriodTax', 'OtherInventoryNetOfReserves', 'OtherLiabilitiesNoncurrent', 'OtherNoncashExpense', 'OtherNoncashIncomeExpense', 'OtherNonoperatingIncomeExpense', 'OtherShortTermInvestments', 'PaymentsForProceedsFromOtherInvestingActivities', 'PaymentsForRepurchaseOfCommonStock', 'PaymentsOfDividends', 'PaymentsToAcquirePropertyPlantAndEquipment', 'PrepaidExpenseAndOtherAssetsCurrent', 'ProceedsFromIssuanceOfCommonStock', 'ProceedsFromIssuanceOfDebt', 'ProceedsFromPaymentsForOtherFinancingActivities', 'ProceedsFromSaleOfAvailableForSaleSecurities', 'ProceedsFromSaleOfPropertyPlantAndEquipment', 'ProfitLoss', 'PropertyPlantAndEquipmentNet', 'RepaymentsOfDebt', 'RepaymentsOfLongTermDebt', 'RetainedEarningsAccumulatedDeficit', 'SalesRevenueGoodsNet', 'SellingGeneralAndAdministrativeExpense', 'ShareBasedCompensation', 'StockholdersEquity', 'StockholdersEquityIncludingPortionAttributableToNoncontrollingInterest', 'StockIssuedDuringPeriodValueShareBasedCompensation', 'TreasuryStockShares', 'TreasuryStockValue', 'TreasuryStockValueAcquiredCostMethod', 'WeightedAverageNumberDilutedSharesOutstandingAdjustment', 'WeightedAverageNumberOfDilutedSharesOutstanding', 'WeightedAverageNumberOfSharesOutstandingBasic']
key_string = '\n' + '\n* '.join(keys)
context_information = [f"In this companies us-gaap report, the following keys are available: {key_string}"]
prompt = GeminiPrompt(
    task="Select 20 keys from the us-gaap keys provided in the context to enable a financial analyst to quickly make a snap-judgment of the company's performance.",
    context_information=context_information,
    instructions=["Be thorough in your analysis and ensure the keys you select are categorical, distinct, non-overlapping and represent key financial indicators like profits, revenues, assets, liabilities, and investments in company growth. Ensure the keys summarize the company's financial health effectively and are present in the context information."],
    response_format_instructions=["Respond with only a markdown code block containing a json list of strings representing the selected keys: eg. ```json\n[\nKey1,\nKey2\n,...,Key10\n]\n```", "Verify the presence of each of the keys in the provided context before responding", "Preserve the original PascalCase of the keys eg. 'AssetsCurrent' rather than 'Assets Current' and 'CostOfGoodsSold' instead of 'Cost Of Goods Sold'", "Order the keys in the order of importance"]
)

response = get_gemini_response(prompt, output_format="json")

print(response)
for key in response:
    try:
        assert key in keys
    except AssertionError:
        print(f"Key {key} not found in context information")
    

In [None]:
from hashlib import md5
from retry import retry
from timeout_decorator import timeout

@retry(tries=4, delay=15)
@timeout(300)
def most_relevant_keys(keys, n, prompt_cache):
    key_string = '\n* '.join(keys)
    context_information = [f"In this companies us-gaap report, the following keys are available:\n {key_string}"]
    prompt = GeminiPrompt(
        task=f"Select {n} keys from the us-gaap keys provided in the context to enable a financial analyst to quickly make a snap-judgment of the company's performance.",
        context_information=context_information,
        instructions=["Be thorough in your analysis", "Ensure the keys you select are categorical, distinct, non-overlapping and represent key financial indicators like profits, revenues, assets, liabilities, and investments in company growth.", "Ensure the keys summarize the company's financial health effectively and are present in the context information."],
        response_format_instructions=["Respond with only a markdown code block containing a json list of strings representing the selected keys: eg. ```json\n[\nKey1,\nKey2\n,...,Key10\n]\n```", "Verify the presence of each of the keys in the provided context before responding", "Preserve the original PascalCase of the keys eg. 'AssetsCurrent' rather than 'Assets Current' and 'CostOfGoodsSold' instead of 'Cost Of Goods Sold'", "Order the keys in the order of importance"]
    )
    hash = md5(''.join(sorted(keys)).encode('utf-8')).hexdigest()
    if hash in prompt_cache:
        return prompt_cache[hash]
    result = get_gemini_response(prompt, output_format="json")
    prompt_cache[hash] = result
    return result

In [None]:
from dataclasses import dataclass

@dataclass
class PointInTimeValue:
    q: int
    year: int
    value: int | float | str
    unit: str

    def __str__(self):
        return str(self.value)

In [None]:
import re
from bs4 import BeautifulSoup
from dataclasses import dataclass
import pandas as pd
from io import StringIO


# Regex to find <DOCUMENT> tags
doc_start_pattern = re.compile(r'<DOCUMENT>')
doc_end_pattern = re.compile(r'</DOCUMENT>')
# Regex to find <TYPE> tag prceeding any characters, terminating at new line
type_pattern = re.compile(r'<TYPE>[^\n]+')
from bs4 import BeautifulSoup
import pandas as pd
import io
from io import StringIO

@dataclass
class Form10KExtracts:
    risk_factors: str
    md_and_a: str
    disclosures: str
    hash_digest: str
    
@dataclass
class Form10K:
    file: str
    year: str
    q: str  
    
def clean_table_dataframe(table_html):
    # Parse the table with pandas
    df = pd.read_html(table_html)[0]
    
    # Drop completely empty rows and columns
    df.dropna(how='all', inplace=True)
    df.dropna(axis=1, how='all', inplace=True)
    
    # Clean up column names: combine multi-row headers, if any
    df.columns = df.iloc[0].fillna('')  # Use the first row as column names
    df = df.drop(index=0)  # Drop the first row, now that it's used as headers
    
    # Remove any rows with just empty values or text (e.g., notes, references)
    df = df[~df.apply(lambda row: row.str.contains(r'[a-zA-Z]').any(), axis=1)]
    
    # Clean up any extraneous characters or symbols in numeric columns
    def clean_numeric(val):
        try:
            # Remove commas and other non-numeric symbols, then convert to float
            return pd.to_numeric(str(val).replace(',', '').replace('(', '-').replace(')', ''), errors='coerce')
        except:
            return val

    df = df.applymap(clean_numeric)
    
    # Further clean if the table has extra empty cells after data rows
    df = df.dropna(how='all', axis=1)  # Drop columns where all values are NaN
    
    # Reset the index to have a clean row index after modifications
    df.reset_index(drop=True, inplace=True)
    
    return df
    
def get_text_only(html_fragment):
    """
    Extracts readable text while preserving tables from an HTML fragment.
    Processes tables to remove noise and converts them to CSV-like strings.
    
    Args:
        html_fragment (str): The HTML content to parse.
        
    Returns:
        str: Extracted text with cleaned tables as CSV-like strings.
    """
    soup = BeautifulSoup(html_fragment, "html.parser")

    # Extract and clean tables
    tables = soup.find_all("table")
    for table in tables:
        try:
            # Parse the table using pandas and clean it
            df = pd.read_html(StringIO(table))[0]  # Parse the first table
            cleaned_df = clean_table_dataframe(df)  # Clean the table data
            table_text = cleaned_df.to_csv(index=False, date_format='%Y-%m-%d')  # Convert to CSV string
            table.replace_with(f"\n[Table]\n{table_text}\n[/Table]\n")  # Replace table with CSV-like text
        except Exception as e:
            # Fallback to raw CSV (unprocessed) if cleaning fails
            # Get the raw table data as CSV using pandas' to_csv method
            try:
                raw_csv = pd.read_html(StringIO(str(table)), header=None)[0].to_csv(index=False, header=False, date_format='%Y-%m-%d')
                table.replace_with(f"\n[Table]\n{raw_csv}\n[/Table]\n")
            except Exception as inner_e:
                table.replace_with(f"\n[Table]\nError processing table: {str(inner_e)}\n[/Table]\n")

    # Extract plain text from the remaining HTML content
    text = soup.get_text(separator="\n")

    # Remove excessive empty lines and whitespace
    lines = [line.strip() for line in text.splitlines() if line.strip()]
    return "\n".join(lines)

def clean_string(input_text):
    """
    Remove excessive empty lines from a string.
    
    Args:
        input_text (str): Input text with potentially excessive whitespace
    
    Returns:
        str: Cleaned text with reduced whitespace
    """
    input_text = re.sub(r'<[^\>]+?\s*$', '', input_text)
    lines = input_text.splitlines()
    
    # Clean the lines
    cleaned_lines = []
    for line in lines:
        # Strip trailing and leading whitespace
        stripped_line = line.strip()
        
        # Only add non-empty lines or single empty lines
        if stripped_line or (not cleaned_lines or cleaned_lines[-1].strip()):
            cleaned_lines.append(line.rstrip())
    
    # Rejoin the lines, adding a single newline between blocks
    return '\n'.join(cleaned_lines)


def parse_10_k(text, hash_digest) -> Form10KExtracts | None:
    """
    Parse the 10-K report text to extract the risk factors, md and a, and disclosures sections.
    """
    # regex to find <TYPE> tags followed by section names like '10-K'   
    sections_regex = re.compile(r'(>(Item|ITEM)(\s|&#160;|&nbsp;)(1A|1B|7A|7|8)\.{0,1})')
    matches = sections_regex.finditer(text)
    matches_list = [(x.group(), x.start(), x.end()) for x in matches]

    # Check if we have any matches before creating DataFrame
    if matches_list:
        sections_df = pd.DataFrame(matches_list, columns=['item', 'start', 'end'])
        sections_df.columns = ['item', 'start', 'end']
        sections_df.replace('&#160;',' ',regex=True,inplace=True)
        sections_df.replace('&nbsp;',' ',regex=True,inplace=True)
        sections_df.replace(' ','',regex=True,inplace=True)
        sections_df.replace('\\.','',regex=True,inplace=True)
        sections_df.replace('>','',regex=True,inplace=True)
        sections_df['item'] = sections_df.item.str.lower()
        sections_df.sort_values('start', ascending=True, inplace=True)
        deduped = sections_df.drop_duplicates(subset=['item'], keep='last')
        deduped.set_index('item', inplace=True)
        risk_factors = clean_string(get_text_only(text[deduped['start'].loc['item1a']:deduped['start'].loc['item1b']]))
        md_and_a = clean_string(get_text_only(text[deduped['start'].loc['item7']:deduped['start'].loc['item7a']]))
        disclosures = clean_string(get_text_only(text[deduped['start'].loc['item7a']:deduped['start'].loc['item8']]))

        return Form10KExtracts(risk_factors, md_and_a, disclosures, hash_digest)
    else:
        return None

def parse_quarterly_report(file):
    with open(file) as f:
        text = f.read()
        doc_start_indexes = [x.end() for x in doc_start_pattern.finditer(text)]
        doc_end_indexes = [x.start() for x in doc_end_pattern.finditer(text)]

        ### Type filter is interesting, it looks for <TYPE> with Not flag as new line, ie terminare there, with + sign
        ### to look for any char afterwards until new line \n. This will give us <TYPE> followed Section Name like '10-K'
        ### Once we have have this, it returns String Array, below line will with find content after <TYPE> ie, '10-K' 
        ### as section names
        doc_types = [x[len('<TYPE>'):] for x in type_pattern.findall(text)]
        parsed_document = {}
        for doc_type, doc_start, doc_end in zip(doc_types, doc_start_indexes, doc_end_indexes):
            if doc_type == '10-K' or doc_type == '10-Q':
                parsed_document[doc_type] = text[doc_start:doc_end].replace("\xa0", " ")     
        return parsed_document

def save_file(parsed_document, file):
    with open(file, 'w') as f:
       json.dump(parsed_document, f)

redaction_instructions = "Any time the company's name, or year of the report, or any revealing product name would appear in the returned markdown document, redact it using the tag [REDACTED] so that a reader would not know which company is described. Also ensure summaries and quotes do not include names of individuals associated with the company."
@retry(tries=2, delay=15)
@timeout(300)
def summarize_risk_factors(risk_factors: str):
    
    context_information = [f"The following text was scraped from the risk factors section of a company's 10-K report: ```\n{risk_factors}\n```\n "]
    task = "Return a three paragraph summary of the most important information for a financial analyst to understand the company's risk factors. Also include a set of up to ten quotes from the text that support the summary."
    response_format_instructions = [f"The summary should be formatted as a markdown file with a 'Risk Factors' heading with two sections: Summary and Substantiating Quotes. {redaction_instructions}", "Respond only with a markdown code block containing markdown content within starting '```markdown\n' and ending: '\n```'"]
    prompt = GeminiPrompt(
        task=task,
        context_information=context_information,
        response_format_instructions=response_format_instructions
    )
    response = get_gemini_response(prompt, output_format="markdown")
    return response

@retry(tries=2, delay=15)
@timeout(300)
def summarize_md_and_a(md_and_a: str):

    context_information = [f"The following text was scraped from the md and a section of a company's 10-K report: ```\n{md_and_a}\n```\n "]
    task = "Return a three paragraph summary of the most important information for a financial analyst to understand the company's management discussion and analysis. Also include a set of up to ten quotes from the text that support the summary. "
    response_format_instructions = [f"The summary should be formatted as a markdown file with a 'Management's Discussion and Analysis' heading with two sections: Summary and Substantiating Quotes. {redaction_instructions}", "Respond only with a markdown code block containing markdown content within starting '```markdown\n' and ending: '\n```'"]
    prompt = GeminiPrompt(
        task=task,
        context_information=context_information,
        response_format_instructions=response_format_instructions
    )
    response = get_gemini_response(prompt, output_format="markdown")
    return response

@retry(tries=2, delay=15)
@timeout(300)
def summarize_disclosures(disclosures: str):
    context_information = [f"The following text was scraped from the disclosures section of a company's 10-K report: ```\n{disclosures}\n```\n "]
    task = "Return a three paragraph summary of the most important information for a financial analyst to understand the company's disclosures. Also include a set of up to ten quotes from the text that support the summary. "
    response_format_instructions = [f"The summary should be formatted as a markdown file with a 'Disclosures' heading with two sections: Summary and Substantiating Quotes. {redaction_instructions}",  "Respond only with a markdown code block containing markdown content within starting '```markdown\n' and ending: '\n```'"]
    prompt = GeminiPrompt(
        task=task,
        context_information=context_information,
        response_format_instructions=response_format_instructions
    )
    response = get_gemini_response(prompt, output_format="markdown")
    return response

def summarize_10k_extracts(extracts: Form10KExtracts) -> str:
    if extracts is None:
        return ""
    return f"""
{summarize_risk_factors(extracts.risk_factors)}

{summarize_md_and_a(extracts.md_and_a)}
"""
# """
# # Disclosures:
# {summarize_disclosures(extracts.disclosures)}
# """




In [None]:
def get_macro_metrics(year, q):
    """
    Returns a dictionary of macroeconomic metrics for the given year and quarter.
    Includes various FRED indicators.
    """
    # List of additional indicators to fetch
    indicators = {
        'CPI': 'CPIAUCSL',  # Consumer Price Index
        'UnemploymentRate': 'UNRATE',  # Unemployment Rate
        'InterestRate': 'GS1',  # 1-Year Treasury Rate
        'RetailSales': 'RSAFS',  # Retail Sales
        'IndustrialProduction': 'INDPRO',  # Industrial Production Index
        'CapacityUtilization': 'TCU',  # Capacity Utilization
        'ProducerPriceIndex': 'PPIACO',  # Producer Price Index
        'HousingStarts': 'HOUST',  # Housing Starts
        'ConsumerSentiment': 'UMCSENT',  # Consumer Sentiment Index
        'CorporateBondSpread': 'BAA10YM',  # Corporate Bond Spread (BAA - 10Y Treasury)
        'TradeBalance': 'BOPGSTB',  # Trade Balance (Goods and Services)
        'AverageHourlyEarnings': 'CES0500000003',  # Average Hourly Earnings (All Employees)
    }

    # Fetch all requested indicators

    indicators = {key: get_fred_indicator(year, q, fred_id) for key, fred_id in indicators.items()}
    point_in_time_values = {key: PointInTimeValue(q=q, year=year, value=indicators[key], unit='') for key in indicators}
    return point_in_time_values



In [None]:
from dataclasses import dataclass, field
import hashlib
from dataclasses_json import dataclass_json
from typing import Generator, List, Dict, NamedTuple, Optional, Set, Tuple, Any
from enum import Enum
import traceback
from datetime import timedelta
import datetime
from collections import defaultdict
from time import sleep
import math
import shelve


@dataclass
class FinancialSnapshot:
    year: int
    q: int
    company: 'Company'
    financial_info: Dict[str, int | str]

@dataclass
class Company:
    cik: str
    ticker: str
    exchange: str
    name: str
    shares: int
    sector: str
    location: str

@dataclass
class Label:
    stock_price_pre_earnings: float
    stock_price_post_earnings: float

class Trend(NamedTuple):
    two_years_ago: float | int | str | None
    one_year_ago: float | int | str | None
    nine_months_ago: float | int | str | None
    six_months_ago: float | int | str | None
    last_quarter: float | int | str | None
    current: float | int | str | None

class Projection(NamedTuple):
    next_quarter: Label
    next_six_months: Label
    next_year: Label

Year = int
Quarter = int

HistoricalTable = dict[tuple[Year, Quarter], int | float | str | None]

FP_TO_Q = {
    "Q1": 1,
    "Q2": 2,
    "Q3": 3,
    "FY": 4
}

DATE_TO_Q = {
    '03-31': 1,
    '06-30': 2,
    '09-30': 3,
    '12-31': 4
}


def to_point_in_time(u, vals) -> PointInTimeValue:
    try:
        end = vals["end"]
        year, q = vals["fy"] if vals.get("fy") else int(end[:4]), FP_TO_Q[vals["fp"]] if vals.get("fp") else DATE_TO_Q[end[5:]]
        return PointInTimeValue(unit=u, q=q, year=year, value=vals["val"])
    except Exception as e:
        traceback.print_exc()

@retry(tries=4, backoff=2, delay=2)
def get_report_share_prices(company: Company, p):
    file_year, file_month, file_day = tuple(p["filed"].split('-'))
    filing_date = datetime.date(int(file_year), int(file_month), int(file_day))
    with shelve.open('share_prices') as share_prices:
        ticker, pre_earnings_date, post_earnings_date = company.ticker, filing_date - timedelta(days=1), filing_date + timedelta(days=5)  # use 5 days after closing to smooth out filing spikes
        key = f'{ticker}|{pre_earnings_date}|{post_earnings_date}'
        if key in share_prices:
            return share_prices[key]
        else:
            share_price = get_share_price(company.ticker, pre_earnings_date, post_earnings_date)
            share_price_pre_filing = PointInTimeValue(
                unit='USD', 
                q=FP_TO_Q[p["fp"]], 
                year=p["fy"], 
                value=share_price["close"].iloc[0]
            )
            share_price_post_filing = PointInTimeValue(
                unit='USD', 
                q=FP_TO_Q[p["fp"]], 
                year=p["fy"], 
                value=share_price["close"].iloc[-1]
            )
            share_prices[key] = (share_price_pre_filing, share_price_post_filing)
            return share_price_pre_filing, share_price_post_filing


def create_report_url(cik, row):
    accession_number = row["accessionNumber"].replace("-", "")
    primary_document = row["primaryDocument"]
    url = f"https://www.sec.gov/Archives/edgar/data/{cik}/{accession_number}/{primary_document}"
    return url

def get_report_year(row):
    reportDate = row['reportDate']
    return reportDate[:reportDate.find('-')]

def get_report_q(row):
    reportDate = row['reportDate']
    return DATE_TO_Q.get(reportDate[reportDate.find('-') + 1:], None)

EXCLUDE_KEYS = ['QuarterlyReportAccessionNumber', 'QuarterlyReportUrl']
@dataclass
class ContextualSnapshot:
    year: int
    q: int
    company: Company
    historical_trends: Dict[str, Trend]
    future_projection: Projection
    most_recent_10k_file: Optional[str] = None

    # We are trying to predict the % change in the stock price over the next quarter, six months, and year
    def get_labels(self):
        pre_earnings_future_values = [
            self.future_projection.next_quarter.stock_price_pre_earnings,
            self.future_projection.next_six_months.stock_price_pre_earnings,
            self.future_projection.next_year.stock_price_pre_earnings
        ]
        post_earnings_future_values = [
            self.future_projection.next_quarter.stock_price_post_earnings,
            self.future_projection.next_six_months.stock_price_post_earnings,
            self.future_projection.next_year.stock_price_post_earnings
        ]
        # Use historical trend's current value as the starting point
        share_price_pre_earnings = self.historical_trends['SharePricePreFiling'].current.value.astype(float)
        share_price_post_earnings = self.historical_trends['SharePricePostFiling'].current.value.astype(float)
        # Calculate the change for each of the next quarter, next six months, and year compared to the starting point
        pre_earnings_percent_changes = [
            (future_value.value - share_price_pre_earnings) / share_price_pre_earnings if future_value else None
            for future_value in pre_earnings_future_values
        ]

        post_earnings_percent_changes = [
            (future_value.value - share_price_post_earnings) / share_price_post_earnings if future_value else None
            for future_value in post_earnings_future_values
        ]
        # Interleave the pre and post earnings percent changes to create a single list of percent changes to serve as labels
        return [
            pre_earnings_percent_changes[0],
            post_earnings_percent_changes[0],
            pre_earnings_percent_changes[1],
            post_earnings_percent_changes[1],
            pre_earnings_percent_changes[2],
            post_earnings_percent_changes[2]
        ]

    def _get_company_summary(self):
        # 
        return f"The following return data was for a company in the " \
               f"{self.company.sector} sector in {self.company.location} " \
               f"with a market cap of ${self.company.shares * self.historical_trends['SharePricePostFiling'].current.value} " \
               f"at the time of the most recent quarterly report."
    
    def _get_historical_trends(self):
        # TODO: Verify that the historical trends is being created correctly
        intro = f"The following csv shows the historical trends for the company's " \
               f"financial metrics up until the current point in time, " \
               f"including both information about the company " \
               f"and macroeconomic metrics: \n\n"
        # Create a dictionary to store trend values for each historical trend
        trend_data = {}
        
        for key, trend in self.historical_trends.items():
            if key not in EXCLUDE_KEYS:
                trend_data[key] = {
                    'current': trend.current,
                    'last_quarter': trend.last_quarter,
                    'six_months_ago': trend.six_months_ago,
                    'nine_months_ago': trend.nine_months_ago,
                    'one_year_ago': trend.one_year_ago,
                    'two_years_ago': trend.two_years_ago
                }
        
        # Convert the dictionary to a DataFrame
        df = pd.DataFrame.from_dict(trend_data, orient='index')
        df.reset_index(inplace=True)
        df.rename(columns={'index': 'metric'}, inplace=True)

        return intro + f"```csv\n{df.to_csv(index=False, date_format='%Y-%m-%d')}\n```\n\n"
    
    def _get_most_recent_10k_extracts(self) -> Form10KExtracts:
        with open(self.most_recent_10k_file) as f:
            text = f.read()
            file_md5 = hashlib.md5(text.encode()).hexdigest()
            with shelve.open('10k_extracts') as tenKExtracts:
                if file_md5 in tenKExtracts:
                    extracts = tenKExtracts[file_md5]
                else:
                    extracts = parse_10_k(text, file_md5)
                    tenKExtracts[file_md5] = extracts
        return extracts


    def _get_most_recent_10k_summary(self):
        extracts = self._get_most_recent_10k_extracts()
        with shelve.open('10k_summaries') as summaries:
            if extracts.hash_digest in summaries:
                return summaries[extracts.hash_digest]
            summary = summarize_10k_extracts(extracts)
            summaries[extracts.hash_digest] = summary
            return summary


    
    def to_anonymous_report(self):
        """
        Details that determine which company is being reported on are redacted,
        and a report is returned as a markdown file consisting of the following sections:
        - Company Summary
        - Historical Trends up until the current point in time including both information about the company and macroeconomic metrics, as a csv file
        - Key excerpts from the most recent 10-K report, including when the report was filed
        """
        print(f"Getting summary for {self.company.name}, Q{self.q}, {self.year}")
        company_summary = self._get_company_summary()
        print(f"Getting historical trends for {self.company.name}, Q{self.q}, {self.year}")
        historical_trends = self._get_historical_trends()
        print(f"Getting most recent 10-K extracts for {self.company.name}, Q{self.q}, {self.year}")
        file_10k_extracts_summary = self._get_most_recent_10k_summary()
        
        return f"""
# Company Summary
{company_summary}

# Historical Trends
{historical_trends}

#  Most Recent 10-K Summary
{file_10k_extracts_summary}
"""
        


@dataclass
class HistoricalDataStore:
    """A data structure for storing and querying historical financial data by year and quarter."""
    _data: Dict[int, Dict[int, Dict[str, PointInTimeValue]]]  # year -> quarter -> metrics
    _keys: Dict[str, str] = field(default_factory=Dict[str, str])
    
    def __init__(self):
        self._data = defaultdict(lambda: defaultdict(dict))
        self._keys = dict()

    def add_metrics(self, year: int, quarter: int, metrics: Dict[str, any]) -> None:
        """Add or update multiple metrics for a specific year and quarter."""
        self._data[year][quarter].update(metrics)
        # Add the keys to the keys dict with the unit of the first value as the value
        for k, v in metrics.items():
            self._keys[k] = v.unit if isinstance(v, PointInTimeValue) else None
    def add_metric(self, year: int, quarter: int, key: str, value: Any) -> None:
        """Add or update a single metric for a specific year and quarter."""
        self._data[year][quarter][key] = value
        self._keys[key] = value.unit if isinstance(value, PointInTimeValue) else None


    def get_by_year_quarter(self, year: int, quarter: int) -> Optional[Dict[str, Any]]:
        """Get all metrics for a specific year and quarter."""
        return self._data.get(year, {}).get(quarter)

    def get_by_year(self, year: int) -> Dict[int, Dict[str, Any]]:
        """Get all quarters' data for a specific year."""
        return dict(self._data.get(year, {}))

    def get_by_year_quarter_metric(self, year: int, quarter: int, metric: str) -> Optional[PointInTimeValue]:
        return self._data.get(year, {}).get(quarter, {}).get(metric)

    def get_by_quarter(self, quarter: int) -> Dict[int, Dict[str, Any]]:
        """Get data for a specific quarter across all years."""
        return {year: quarters[quarter] 
                for year, quarters in self._data.items() 
                if quarter in quarters}

    def get_all_years(self) -> List[int]:
        """Get all available years."""
        return sorted(self._data.keys())

    def get_all_quarters(self) -> List[int]:
        """Get all available quarters across all years."""
        return sorted(set(q for quarters in self._data.values() for q in quarters.keys()))
    
    def get_all_metrics(self) -> Set[str]:
        return self._keys

    def to_dataframe(self) -> pd.DataFrame:
        # Lazily iterate over all metrics and convert to dataframe yield dataframe rows for each quarter, year, using the q, year, and sorted keys within metric as header
        keys = sorted(self._keys.items(), key=lambda x: x[0])
        # set up dataframe with columns for q, year, and all keys
        df = pd.DataFrame(columns=['q', 'year'] + keys)
        for year in self.get_all_years():
            for q in self.get_by_year(year).values():
    
                for metrics in q.values():
                    df = df.append({'q': q, 'year': year, **metrics}, ignore_index=True)
        return df

    def generate_csv(self) -> Generator[str, None, None]:
        # lazily iterate over all metrics and yield csv rows for each quarter, year, using the q, year, and sorted keys within metric as header
        # emit header row first with all keys
        keys = sorted(self._keys.items(), key=lambda x: x[0])
        keys_with_units = [f"{k} ({v})" for k, v in keys]
        yield ','.join(['q', 'year'] + keys_with_units) + '\n'
        for year in self.get_all_years():
            for q in self.get_by_year(year).keys():
                quarter_year = [str(q), str(year)]
                values = [str(self.get_by_year_quarter_metric(year, q, k).value) if self.get_by_year_quarter_metric(year, q, k) else '' for k, _ in keys]
                row = ','.join(quarter_year + values) + '\n'
                yield row
    
@dataclass
class KeyCompanyFacts:
    company: Company
    historical_data: HistoricalDataStore
    submissions_df: pd.DataFrame = field(default_factory=pd.DataFrame)
    out_folder = f'out/'
    submissions_folder = f'findata/submissions'

    def __post_init__(self):
        self.download_folder = f'out/{self.company.cik}/downloads'
        
    def __str__(self):
        # Assemble csv of all metrics and their values
        return ''.join(self.historical_data.generate_csv())
        
    def years(self):
        # return all years in the historical data store
        return self.historical_data.get_all_years()
    
    def get_trends(self, q: int, year: int) -> Dict[str, Trend]:
        """Get trends for all metrics at a specific quarter and year."""
        current_data = self.historical_data.get_by_year_quarter(year, q) or {}
        return {
            key: self.get_trend(key, q, year)
            for key in current_data.keys()
        }

    def get_trend(self, key: str, q: int, year: int) -> Trend:
        """Get trend data for a specific metric."""
        def get_value(y: int, qtr: int) -> Optional[Any]:
            data = self.historical_data.get_by_year_quarter(y, qtr)
            return data.get(key) if data else None

        return Trend(
            current=get_value(year, q),
            two_years_ago=get_value(year - 2, q),
            one_year_ago=get_value(year - 1, q),
            nine_months_ago=get_value(
                year if q == 4 else year - 1,
                (q % 4) + 1
            ),
            six_months_ago=get_value(
                year if q > 2 else year - 1,
                (q - 2) if q > 2 else (q + 2) % 4
            ),
            last_quarter=get_value(
                year if q > 1 else year - 1,
                4 if q == 1 else q - 1
            )
        )

    def _load_submissions_df(self) -> pd.DataFrame:
        submissions = f'{getcwd()}/{self.submissions_folder}'
        os.makedirs(submissions, exist_ok=True)
        print(f"submissions: {submissions}")
        with open(f'{submissions}/CIK{self.company.cik}.json') as f:
            obj = json.load(f)
            dataframe = pd.DataFrame(obj['filings']['recent'])
            for file in obj['filings']['files']:
                with open(f"{submissions}/{file['name']}") as s:
                    file_dataframe = pd.DataFrame(json.load(s))
                    dataframe = pd.concat([dataframe, file_dataframe])
        self.submissions_df = dataframe
        return dataframe


    def _load_quarterly_report_index(self) -> pd.DataFrame:
        if self.submissions_df.empty:
            self._load_submissions_df()
        submissions = self.submissions_df
        quarterly_reports = submissions[submissions["form"].isin(["10-K", "10-Q", "10-Q/A", "10-K/A"])].sort_values(by="reportDate")
        if len(quarterly_reports) > 0:
            quarterly_reports = quarterly_reports[quarterly_reports['primaryDocument'] != '']
            quarterly_reports.dropna(subset=['primaryDocument'], inplace=True)
            quarterly_reports['url'] = quarterly_reports.apply(lambda row: create_report_url(self.company.cik, row), axis=1)
            quarterly_reports['filingDate'] = pd.to_datetime(quarterly_reports['filingDate'])
            quarterly_reports['year'] = quarterly_reports['filingDate'].dt.year
            quarterly_reports['q'] = quarterly_reports['filingDate'].dt.quarter
        return quarterly_reports


    def _download_quarterly_report(self, year: int, q: int):
        index = self._load_quarterly_report_index()
        url = index.loc[(index['year'] == year) & (index['q'] == q), 'url'].iloc[0]
        print(url)
        q_folder = f'{self.download_folder}/{year}/{q}'
        os.makedirs(q_folder, exist_ok=True)
        try:
            path = f'{q_folder}/raw.htm'

            print(f"Downloading file to {path}")
            response = requests.get(url, stream=True, headers={
                'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
                'accept-language': 'en-US,en;q=0.9,he-IL;q=0.8,he;q=0.7',
                'cache-control': 'max-age=0',
                'priority': 'u=0, i',
                'sec-ch-ua': '"Chromium";v="130", "Google Chrome";v="130", "Not?A_Brand";v="99"',
                'sec-ch-ua-mobile': '?0',
                'sec-ch-ua-platform': '"macOS"',
                'sec-fetch-dest': 'document',
                'sec-fetch-mode': 'navigate',
                'sec-fetch-site': 'none',
                'sec-fetch-user': '?1',
                'upgrade-insecure-requests': '1',
                'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
            })
            response.raise_for_status()
            with open(path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=1024):
                    f.write(chunk)
            return path
        except requests.exceptions.RequestException as e:
            print(f"Error downloading {url}: {e}")
            traceback.print_exc()


    def _download_quarterly_reports(self):
        # Iterate over historical data store and find file urls in index for corresponding quarterly reports.
        for year in self.historical_data.get_all_years():
            for q, _ in self.historical_data.get_by_year(year).items():
                self._download_quarterly_report(year, q)

    def get_quarterly_report_file(self, year: int, q: int) -> str:
        path = f'{self.download_folder}/{year}/{q}/raw.htm'
        if not os.path.exists(path):
            self._download_quarterly_report(year, q)
        return path
    
    def get_most_recent_10k_file(self, as_of_year: int, as_of_q: int) -> Optional[str]:
        """
        Gets the most recent 10-K file *as of* a given year and quarter.

        Args:
            as_of_year: The year to consider as the cutoff.
            as_of_q: The quarter to consider as the cutoff.

        Returns:
            The path to the downloaded 10-K file, or None if not found.
        """
        index = self._load_quarterly_report_index()
        # Filter for 10-K forms
        ten_ks = index[index['form'] == '10-K']
        # Filter for reports with year and quarter LESS THAN OR EQUAL TO the as_of values
        ten_ks_before_cutoff = ten_ks[(ten_ks['year'] < as_of_year) | ((ten_ks['year'] == as_of_year) & (ten_ks['q'] <= as_of_q))]

        if not ten_ks_before_cutoff.empty:
            most_recent_10k = ten_ks_before_cutoff.sort_values(by=['year', 'q'], ascending=False).iloc[0]
            return self._download_quarterly_report(most_recent_10k['year'], most_recent_10k['q'])

        return None
        
    def add_company_facts(self, facts: Dict[str, Any]) -> None:
        """Add company facts to the historical data store."""
        quarterly_report_index = self._load_quarterly_report_index()

        for key, unit_data in facts.items():
            if unit_data:
                for unit, points_in_time in unit_data.items():
                    for val in points_in_time:
                        if val.get('form') not in ['10-K', '10-Q', '10-Q/A', '10-K/A']:
                            continue
                        point = to_point_in_time(unit, val)
                        if not point:
                            print(f"Could not create point for {key} with unit {unit} and value {val}")
                            continue
                    
                        # Check if we need to add share prices for a new time period
                        existing_data = self.historical_data.get_by_year_quarter(point.year, point.q)
                        if not existing_data:
                            share_price_pre_filing, share_price_post_filing = get_report_share_prices(self.company, val)
                            accn = val["accn"]
                            quarterly_report_row = quarterly_report_index[quarterly_report_index["accessionNumber"] == accn]
                            accn = accn.replace("-", "")
                            primary_document = quarterly_report_row["primaryDocument"].values[0] if len(quarterly_report_row) > 0 else None
                            quarterly_report_url = f"https://www.sec.gov/Archives/edgar/data/{self.company.cik}/{accn}/{primary_document}"
                            macroeconomic_metrics = get_macro_metrics(point.year, point.q)
                            self.historical_data.add_metrics(point.year, point.q, {
                                'SharePricePreFiling': share_price_pre_filing,
                                'SharePricePostFiling': share_price_post_filing,
                                'QuarterlyReportAccessionNumber': PointInTimeValue(q=point.q, year=point.year, value=accn, unit=''),
                                'QuarterlyReportUrl': PointInTimeValue(q=point.q, year=point.year, value=quarterly_report_url, unit='URL') if quarterly_report_url else None,
                                **macroeconomic_metrics
                            })
                        
                        # Add the new metric
                        self.historical_data.add_metric(point.year, point.q, key, point)

    def get_projection(self, q: int, year: int) -> Projection:
        """
        A projection in this case is not a guess about what will happen, but because we are looking at historical data
        we know about what happened in the years and quarters after the q and year we are querying.

        It contains the values of the stock price pre-earnings and post-earnings for the next quarter, six months, and year following the provided quarter and year.
        """
        # get pre-earnings and post-earnings for next quarter, six months, and year:
        next_quarter_pre_earnings = self.historical_data.get_by_year_quarter_metric(year, q + 1 if q < 4 else 1, 'SharePricePreFiling')
        next_quarter_post_earnings = self.historical_data.get_by_year_quarter_metric(year, q + 1 if q < 4 else 1, 'SharePricePostFiling')
        next_six_months_pre_earnings = self.historical_data.get_by_year_quarter_metric(year, q + 2 if q < 3 else q - 2, 'SharePricePreFiling')
        next_six_months_post_earnings = self.historical_data.get_by_year_quarter_metric(year, q + 2 if q < 3 else q - 2, 'SharePricePostFiling')
        next_year_pre_earnings = self.historical_data.get_by_year_quarter_metric(year + 1, q, 'SharePricePreFiling')
        next_year_post_earnings = self.historical_data.get_by_year_quarter_metric(year + 1, q, 'SharePricePostFiling')

        return Projection(
            next_quarter=Label(
                stock_price_pre_earnings=next_quarter_pre_earnings,
                stock_price_post_earnings=next_quarter_post_earnings
            ),
            next_six_months=Label(
                stock_price_pre_earnings=next_six_months_pre_earnings,
                stock_price_post_earnings=next_six_months_post_earnings
            ),
            next_year=Label(
                stock_price_pre_earnings=next_year_pre_earnings,
                stock_price_post_earnings=next_year_post_earnings
            )
        )
    
    def create_contextual_snapshot(self, year: int, q: int) -> ContextualSnapshot:
        return ContextualSnapshot(
            year=year,
            q=q,
            company=self.company,
            historical_trends=self.get_trends(q, year),
            future_projection=self.get_projection(q, year),
            most_recent_10k_file=self.get_most_recent_10k_file(year, q)
        )
    


    @classmethod
    def from_company_facts_file(cls, company, prompt_cache):
        company_facts = "findata/companyfacts"
        cik = company['cik']
        print(f"cik: {cik}")
        with open(f'{company_facts}/CIK{cik}.json') as f:
            print('loading facts')
            all_facts = json.load(f)
            us_gaap_keys = all_facts['facts']['us-gaap'].keys()
            most_relevant_gaap_keys = most_relevant_keys(us_gaap_keys, 20, prompt_cache=prompt_cache)
            print(f"Most relevant gaap keys: {most_relevant_gaap_keys}")
            company_obj = Company(
                cik=cik,
                ticker=company['Ticker'],
                exchange=company['Exchange'],
                name=company['Name'],
                shares=int(float(company['Quantity'].replace(',', ''))),
                sector=company['Sector'],
                location=company['Location']
            )            
            # Create instance with empty historical data
            instance = cls(company=company_obj, historical_data=HistoricalDataStore())
            # Filter facts to only include most relevant keys
            relevant_facts = {
                key: all_facts['facts']['us-gaap'][key]['units'] if key in all_facts['facts']['us-gaap'] else None
                for key in most_relevant_gaap_keys
            }
            
            # Use add_company_facts to process the data
            instance.add_company_facts(relevant_facts)
            
            return instance





In [None]:
from dataclasses import dataclass
from typing import List
import hashlib

@dataclass(frozen=True, eq=True)
class PromptResponse:
    prompt: str
    response: List[str]

def hash(long_string):
    hashlib.md5(bytes(long_string, 'utf-8')).hexdigest()

In [None]:
import os
import requests
import zipfile

def download_and_inflate(url, output_dir):
    """Downloads and inflates a zip file to a subdirectory within the output directory.

    Args:
        url (str): The URL of the file to download.
        output_dir (str): The main output directory.
    """
    filename = os.path.basename(url)
    subdirectory = filename.replace(".zip", "")  # Create subdirectory name
    subdirectory_path = os.path.join(output_dir, subdirectory)
    zip_filepath = os.path.join(output_dir, filename)

    if not os.path.exists(subdirectory_path):
        print(f"Downloading file: {filename}")
        try:
            response = requests.get(url, stream=True, headers={
            'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
            'accept-language': 'en-US,en;q=0.9,he-IL;q=0.8,he;q=0.7',
            'cache-control': 'max-age=0',
            'priority': 'u=0, i',
            'sec-ch-ua': '"Chromium";v="130", "Google Chrome";v="130", "Not?A_Brand";v="99"',
            'sec-ch-ua-mobile': '?0',
            'sec-ch-ua-platform': '"macOS"',
            'sec-fetch-dest': 'document',
            'sec-fetch-mode': 'navigate',
            'sec-fetch-site': 'none',
            'sec-fetch-user': '?1',
            'upgrade-insecure-requests': '1',
            'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
        })
            response.raise_for_status()

            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            with open(zip_filepath, "wb") as f:
                for chunk in response.iter_content(1024):
                    f.write(chunk)
            print(f"Download complete: {filename}")

            # Create the subdirectory if it doesn't exist
            if not os.path.exists(subdirectory_path):
                os.makedirs(subdirectory_path)

            with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
                zip_ref.extractall(subdirectory_path)
            print(f"File inflated to: {subdirectory_path}")

            # Optionally remove the zip file after extraction
            os.remove(zip_filepath)
            print(f"Removed zip file: {zip_filepath}")

        except requests.exceptions.RequestException as e:
            print(f"Error downloading file: {e}")
        except zipfile.BadZipFile as e:
            print(f"Error inflating file: {e}")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
    else:
        print(f"Zip File already exists: {zip_filepath}")
        subdirectory = filename.replace(".zip", "")  # Create subdirectory name
        subdirectory_path = os.path.join(output_dir, subdirectory)
        if os.path.exists(subdirectory_path):
            print(f"Subdirectory already exists: {subdirectory_path}")
        else:
            print("Zip file exists but subdirectory does not. Please manually extract the zip file.")

# Example usage
urls = [
    "http://www.sec.gov/Archives/edgar/daily-index/xbrl/companyfacts.zip",
    "https://www.sec.gov/Archives/edgar/daily-index/bulkdata/submissions.zip"
]
output_dir = f"{getcwd()}/findata"

for url in urls:
    download_and_inflate(url, output_dir)

In [None]:
facts = KeyCompanyFacts.from_company_facts_file(company=holdings.iloc[0], prompt_cache={})
snapshot = facts.create_contextual_snapshot(2020, 1)
extracts = snapshot._get_most_recent_10k_extracts()

print(len(extracts.md_and_a), len(extracts.disclosures), len(extracts.risk_factors))
print(snapshot.to_anonymous_report())


In [None]:
import os
import pandas as pd
import shelve
import json

def create_resilient_training_data(holdings, outdir, processed_ciks_shelf='processed_ciks', 
                                 prompt_cache_shelf='prompt_cache', 
                                 training_data_file='training_data.parquet'):
    """
    Create a resilient training data DataFrame with checkpointing at every snapshot.
    
    Args:
        holdings (pd.DataFrame): DataFrame containing company holdings
        outdir (str): Output directory for snapshots
        processed_ciks_shelf (str): Shelf file for tracking processed CIKs
        prompt_cache_shelf (str): Shelf file for prompt caching
        training_data_file (str): Parquet file to save/load training data
    
    Returns:
        pd.DataFrame: Completed training data DataFrame
    """
    # Attempt to load existing training data
    if os.path.exists(training_data_file):
        training_data = pd.read_parquet(training_data_file)
        # Get the last processed snapshots to continue from
        processed_snapshots = set(
            training_data.apply(
                lambda x: f"{x['cik']}_{x['year']}_{x['q']}", 
                axis=1
            )
        ) if not training_data.empty else set()
    else:
        training_data = pd.DataFrame()
        processed_snapshots = set()

    # Ensure output directories exist
    os.makedirs(outdir, exist_ok=True)

    # Open shelves for processed CIKs and prompt cache
    with shelve.open(processed_ciks_shelf) as processed, \
         shelve.open(prompt_cache_shelf) as prompt_cache:
        
        # Iterate through holdings
        for i, company in holdings.iterrows():
            try:
                # Process company facts
                company_facts = KeyCompanyFacts.from_company_facts_file(
                    company=holdings.iloc[i], 
                    prompt_cache=prompt_cache
                )
                
                # Create company-specific directory
                company_dir = os.path.join(os.getcwd(), outdir, str(company_facts.company.cik))
                print(f"Company facts exist for {company_facts.years()}")
                
                for year in company_facts.years():
                    for q in range(1, 5):
                        # Create snapshot identifier
                        snapshot_id = f"{company_facts.company.cik}_{year}_{q}"
                        
                        # # Skip if this snapshot was already processed
                        # if snapshot_id in processed_snapshots:
                        #     continue
                            
                        try:
                            # Check if data is available for this year and quarter
                            if company_facts.historical_data.get_by_year_quarter(year, q):
                                # Create contextual snapshot
                                snapshot = company_facts.create_contextual_snapshot(year, q)
                                
                                # Prepare file paths
                                snapshot_dir = os.path.join(company_dir, 'snapshots', str(year))
                                os.makedirs(snapshot_dir, exist_ok=True)
                                snapshot_file = os.path.join(snapshot_dir, f'{q}.md')
                                labels_file = os.path.join(snapshot_dir, f'{q}-labels.json')
                                
                                # Save snapshot
                                with open(snapshot_file, 'w') as f:
                                    f.write(snapshot.to_anonymous_report())

                                # Get and save labels
                                labels = {
                                    'cik': company_facts.company.cik,
                                    'year': year,
                                    'q': q,
                                    'snapshot_file': snapshot_file,
                                    'labels': snapshot.get_labels()
                                }
                                
                                # Save labels to JSON
                                with open(labels_file, 'w') as f:
                                    json.dump(labels, f)

                                # Add single row to training data
                                training_data = pd.concat(
                                    [training_data, pd.DataFrame([labels])], 
                                    ignore_index=True
                                )
                                
                                # Save progress after each snapshot
                                training_data.to_parquet(training_data_file, index=False)
                                
                                # Mark snapshot as processed
                                processed_snapshots.add(snapshot_id)
                                
                        except Exception as e:
                            print(f"Error processing snapshot {snapshot_id}: {e}")
                            traceback.print_exc()
                            continue

                # Mark CIK as fully processed only if all snapshots succeeded
                if all(
                    f"{company_facts.company.cik}_{year}_{q}" in processed_snapshots
                    for year in company_facts.years()
                    for q in range(1, 5)
                    if company_facts.historical_data.get_by_year_quarter(year, q)
                ):
                    processed[str(company_facts.company.cik)] = True

            except Exception as e:
                print(f"Error processing company {company['cik']}: {e}")
                traceback.print_exc()
                continue

    return training_data



In [None]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)

training_data = create_resilient_training_data(holdings, outdir)