In [None]:
!pip install dateutils huggingface_hub wandb  yfinance datasets transformers trl modal holidays dateutils throttler

Collecting dateutils
  Downloading dateutils-0.6.12-py2.py3-none-any.whl.metadata (1.3 kB)
Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting trl
  Downloading trl-0.14.0-py3-none-any.whl.metadata (12 kB)
Collecting modal
  Downloading modal-0.73.7-py3-none-any.whl.metadata (2.3 kB)
Collecting throttler
  Downloading throttler-1.2.2-py3-none-any.whl.metadata (7.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2023.5.0 (from huggingface_hub)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting fastapi (from modal)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting grp

In [None]:
!huggingface-cli login
!wandb login

In [None]:
HF_USER_NAME="2084Collective"
"""
Also change the user name in the below writefile script if you want to run the data collection scripts for yourself
"""

'\nAlso change the user name in the below writefile script if you want to run the data collection scripts for yourself\n'

In [None]:
from dateutil.parser import parse, ParserError
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Load a Dataset with a specific version of a CSV, then remove a column
dataset = kagglehub.load_dataset(
    KaggleDatasetAdapter.HUGGING_FACE,
    "miguelaenlle/massive-stock-news-analysis-db-for-nlpbacktests",
    "analyst_ratings_processed.csv",
)
dataset.push_to_hub("2084Collective/deepstock-dataset-raw")
ds = load_dataset("2084Collective/deepstock-dataset-raw")
def parse_date(example):
  try:
    if example['date'] is None:
      return {"date":None}
    return {"date":parse(example['date'])}
  except ParserError:
    return {
        "date":None
    }
ds = ds['train']
ds = ds.map(parse_date)
ds = ds.filter(lambda example: example['date'] is not None)
ds.push_to_hub(f"{HF_USER_NAME}/deepstock-dataset-filtered")

In [None]:
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd
from collections import defaultdict
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset
from dateutil.parser import parse, ParserError
import asyncio
import pickle

async def download_yf(stock, start_date, end_date, progress):
  return yf.download(stock, start=start_date, end=end_date, progress=False)

async def download_stock_histories(dataset):
    """Download complete history for each unique stock"""
    unique_stocks = set(dataset['stock'])
    stock_histories = {}
    print(dataset['date'][0])
    # Get min and max dates from dataset

    dates = [(date) for date in tqdm(dataset['date'])]
    # Use date() to get just the date part for the stock data range
    start_date = min(dates).date() - timedelta(days=1)
    end_date = max(dates).date() + timedelta(days=5)
    with tqdm(total=len(unique_stocks), desc="Downloading stock histories") as pbar:
      async def download_stock_history(stock, start_date, end_date):
        try:
            history = yf.download(stock, start=start_date, end=end_date, progress=False)
            pbar.update(1)
            if not history.empty:
                return stock, history
            else:
                print(f"No data found for {stock}")
        except Exception as e:
            print(f"Error downloading {stock}: {str(e)}")
        return stock, None
      tasks = []
      for stock in tqdm(unique_stocks, desc="Downloading stock histories"):
        tasks.append(download_stock_history(stock, start_date, end_date))
      results =  await (asyncio.gather(*tasks))
      for stock, history in results:
        if history is not None:
            stock_histories[stock] = history
    return stock_histories


# Use the function


if __name__=="__main__":
  ds = load_dataset(f"{HF_USER_NAME}/deepstock-dataset-filtered")
  ds = ds['train']
  stock_histories = await download_stock_histories(ds)
  with open("stock_histories.pkl", "wb") as f:
    pickle.dump(stock_histories, f)

In [None]:
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from datasets import Dataset
import pickle
def process_single_stock(stock_data):
    """Process a single stock's history"""
    stock, history = stock_data

    local_data = {
        'stock': [],
        'date': [],
        'open': [],
        'close': []
    }

    for date, row in history.iterrows():
        try:
            local_data['stock'].append(stock)
            local_data['date'].append(date.strftime('%Y-%m-%d'))
            local_data['open'].append(float(row['Open'].iloc[0]))
            local_data['close'].append(float(row['Close'].iloc[0]))
        except Exception as e:
            print(f"Error processing {stock} on {date} - {e}")
    return local_data

def merge_dictionaries(dict_list):
    """Merge list of dictionaries into a single dictionary"""
    merged = {
        'stock': [],
        'date': [],
        'open': [],
        'close': []
    }

    for d in dict_list:
        for key in merged:
            merged[key].extend(d[key])

    return merged

def process_stock_histories(stock_histories, num_processes=None):
    """
    Process stock histories into a format suitable for HuggingFace datasets using multiprocessing

    Args:
        stock_histories: Dictionary of stock histories
        num_processes: Number of processes to use. If None, will use number of CPU cores

    Returns:
        Dataset: HuggingFace Dataset containing processed stock data
    """
    if num_processes is None:
        num_processes = cpu_count()

    # Prepare stock data for parallel processing
    stock_data = list(stock_histories.items())

    # Process stocks in parallel using Pool
    with Pool(processes=num_processes) as pool:
        # Use imap to process stocks with progress bar
        results = list(tqdm(
            pool.imap(process_single_stock, stock_data),
            total=len(stock_data),
            desc="Processing stocks"
        ))

    # Merge results from all processes
    merged_data = merge_dictionaries(results)

    return Dataset.from_dict(merged_data)
if __name__ == "__main__":
    with open("stock_histories.pkl", "rb") as f:
        stock_histories = pickle.load(f)
    stock_ds = process_stock_histories(stock_histories)
    stock_ds.push_to_hub(f"{HF_USER_NAME}/deepstock-stock-historical-prices-dataset-processed")


In [None]:
%%writefile create_deepstock_dataset.py
from pydantic import BaseModel
from datetime import date, timedelta
from typing import List, Dict, Any, Tuple, Optional
import abc
from datasets import load_dataset, Dataset
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd
from tqdm.auto import tqdm
import pickle
import json
from pprint import pprint
import os
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field, HttpUrl
from enum import Enum
import requests
from typing import Dict, Any
import time
from datetime import date
import holidays
import modal
import pickle
import os
# Select country
CACHE_VOLUME="/cache"
us_holidays = holidays.US()
"""
REMEMBER TO SET NEWSAPI_SECRET in your environment variables!
"""
OLDEST_POSSIBLE_DATE = "2021-06-30"
NEWEST_POSSIBLE_DATE = "2023-12-31"

def get_business_days(start_date="2021-06-30", end_date=None):
    """
    Generate a list of business days between start_date and end_date (inclusive).
    If end_date is not provided, uses current date.

    Args:
        start_date (str): Start date in 'YYYY-MM-DD' format
        end_date (str, optional): End date in 'YYYY-MM-DD' format

    Returns:
        list: List of datetime objects representing business days
    """
    if end_date is None:
        end_date = datetime.now().strftime("%Y-%m-%d")

    # Convert string dates to datetime objects
    start = pd.to_datetime(start_date)
    end = pd.to_datetime(end_date)

    # Generate business days using pandas
    business_days = pd.date_range(start=start, end=end, freq='B')

    return business_days.tolist()



class FinancialsT(BaseModel):
    year: date
    financials: str


class PriceT(BaseModel):
    open: float
    close: float
    price_date: date
    open_previous: float
    close_previous: float
    previous_date: date


class CompanyInfoT(BaseModel):
    name: str
    description: str


class NewsT(BaseModel):
    news_headlines: List[str]
    news_date : date

class CompanyInfoAtDate(BaseModel):
    ticker: str
    current_date: date

    company_info: CompanyInfoT
    news: NewsT
    financials: FinancialsT
    price: PriceT


class AbstractCompanyInfoCreator:
    @abc.abstractmethod
    def fetch_company_info(self, ticker: str, current_date: date) -> CompanyInfoAtDate:
        pass
def format_datetime(newsdate : date):
  return newsdate.strftime("%Y-%m-%d")


class NewsDatabase():
    def __init__(self, start_date: date, end_date: date):
        self.ds = None
        self.cache = {}
        self.cache_file = os.path.join(CACHE_VOLUME,"news_cache_7.pkl")
        self._load_cache(start_date, end_date)


    def _load_cache(self, start_date: date, end_date: date):
        """Load cache from disk if it exists"""
        try:
            with open(self.cache_file, 'rb') as f:
                self.cache = pickle.load(f)
        except (FileNotFoundError):
            self.cache = {}
            self.preprocess_date_range(start_date, end_date)

    def _save_cache(self):
        """Save cache to disk"""
        with open(self.cache_file, 'wb') as f:
            pickle.dump(self.cache, f)

    def preprocess_date_range(self, start_date: date, end_date: date):
        """
        Preprocess and cache news headlines for all stocks between start_date and end_date.

        Args:
            start_date (date): Start date for preprocessing
            end_date (date): End date for preprocessing
        """
        print(os.listdir(CACHE_VOLUME), os.listdir("/"))
        self.ds = load_dataset(
            "2084Collective/FNSPID_IMPROVED", split="train"
        ).to_pandas()

        # Convert dates to string format for comparison with dataset
        start_str = format_datetime(start_date)
        end_str = format_datetime(end_date)

        # Filter dataset for date range
        date_filtered = self.ds[
            (self.ds["date"] >= start_str) &
            (self.ds["date"] <= end_str)
        ]

        # Group by date and stock
        grouped = date_filtered.groupby(["date", "stock"])["title"].apply(list).to_dict()

        # Update cache
        for (date_str, stock), headlines in tqdm(grouped.items()):
            date_str = (date_str[:10])
            # print(date_str, stock, headlines)
            if date_str not in self.cache:
                print("not in cache")
                self.cache[date_str] = {}
            self.cache[date_str][stock] = headlines
        # Save cache to disk
        self._save_cache()

    def fetch_news_for_date(self, newsdate: date, stock: str, company_info: CompanyInfoT) -> NewsT:
        """
        Fetch news for a given date and stock, using cache if available.

        Args:
            newsdate (date): Date to fetch news for
            stock (str): Stock symbol
            company_info (CompanyInfoT): Company information

        Returns:
            NewsT: News headlines and date
        """
        seven_days_ago = newsdate - timedelta(days=7)
        headlines = []

        # Try to get headlines from cache for the past 7 days
        current_date = seven_days_ago
        while current_date < newsdate:
            date_str = format_datetime(current_date)
            if date_str in self.cache and stock in self.cache[date_str]:
                headlines.extend(self.cache[date_str][stock])
            else:
                pass
                # # If not in cache, fetch from dataset
                # day_headlines = self.ds[
                #     (self.ds["date"] == date_str) &
                #     (self.ds["stock"] == stock)
                # ]["title"].tolist()
                # print(current_date, day_headlines)
                # # Update cache for this date and stock
                # if date_str not in self.cache:
                #     self.cache[date_str] = {}
                # self.cache[date_str][stock] = day_headlines
                # headlines.extend(day_headlines)
                # self._save_cache()

            current_date += timedelta(days=1)
        # print(headlines, newsdate, seven_days_ago)
        return NewsT(news_headlines=headlines, news_date=seven_days_ago)

    def summary(self) -> dict:
        return {
            "max_date": self.ds["date"].max(),
            "min_date": self.ds["date"].min(),
            "stock_count": self.ds["stock"].nunique(),
        }



class PriceOpenPriceCloseDatabase:
    CACHE_FILE = os.path.join(CACHE_VOLUME,"price_cache.pkl")

    def __init__(self):
        self.ds = None
        self.cache = self._load_or_create_cache()

    def _load_or_create_cache(self):
        if os.path.exists(self.CACHE_FILE):
            # Load existing cache
            with open(self.CACHE_FILE, 'rb') as f:
                return pickle.load(f)
        else:
            # Create and save new cache
            cache = self._preprocess_data()
            with open(self.CACHE_FILE, 'wb') as f:
                pickle.dump(cache, f)
            return cache

    def _preprocess_data(self):
        from collections import defaultdict
        self.ds = load_dataset(
            "2084Collective/deepstock-stock-historical-prices-dataset-processed",
            split="train",
        ).to_pandas()
        print("Creating new cache...")
        cache = defaultdict(dict)

        for _, row in tqdm(self.ds.iterrows(), total=13900000):
            date_str = row['date']
            stock = row['stock']
            cache[date_str][stock] = {
                'open': row['open'],
                'close': row['close']
            }

        print("Cache creation complete")
        return cache

    def fetch_open_close_for_date(self, price_date: date, stock: str) -> PriceT:
        seven_days_ago = price_date - timedelta(days=7)
        current_data = self.get_stock_price(stock, price_date)
        assert current_data is not None, f"Could not fetch data for {stock} on {price_date}"
        while (seven_days_ago_data := self.get_stock_price(stock, seven_days_ago)) is None:
            seven_days_ago -= timedelta(days=1)

        return PriceT(
            open=current_data['open'],
            close=current_data['close'],
            price_date=price_date,
            open_previous=seven_days_ago_data['open'],
            close_previous=seven_days_ago_data['close'],
            previous_date=seven_days_ago
        )

    def get_stock_price(self, stock: str, pricedate: date) -> Optional[Dict[str, float]]:
        date_str = format_datetime(pricedate)
        try:
            return self.cache[date_str][stock]
        except KeyError:
            try:
                stock_data = yf.Ticker(stock).history(start=pricedate, end=pricedate+timedelta(days=1))
                print(f"Fetching {stock} data for {date_str}", stock_data)

                open_price = stock_data['Open'].iloc[0]
                close_price = stock_data['Close'].iloc[0]
                self.cache[date_str][stock] = {
                    'open': open_price,
                    'close': close_price
                }
                return self.cache[date_str][stock]
            except Exception as e:
                print(f"Error fetching {stock} data for {date_str}: {e}")
                return None



class FinancialsDatabase:
    def __init__(self):
        self.financials_cache = {}

    def fetch_financials_for_date(self, stock_date: date, stock: str) -> FinancialsT:
        if stock not in self.financials_cache:
            self.financials_cache[stock] = yf.Ticker(stock).financials
        dates = [date.date() for date in self.financials_cache[stock].columns]
        sorted_dates = sorted(dates)
        right_date = None
        for i in range(len(sorted_dates) - 1):
            ind = min(len(sorted_dates) - 1, i + 1)
            if sorted_dates[ind] > stock_date:
                right_date = sorted_dates[ind]
                break
        if right_date is None and stock_date > sorted_dates[-1]:
            right_date = sorted_dates[-1]
        return FinancialsT(
            financials=json.dumps(self.financials_cache[stock][
                right_date.strftime("%Y-%m-%d")
            ].to_dict()),
            year=right_date,
        )


class CompanyInfoDatabase:
    def __init__(self):
        self.company_info_cache = {}

    def fetch_company_info(self, stock: str) -> CompanyInfoT:
        if stock not in self.company_info_cache:
            self.company_info_cache[stock] = yf.Ticker(stock).info
        return CompanyInfoT(
            name=self.company_info_cache[stock]["shortName"],
            description=self.company_info_cache[stock]["longBusinessSummary"],
        )


class CompanyInfoCreator(AbstractCompanyInfoCreator):
    def __init__(self, earliest_date: date, latest_date: date):
        self.news_db = NewsDatabase(earliest_date - timedelta(days=10), latest_date + timedelta(days=10))
        self.price_db = PriceOpenPriceCloseDatabase()
        self.financials_db = FinancialsDatabase()
        self.company_info_db = CompanyInfoDatabase()
        # print(self.news_db.summary())

    def fetch_company_info(self, ticker: str, current_date: date) -> CompanyInfoAtDate:
        # start_time = time.time()
        company_info = self.company_info_db.fetch_company_info(ticker)
        # print(f"Fetched company info in {time.time() - start_time} seconds")
        # start_time = time.time()
        news = self.news_db.fetch_news_for_date(current_date, ticker, company_info)
        # print(f"Fetched news in {time.time() - start_time} seconds")
        # start_time = time.time()
        financials = self.financials_db.fetch_financials_for_date(current_date, ticker)
        # print(f"Fetched financials in {time.time() - start_time} seconds")
        # start_time = time.time()
        price = self.price_db.fetch_open_close_for_date(current_date, ticker)
        # print(f"Fetched price in {time.time() - start_time} seconds")
        return CompanyInfoAtDate(
            ticker=ticker,
            current_date=current_date,
            company_info=company_info,
            news=news,
            financials=financials,
            price=price,
        )
def get_sp500_tickers() -> List[str]:
    return pd.read_html("https://en.wikipedia.org/wiki/List_of_S%26P_500_companies")[0][
        "Symbol"
    ].tolist()
def dump_company(company_info: Optional[CompanyInfoAtDate]) -> dict:
        if company_info is None:
            return None
        return company_info.model_dump()

def process_single_stock(data):
    cic = CompanyInfoCreator(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
    company_info : List[CompanyInfoAtDate] = []
    for ticker, day in zip(data['ticker'], data['day']):
        try:
            ci = cic.fetch_company_info(ticker, day.date())
            company_info.append(ci)
        except Exception as e:
            print(e)
            print(ticker, day)
            company_info.append(None)
            pass
    return {"company_info": [dump_company(ci) for ci in company_info]}

if __name__ == "__main__" and not (os.path.exists("price_cache.pkl") and os.path.exists("news_cache_7.pkl")):
    NewsDatabase(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
    PriceOpenPriceCloseDatabase()

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .pip_install("torch==2.2.1")
    .pip_install([
        "datasets",
        "yfinance",
        "pandas",
        "requests",
        "pydantic",
        "tqdm",
        "holidays",
        "modal",
        "numpy",
        "transformers",
        "huggingface_hub",
    ])
    .run_commands("mkdir /cache")
    .add_local_file("price_cache.pkl", remote_path="/cache/price_cache.pkl")
    .add_local_file("news_cache_7.pkl", remote_path="/cache/news_cache_7.pkl")
)

app = modal.App(name="deepstock", image=image)

@app.function(timeout=2000)
def get_company_info(ticker: str) -> Tuple[List[CompanyInfoAtDate], str]:
    cic = CompanyInfoCreator(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
    company_info : List[CompanyInfoAtDate] = []
    for day in get_business_days(OLDEST_POSSIBLE_DATE, NEWEST_POSSIBLE_DATE):
        try:
            ci = cic.fetch_company_info(ticker, day.date())
            company_info.append(ci)
        except Exception as e:
            print(e)
            print(ticker, day.date())
            company_info.append(None)
            pass
    return company_info, ticker

@app.local_entrypoint()
def main():
    tickers = get_sp500_tickers()
    tickers.remove("KVUE")
    tickers.remove("CEG")
    tickers.remove("VLTO")
    tickers.remove("GEHC")
    company_info_info = {}
    for result in get_company_info.map(tickers):
        company_info_dates, ticker = result
        company_info_info[ticker] = company_info_dates
    with open("company_info.pkl", "wb") as f:
        pickle.dump(company_info_info, f)
    dataset = []
    count_none = 0
    count_total = 0
    for ticker, company_info_dates in company_info_info.items():
        for company_info in company_info_dates:
            count_total += 1
            if company_info is None:
                count_none += 1
                continue
            dataset.append({
                "ticker": ticker,
                "company_info": company_info.model_dump()
            })
    print(f"Total number of data points: {count_total}")
    print(f"Number of missing data points: {count_none}")
    dataset = Dataset.from_list(dataset)
    dataset.push_to_hub("2084Collective/deepstock-sp500-companies-with-info")


In [None]:
!modal setup
!modal run create_deepstock_dataset.py

In [None]:
import re
import json
import math
from datasets import load_dataset
EXAMPLE_COMPANY_INFO = {
    "company_info": {
        "description": "3M Company provides diversified technology services in the United States and internationally. The company's Safety and Industrial segment offers industrial abrasives and finishing for metalworking applications; autobody repair solutions; closure systems for personal hygiene products, masking, and packaging materials; electrical products and materials for construction and maintenance, power distribution, and electrical original equipment manufacturers; structural adhesives and tapes; respiratory, hearing, eye, and fall protection solutions; and natural and color-coated mineral granules for shingles. Its Transportation and Electronics segment provides ceramic solutions; attachment/bonding products, films, sound, and temperature management for transportation vehicles; premium large format graphic films for advertising and fleet signage; light management films and electronics assembly solutions; packaging and interconnection solutions; semiconductor production materials; data centers solutions; and reflective signage for highway, and vehicle safety. The company's Consumer segment provides consumer bandages, braces, supports, and consumer respirators; home cleaning products; retail abrasives, paint accessories, car care DIY products, picture hanging, and consumer air quality solutions; and stationery products. It offers its products through e-commerce and traditional wholesalers, retailers, jobbers, distributors, and dealers. 3M Company was founded in 1902 and is headquartered in Saint Paul, Minnesota.",
        "name": "3M Company"
    },
    "current_date": "2021-06-30",
    "financials": {
        "financials": "{\"Tax Effect Of Unusual Items\": 0.0, \"Tax Rate For Calcs\": 0.178, \"Normalized EBITDA\": 9607000000.0, \"Total Unusual Items\": 0.0, \"Total Unusual Items Excluding Goodwill\": 0.0, \"Net Income From Continuing Operation Net Minority Interest\": 5921000000.0, \"Reconciled Depreciation\": 1915000000.0, \"Reconciled Cost Of Revenue\": 18795000000.0, \"EBITDA\": 9607000000.0, \"EBIT\": 7692000000.0, \"Net Interest Income\": -462000000.0, \"Interest Expense\": 488000000.0, \"Interest Income\": 26000000.0, \"Normalized Income\": 5921000000.0, \"Net Income From Continuing And Discontinued Operation\": 5921000000.0, \"Total Expenses\": 27986000000.0, \"Total Operating Income As Reported\": 7369000000.0, \"Diluted Average Shares\": 585300000.0, \"Basic Average Shares\": 579000000.0, \"Diluted EPS\": 10.12, \"Basic EPS\": 10.23, \"Diluted NI Availto Com Stockholders\": 5921000000.0, \"Net Income Common Stockholders\": 5921000000.0, \"Net Income\": 5921000000.0, \"Minority Interests\": -8000000.0, \"Net Income Including Noncontrolling Interests\": 5929000000.0, \"Net Income Continuous Operations\": 5929000000.0, \"Earnings From Equity Interest Net Of Tax\": 10000000.0, \"Tax Provision\": 1285000000.0, \"Pretax Income\": 7204000000.0, \"Other Income Expense\": 297000000.0, \"Other Non Operating Income Expenses\": 297000000.0, \"Special Income Charges\": 0.0, \"Gain On Sale Of Business\": 0.0, \"Impairment Of Capital Assets\": 0.0, \"Net Non Operating Interest Income Expense\": -462000000.0, \"Interest Expense Non Operating\": 488000000.0, \"Interest Income Non Operating\": 26000000.0, \"Operating Income\": 7369000000.0, \"Operating Expense\": 9191000000.0, \"Research And Development\": 1994000000.0, \"Selling General And Administration\": 7197000000.0, \"General And Administrative Expense\": 7197000000.0, \"Other Gand A\": 7197000000.0, \"Salaries And Wages\": -297000000.0, \"Gross Profit\": 16560000000.0, \"Cost Of Revenue\": 18795000000.0, \"Total Revenue\": 35355000000.0, \"Operating Revenue\": 35355000000.0}",
        "year": "2021-12-31"
    },
    "news": {
        "news_date": "2021-06-23",
        "news_headlines": [
            "Dow Movers: MMM, CVX",
            "2 Stocks I'm Never Selling",
            "Better Buy: GE vs. 3M",
            "C3.ai Is Down More Than 60% From Its Peak. Here's What Happened",
            "Which Industrial Stocks Are Better Bets Compared To Johnson Controls?",
            "Have Insiders Been Selling 3M Company (NYSE:MMM) Shares?"
        ]
    },
    "price": {
        "close": 142.9438018798828,
        "close_previous": 138.43162536621094,
        "open": 140.97916179854076,
        "open_previous": 140.18756548689464,
        "previous_date": "2021-06-23",
        "price_date": "2021-06-30"
    },
    "ticker": "MMM"
}


def company_info_to_user_message(company_info):
        prompt = ""
        prompt += f"""
You are a seasoned stock market analyst who is trying to predict whether the prices will go down or up over the day, {company_info['price']['price_date']},  for a specific stock, by offering a buy or sell rating.
"""
        prompt += f"""
[Company Name]
{company_info['company_info']['name']}
"""
        prompt += f"""
[Company Description]
{company_info['company_info']['description']}
"""
        prompt += f"""
[Price Movement]
It was {company_info['price']['close']} on {company_info['price']['previous_date']}.
The price of the stock on {company_info['price']['price_date']} started at {company_info['price']['open']}.
"""
        news = '\n'.join(company_info['news']['news_headlines'])
        prompt += f"""
[News since {company_info['news']['news_date']}]
{news}
"""
        financials =  json.loads(company_info['financials']['financials'])
        financials_keys = ["Basic EPS", "Normalized EBITDA", "Net Income"]
        for key in financials_keys:
          if key in financials:

            prompt += f"""\n[Financials]\n"""
            break
        for key in financials_keys:
            if key in financials:
                try:
                  if abs(math.log(abs(financials[key]))) > 3:
                    prompt += f"""{key}: ${financials[key]:e}\n"""
                  else:
                    prompt += f"""{key}: ${financials[key]}\n"""
                except Exception as e:
                    print(financials)
                    raise e
        prompt += f"""
Your answer should look like the following
<think>reasoning about why the stock would go up or down here for example
- Recent news highlights insider selling, which could signal low confidence.
- EPS is strong, but EBITDA has dipped slightly.
- The stock has been trending downward for the past week.
</think><answer>sell</answer>
Please reason about and provide several reasons for why you think the stock would go up or down in the <think></think> tags. Please provide your answer as a single rating, 'buy' or 'sell', in the <answer></answer> tags, with buy meaning that the stock price will go up,
and sell meaning that the stock price will go down.
"""
        return prompt



# Run the tests
def format_prompt(example):
  return {
      "user_prompt":company_info_to_user_message(example['company_info'])
  }

print(company_info_to_user_message(EXAMPLE_COMPANY_INFO))
dataset = load_dataset("2084Collective/deepstock-sp500-companies-with-info", split="train")
formatted_dataset = dataset.map(format_prompt)
formatted_dataset.push_to_hub("2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell")



You are a seasoned stock market analyst who is trying to predict whether the prices will go down or up over the day, 2021-06-30,  for a specific stock, by offering a buy or sell rating.

[Company Name]
3M Company

[Company Description]
3M Company provides diversified technology services in the United States and internationally. The company's Safety and Industrial segment offers industrial abrasives and finishing for metalworking applications; autobody repair solutions; closure systems for personal hygiene products, masking, and packaging materials; electrical products and materials for construction and maintenance, power distribution, and electrical original equipment manufacturers; structural adhesives and tapes; respiratory, hearing, eye, and fall protection solutions; and natural and color-coated mineral granules for shingles. Its Transportation and Electronics segment provides ceramic solutions; attachment/bonding products, films, sound, and temperature management for transportati

Map:   0%|          | 0/305860 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/5 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell/commit/a532f3239eed7c3149893df3d1d6ce0a7679e8e4', commit_message='Upload dataset', commit_description='', oid='a532f3239eed7c3149893df3d1d6ce0a7679e8e4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell', endpoint='https://huggingface.co', repo_type='dataset', repo_id='2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell'), pr_revision=None, pr_num=None)

In [None]:
"""
This is  where the Deepseek RL stuff starts, everything after this point can be run seperately, with the only dependency on the above being the dataset 2084Collective/deepstock-sp500-companies-with-info-and-user-prompt which is
produced as a result of running the above.
"""

In [None]:
!pip install vllm==0.6.6.post1
!git clone https://github.com/huggingface/open-r1.git && cd open-r1/ && pip install -e ".[dev]"

Collecting vllm==0.6.6.post1
  Downloading vllm-0.6.6.post1-cp38-abi3-manylinux1_x86_64.whl.metadata (11 kB)
Collecting blake3 (from vllm==0.6.6.post1)
  Downloading blake3-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting fastapi!=0.113.*,!=0.114.0,>=0.107.0 (from vllm==0.6.6.post1)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn[standard] (from vllm==0.6.6.post1)
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting prometheus-fastapi-instrumentator>=7.0.0 (from vllm==0.6.6.post1)
  Downloading prometheus_fastapi_instrumentator-7.0.2-py3-none-any.whl.metadata (13 kB)
Collecting tiktoken>=0.6.0 (from vllm==0.6.6.post1)
  Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting lm-format-enforcer<0.11,>=0.10.9 (from vllm==0.6.6.post1)
  Downloading lm_format_enforcer-0.10.9-py3-none-any.whl.metadata (17 kB)
Collecting outlines==0.

Cloning into 'open-r1'...
remote: Enumerating objects: 442, done.[K
remote: Counting objects: 100% (177/177), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 442 (delta 148), reused 86 (delta 86), pack-reused 265 (from 1)[K
Receiving objects: 100% (442/442), 484.57 KiB | 2.20 MiB/s, done.
Resolving deltas: 100% (228/228), done.
Obtaining file:///content/open-r1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers@ git+https://github.com/huggingface/transformers.git@main (from open-r1==0.1.0.dev0)
  Cloning https://github.com/huggingface/transformers.git (to revision main) to /tmp/pip-install-rn090bx4/transformers_91259089fe664f5aa19c4f694c6274b3
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-install-rn090bx4/transformers_91259089fe664f5aa19c4f694c6274b3
  Resolved https://github.com/huggingface/transformers.git to commit 7eecdf2a8650306ed5fbb6150c64f99f587e004d
  Instal

In [None]:
import re
import json
import math
from datasets import load_dataset
def accuracy_reward(completions, company_info, **kwargs):
    """
    Reward function that checks if the completion correctly predicted price movement.
    Returns 1.0 if prediction matches actual movement, 0.0 otherwise.
    Ignores whitespace in the answer.
    """

    try:
      rewards = []
      contents = [completion[0]["content"] for completion in completions]
      for completion_contents, company_info in zip(contents, company_info):
        pattern = r"^<think>.*?</think><answer>\s*(buy|sell)\s*</answer>$"
        # Extract price data from company_info
        close_price = float(company_info["price"]["close"])
        open_price = float(company_info["price"]["open"])
        actual_movement = "buy" if close_price > open_price else "sell"
        match = re.match(pattern, completion_contents.strip(), re.IGNORECASE)
        if not match:
            rewards.append(0.0)
            continue

        # Extract prediction and remove all whitespace
        prediction = match.group(1).lower().strip()

        # Compare prediction with actual movement
        # if actual_movement == "buy":
        #   ratio = close_price / open_price
        # else:
        #   ratio = open_price / close_price
        ratio = 1
        reward = 1.0*ratio if prediction == actual_movement else 0.0
        rewards.append(reward)

      return rewards
    except Exception as e:
      print(company_info, completions)
      raise e


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format and answer content."""
    pattern = r"^<think>.*?</think><answer>\s*(buy|sell)\s*</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.IGNORECASE) for content in completion_contents]
    lengths = [len(content) for content in completion_contents]
    rewards = []
    for match_r, length in zip(matches, lengths):
      if match_r:
        # print(length)
        rewards.append(1.0)
        # if length > 100:
        #   rewards.append(1.0)
        # else:
        #   rewards.append(0.5)
      else:
        rewards.append(0.0)
    return rewards
def test_reward_functions():
    # Sample completions in various formats
    completions = [
        [{"content": "<think>price went upprice went upprice went upprice went upprice went upprice went upprice went upprice went upprice went upprice went upprice went upprice went up</think><answer>up</answer>"}],  # Correct format, "up"
        [{"content": "<think>price dropped</think><answer>  DOWN  </answer>"}],  # Correct format with whitespace
        [{"content": "<think>analysis</think><answer>sideways</answer>"}],  # Wrong answer
        [{"content": "just saying up"}],  # Wrong format
        [{"content": "<think>going up</think><answer> Up </answer>"}],  # Correct format with mixed case
    ]

    # Sample company data with price going up
    company_info_up = [{
        "price": {
            "open": 100.0,
            "close": 110.0
        }
    }]*len(completions)

    # Sample company data with price going down
    company_info_down = [{
        "price": {
            "open": 110.0,
            "close": 100.0
        }
    }]*len(completions)

    # Test format reward
    format_results = format_reward(completions)
    expected_format = [1.0, 1.0, 0.0, 0.0, 1.0]
    assert format_results == expected_format, f"Format reward test failed. Got {format_results}, expected {expected_format}"

    # Test accuracy reward with upward movement
    accuracy_results_up = accuracy_reward(completions, company_info_up)
    expected_accuracy_up = [1.0, 0.0, 0.0, 0.0, 1.0]
    assert accuracy_results_up == expected_accuracy_up, f"Accuracy reward (up) test failed. Got {accuracy_results_up}, expected {expected_accuracy_up}"

    # Test accuracy reward with downward movement
    accuracy_results_down = accuracy_reward(completions, company_info_down)
    expected_accuracy_down = [0.0, 1.0, 0.0, 0.0, 0.0]
    assert accuracy_results_down == expected_accuracy_down, f"Accuracy reward (down) test failed. Got {accuracy_results_down}, expected {expected_accuracy_down}"

    print("All tests passed!")
test_reward_functions()



AssertionError: Format reward test failed. Got [0.0, 0.0, 0.0, 0.0, 0.0], expected [1.0, 1.0, 0.0, 0.0, 1.0]

In [None]:
from datasets import load_dataset
import random
import numpy as np
from tqdm.auto import tqdm
def generate_random_completion():
    """Generate a random completion in the correct format."""
    answer = random.choice(["up", "down"])
    return [{
        "content": f"<think>Random guess</think><answer>{answer}</answer>"
    }]

def evaluate_random_chance(num_iterations=1):
    """
    Evaluate how well random chance performs using the accuracy_reward function.

    Args:
        num_iterations: Number of times to run the evaluation

    Returns:
        float: Average accuracy across all iterations
    """
    # Load the dataset
    dataset = load_dataset("2084Collective/deepstock-sp500-companies-with-info", split="train")

    # Initialize list to store accuracies for each iteration
    iteration_accuracies = []

    # Run multiple iterations to get a stable estimate
    for iteration in range(num_iterations):
        # Generate random completions for all examples
        completions = [generate_random_completion() for _ in tqdm(range(len(dataset)))]

        # Get company info for each example
        company_infos = [
            example["company_info"]
         for example in tqdm(dataset)]

        # Calculate accuracy for this iteration
        rewards = accuracy_reward(completions, company_infos)
        avg_accuracy = np.mean(rewards)
        iteration_accuracies.append(avg_accuracy)

    # Calculate overall statistics
    final_accuracy = np.mean(iteration_accuracies)
    std_accuracy = np.std(iteration_accuracies)

    print(f"Random Chance Performance:")
    print(f"Average Accuracy: {final_accuracy:.4f}")
    print(f"Standard Deviation: {std_accuracy:.4f}")
    print(f"95% Confidence Interval: [{final_accuracy - 1.96*std_accuracy:.4f}, {final_accuracy + 1.96*std_accuracy:.4f}]")

    return final_accuracy, std_accuracy

evaluate_random_chance()

  0%|          | 0/305860 [00:00<?, ?it/s]

Exception ignored in: <function tqdm.__del__ at 0x7bcf8483f6a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/tqdm/std.py", line 1147, in __del__
    def __del__(self):

KeyboardInterrupt: 


  0%|          | 0/305860 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import re
from dataclasses import dataclass, field
import json
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

MODEL_ID="HuggingFaceTB/SmolLM2-1.7B-Instruct"
# MODEL_ID="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DATASET_ID="2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell"

@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'accuracy', 'format'.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )





reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


def main(script_args, training_args : GRPOConfig, model_args):
    # Get reward functions
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]

    # Load the dataset
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)





    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example['user_prompt']},
            ]
        }

    dataset = dataset.map(make_conversation)

    # Initialize the GRPO trainer
    trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


config = GRPOConfig(
    log_level="debug",
    max_completion_length=256,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    logging_steps=1,
    max_prompt_length=256,
    output_dir="DeepSeek-R1-Distill-Qwen-7B-GRPO",
    run_name="deepstock-check",
    num_train_epochs=1,
    learning_rate=1e-4
)
print(config.device)
script_args = GRPOScriptArguments(dataset_name=DATASET_ID)
model_args = ModelConfig(model_name_or_path=MODEL_ID, use_peft=True,lora_r=22)
main(script_args, config, model_args)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

[2025-02-03 13:39:56,678] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
cuda:0


README.md:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

train-00000-of-00005.parquet:   0%|          | 0.00/28.7M [00:00<?, ?B/s]

train-00001-of-00005.parquet:   0%|          | 0.00/26.2M [00:00<?, ?B/s]

train-00002-of-00005.parquet:   0%|          | 0.00/25.0M [00:00<?, ?B/s]

train-00003-of-00005.parquet:   0%|          | 0.00/28.2M [00:00<?, ?B/s]

train-00004-of-00005.parquet:   0%|          | 0.00/27.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/305860 [00:00<?, ? examples/s]

Map:   0%|          | 0/305860 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/792 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.42G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.76k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

Using auto half precision backend
Currently training with a batch size of: 1
***** Running training *****
  Num examples = 305,860
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 16
  Total optimization steps = 19,116
  Number of trainable parameters = 4,325,376
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mlukasnel2084[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss


KeyboardInterrupt: 