In [28]:
class Recorder:
    def __init__(self):
        self.history = []

    def log(self, timestamp, position, pnl, risk):
        self.history.append({
            'timestamp': timestamp,
            'position': position,
            'pnl': pnl,
            'risk': risk
        })

    def reset(self):
        self.history = []  # Clears history between episodes

    def get_history(self):
        return pd.DataFrame(self.history)

In [None]:
my_recorder = Recorder()
my_recorder.log(
    timestamp=minute,
    position=0,
    pnl=0,
    risk=0
)

In [4]:
from dataclasses import dataclass, field
from typing import List, Dict
@dataclass
class OptionPosition:
    underlying: str
    option_type: str
    strike: float
    expiry: pd.Timestamp
    quantity: int
    price_history: pd.Series = field(default_factory=lambda: pd.Series(dtype='float64'))

    def record_price(self, timestamp, price):
        self.price_history.loc[timestamp] = price

@dataclass
class Portfolio:
    positions: List[OptionPosition] = field(default_factory=list)

    def add_position(self, position: OptionPosition):
        self.positions.append(position)

    def record_prices(self, timestamp, prices: Dict[str, float]):
        for pos in self.positions:
            key = f"{pos.underlying}_{pos.option_type}_{pos.strike}_{pos.expiry.date()}"
            if key in prices:
                pos.record_price(timestamp, prices[key])

# Usage example:
portfolio = Portfolio()
pos = OptionPosition('SPX', 'call', 4500, pd.Timestamp('2024-04-19'), 10)
portfolio.add_position(pos)

portfolio.record_prices(pd.Timestamp('2024-03-25 09:30'), {'SPX_call_4500_2024-04-19': 45.5})
portfolio.record_prices(pd.Timestamp('2024-03-25 09:31'), {'SPX_call_4500_2024-04-19': 46.0})

print(pos.price_history)

NameError: name 'pd' is not defined

In [1]:
import pandas as pd

class Portfolio:
    def __init__(self):
       self.positions = pd.DataFrame({
            'instrument_type': pd.Series(dtype='str'),
            'underlying': pd.Series(dtype='str'),
            'quantity': pd.Series(dtype='float'),
            'option_type': pd.Series(dtype='str'),
            'strike': pd.Series(dtype='float'),
            'expiry': pd.Series(dtype='datetime64[ns]')
})

    def add_position(self, instrument_type, underlying, quantity, option_type=pd.NA, strike=pd.NA, expiry=pd.NaT):
        expiry = pd.Timestamp(expiry) if expiry else None

        mask = (
            (self.positions['instrument_type'] == instrument_type) &
            (self.positions['underlying'] == underlying) &
            (self.positions['option_type'] == option_type) &
            (self.positions['strike'] == strike) &
            (self.positions['expiry'] == expiry)
        )

        if mask.any():
            self.positions.loc[mask, 'quantity'] += quantity
        else:
            if (instrument_type == 'option'):
                new_row = {
                    'instrument_type': instrument_type,
                    'underlying': underlying,
                    'quantity': quantity,
                    'option_type': option_type,
                    'strike': strike,
                    'expiry': expiry
                }
            elif (instrument_type == 'stock'):
                new_row = {
                    'instrument_type': instrument_type,
                    'underlying': underlying,
                    'quantity': quantity
                }
            #print(f"Adding new position: {new_row}")
            self.positions = pd.concat([self.positions, pd.DataFrame([new_row])], ignore_index=True)

    def add_option(self, underlying, quantity, option_type, strike, expiry):
        self.add_position('option', underlying, quantity, option_type, strike, expiry)

    def add_stock(self, underlying, quantity):
        self.add_position('stock', underlying, quantity)

    def get_positions(self):
        return self.positions

In [4]:
import pandas as pd
import datetime

class TradeLedger:
    def __init__(self):
        self.trades_list = []

    def record_trade(self, timestamp, action, instrument_type, underlying, option_type, strike, expiry, quantity, price):
        new_trade = {
            'timestamp': timestamp,
            'action': action,  # 'buy' or 'sell'
            'instrument_type': instrument_type,
            'underlying': underlying,
            'option_type': option_type,
            'strike': strike,
            'expiry': expiry,
            'quantity': quantity,
            'price': price,
            'signed_quantity': quantity if action == 'buy' else -quantity,
            'total_cost': quantity * price
        }
        self.trades_list.append(new_trade)

    @property
    def trades(self):
        #return pd.DataFrame(self.trades_list)
        return pd.DataFrame(self.trades_list).sort_values(by='timestamp').reset_index(drop=True)
        #return self.trades_list


class Portfolio:
    def __init__(self, ledger: TradeLedger):
        self.ledger = ledger
        self.positions = pd.DataFrame({
            'instrument_type': pd.Series(dtype='str'),
            'underlying': pd.Series(dtype='str'),
            'quantity': pd.Series(dtype='float'),
            'option_type': pd.Series(dtype='str'),
            'strike': pd.Series(dtype='float'),
            'expiry': pd.Series(dtype='datetime64[ns]')
        })

    def add_position(self, timestamp, instrument_type, underlying, quantity, price, option_type=None, strike=None, expiry=None):
        expiry = pd.Timestamp(expiry) if expiry else pd.NaT
        expiry = expiry.replace(hour=16, minute=17) if not pd.isna(expiry) else expiry

        mask = (
            (self.positions['instrument_type'] == instrument_type) &
            (self.positions['underlying'] == underlying) &
            (self.positions['option_type'] == option_type) &
            (self.positions['strike'] == strike) &
            (self.positions['expiry'] == expiry)
        )

        if mask.any():
            print(f"Updating existing position: {self.positions[mask]}")
            self.positions.loc[mask, 'quantity'] += quantity
        else:
            if (instrument_type == 'option'):
                new_row = {
                    'instrument_type': instrument_type,
                    'underlying': underlying,
                    'quantity': quantity,
                    'option_type': option_type,
                    'strike': strike,
                    'expiry': expiry
                }
            elif (instrument_type == 'stock'):
                new_row = {
                    'instrument_type': instrument_type,
                    'underlying': underlying,
                    'quantity': quantity
                }
            #print(f"Adding new position: {new_row}")
            self.positions = pd.concat([self.positions, pd.DataFrame([new_row])])
            #self.positions = pd.concat([self.positions, pd.DataFrame([new_row])], ignore_index=True)

        if quantity > 0:
            action = 'buy'
        elif quantity < 0:
            action = 'sell'
        else:
            action = 'none'

        if action != 'none':
            self.ledger.record_trade(timestamp, action, instrument_type, underlying, option_type, strike, expiry, quantity, price)

    def add_option(self, timestamp, underlying, quantity, option_type, strike, expiry, price):
        self.add_position(timestamp, 'option', underlying, quantity, price, option_type, strike, expiry)

    def add_stock(self, timestamp, underlying, quantity, price):
        self.add_position(timestamp, 'stock', underlying, quantity, price)

    def get_positions(self):
        return self.positions

    def get_ledger(self):
        return self.ledger.trades

In [5]:
from zoneinfo import ZoneInfo
ledger= TradeLedger()
portfolio = Portfolio(ledger=ledger)

t = pd.Timestamp(datetime.datetime.now(tz=ZoneInfo('US/Eastern')))
# Initial positions
#portfolio.add_option(t,'SPX', 10,'call', 4500, '2024-04-19',15)
#portfolio.add_stock(t,'MSFT',15, 200)

# Add more positions (existing)
#portfolio.add_option(t,'SPX', 5,'call', 4500, '2024-04-19',25)
#portfolio.add_option(t,'SPX', 5,'call', 4400, '2024-04-19',25)
print("added the new instruments")
#portfolio.add_stock(t,'MSFT', -20,200)

# Check positions
#print(portfolio.get_positions())
#print(portfolio.get_ledger())
def add_straddle_position(portfolio, underlying, quantity, strike, expiry, price,t):
    portfolio.add_option(t,underlying, quantity, 'call', strike, expiry, price/2)
    portfolio.add_option(t,underlying, quantity, 'put', strike, expiry, price/2)

expiry = pd.Timestamp("2024-04-19 16:17", tz="US/Eastern")
add_straddle_position(portfolio, 'SPY', 10, 100, expiry, 15, t)
print(portfolio.get_positions())
print(portfolio.get_ledger())
print(f"cost={portfolio.get_ledger().total_cost.sum()}")

added the new instruments
  instrument_type underlying  quantity option_type  strike  \
0          option        SPY      10.0        call   100.0   
0          option        SPY      10.0         put   100.0   

                      expiry  
0  2024-04-19 16:17:00-04:00  
0  2024-04-19 16:17:00-04:00  
                         timestamp action instrument_type underlying  \
0 2025-04-04 16:29:40.024894-04:00    buy          option        SPY   
1 2025-04-04 16:29:40.024894-04:00    buy          option        SPY   

  option_type  strike                    expiry  quantity  price  \
0        call     100 2024-04-19 16:17:00-04:00        10    7.5   
1         put     100 2024-04-19 16:17:00-04:00        10    7.5   

   signed_quantity  total_cost  
0               10        75.0  
1               10        75.0  
cost=150.0


In [7]:
start_time= pd.Timestamp("2024-04-19 9:31", tz="US/Eastern")
end_time= pd.Timestamp("2024-04-19 16:01", tz="US/Eastern")
all_minutes= pd.date_range(start=start_time, end=end_time, freq='min')
all_minutes

all_texp= (expiry - all_minutes).total_seconds()/3600/24/252

In [8]:
def apply_quadratic_volatility_model(strikes, spot, atm_vol, slope, quadratic_term, texp_years):
    """
    Apply the quadratic volatility model to new data points.
    
    Parameters:
        strikes (array-like): Array of strike prices.
        spot (float): Spot price.
        atm_vol (float): At-the-money volatility.
        slope (float): Slope of the linear term.
        quadratic_term (float): Coefficient of the quadratic term.
        texp_years (float): Time to expiration in years.
    
    Returns:
        array-like: Fitted volatilities for the given strikes.
    """
    log_strikes = np.log(strikes) - np.log(spot)
    #fitted_vols = atm_vol + (slope / np.sqrt(texp_years)) * log_strikes + quadratic_term * log_strikes**2
    fitted_vols = atm_vol + slope * log_strikes + quadratic_term * log_strikes**2
    return fitted_vols

In [9]:
import numpy as np
import signals
def price_portfolio_old(portfolio, all_times,all_spot,atm_vols,slope_param, quadratic_param):

    options_df=portfolio.get_positions()
    options_df = options_df[options_df['instrument_type'] == 'option']
    num_options=options_df.shape[0]  # Count the number of option positions

    num_times=len(all_times)
    price_results=[]
    price_df = pd.DataFrame(index=all_times)
    for row in options_df.itertuples():
        row_expiry = row.expiry
        if (row_expiry.tz is None):
            row_expiry = row.expiry.tz_localize("US/Eastern")
        all_texp=(row_expiry-all_times).total_seconds().to_numpy() / (252 * 24 * 60 * 60) 

        instrument_label = f"{row.underlying}_{row.option_type}_{row.strike}_{row.expiry.date()}"
        all_types=np.full(num_times, row.option_type[0])
        spot_vols=apply_quadratic_volatility_model(row.strike, all_spot, atm_vols, slope_param, quadratic_param, all_texp)
        all_prices=signals.price_instrument(all_types,  all_spot, row.strike,all_texp , spot_vols)
        price_results.append(all_prices)
        price_df[instrument_label]=all_prices

        #i=i+1
    #texp_vec2=np.concatenate(all_texp)
    price_results2=np.stack(price_results, axis=1)

    return price_df, price_results2

#chatGPT update
def price_portfolio(portfolio, all_times, all_spot, atm_vols, slope_param, quadratic_param):
    """
    Prices all options in a portfolio over specified timestamps.

    Args:
        portfolio (Portfolio): Portfolio object containing options.
        all_times (pd.DatetimeIndex): Timestamps at which to price options.
        all_spot (np.array or float): Spot price(s) of underlying.
        atm_vols (np.array): ATM volatilities at each timestamp.
        slope_param (np.array): Slope parameters at each timestamp.
        quadratic_param (np.array): Quadratic parameters at each timestamp.

    Returns:
        price_df (pd.DataFrame): DataFrame of option prices indexed by time.
        price_results_array (np.array): Array of computed option prices.
    """
    options_df = portfolio.get_positions()
    options_df = options_df[options_df['instrument_type'] == 'option']

    num_times = len(all_times)
    price_results = []
    price_df = pd.DataFrame(index=all_times)

    for row in options_df.itertuples():
        row_expiry = row.expiry.tz_localize("US/Eastern") if row.expiry.tz is None else row.expiry

        all_texp = (row_expiry - all_times).total_seconds().to_numpy() / (252 * 24 * 60 * 60)

        instrument_label = f"{row.underlying}_{row.option_type}_{row.strike}_{row.expiry.date()}"
        all_types = np.full(num_times, row.option_type[0])

        spot_vols = apply_quadratic_volatility_model(
            row.strike, all_spot, atm_vols, slope_param, quadratic_param, all_texp
        )

        all_prices = signals.price_instrument(
            all_types, all_spot, row.strike, all_texp, spot_vols
        )

        price_results.append(all_prices)
        price_df[instrument_label] = all_prices

    price_results_array = np.stack(price_results, axis=1)
    return price_df, price_results_array

def compute_pnl_from_precomputed(portfolio, price_df, current_time):
    pnl = 0
    positions = portfolio.get_positions()
    for pos in positions.itertuples():
        if pos.instrument_type == 'option':
            instrument_label = f"{pos.underlying}_{pos.option_type}_{pos.strike}_{pos.expiry.date()}"
        else:
            instrument_label = pos.underlying

        current_price = price_df.loc[current_time, instrument_label]
        pnl += pos.quantity * current_price
    return pnl



    
start_time= pd.Timestamp("2024-04-19 9:31", tz="US/Eastern")
end_time= pd.Timestamp("2024-04-19 16:01", tz="US/Eastern")
all_minutes= pd.date_range(start=start_time, end=end_time, freq='min')
all_minutes

for i in range(0, 1000):
    num_times=len(all_minutes)
    atm_vols=np.full(num_times, 0.2)
    slopes=np.full(num_times, 0.1)
    quadratic_terms=np.full(num_times, 0.1)
    price_df, price_results2=price_portfolio(portfolio, all_minutes, 100, atm_vols, slopes, quadratic_terms)
price_df.sum(axis=1)*550/100
price_results2
print(f'price_df={price_df}')
print(f'portfolio={portfolio.get_positions()}')
pnl=compute_pnl_from_precomputed(portfolio, price_df, all_minutes[0])- portfolio.get_ledger().total_cost.sum()

print(f"totoal cost={portfolio.get_ledger().total_cost.sum()}")
print(f"pnl={pnl}")


price_df=                           SPY_call_100.0_2024-04-19  SPY_put_100.0_2024-04-19
2024-04-19 09:31:00-04:00                   0.266883                  0.266883
2024-04-19 09:32:00-04:00                   0.266554                  0.266554
2024-04-19 09:33:00-04:00                   0.266225                  0.266225
2024-04-19 09:34:00-04:00                   0.265895                  0.265895
2024-04-19 09:35:00-04:00                   0.265565                  0.265565
...                                              ...                       ...
2024-04-19 15:57:00-04:00                   0.059234                  0.059234
2024-04-19 15:58:00-04:00                   0.057734                  0.057734
2024-04-19 15:59:00-04:00                   0.056195                  0.056195
2024-04-19 16:00:00-04:00                   0.054611                  0.054611
2024-04-19 16:01:00-04:00                   0.052981                  0.052981

[391 rows x 2 columns]
portfolio=  instrum

In [20]:
def add_straddle(portfolio, timestamp, underlying, quantity, strike, expiry, price):
    portfolio.add_option(timestamp, underlying, quantity, 'call', strike, expiry, price/2)
    portfolio.add_option(timestamp, underlying, quantity, 'put', strike, expiry, price/2)
ledger = TradeLedger()
portfolio= Portfolio(ledger=ledger)
t = pd.Timestamp(datetime.datetime.now(tz=ZoneInfo('US/Eastern')))
# Initial positions
add_straddle(portfolio, t, 'SPX', 10, 4500, '2024-04-19', 15)

In [None]:

    def price_instrument(self, cp, strike, spot, texp, vol):
        #if self.debug:
        #    print(f"cp={cp}\n, strike={strike}\n, spot={spot}\n, texp={texp}\n, vol={vol}\n")
        #print(f"pricing_insturment sizes: cp={cp}, strike={strike.shape}, spot={spot.shape}, texp={texp.shape}, vol={vol.shape}")
        return py_vollib_vectorized.models.vectorized_black_scholes(cp, spot, strike, texp, 0, vol,return_as="numpy")

In [16]:
from signals import ZeroDTESurfaceLoader, MarketData
import datetime
from zoneinfo import ZoneInfo
#import pytz
loader = ZeroDTESurfaceLoader("./algo_data/vol_surfaces2.csv","./algo_data/spy_daily_prices.csv")
data, underlying_data = loader.load_data()

print(f'main loaded data')
#print(data.head())
metadata = loader.get_metadata(data)
print(metadata)
market_data = MarketData(data, underlying_data)
#print(market_data.get_current_state("2023-09-20 10:00:00-04:00"))
#create a date time minute NY time zone
#minute = datetime.datetime(2024, 9, 20, 10,0 , 0, tzinfo=pytz.timezone('US/Eastern'))
#minute=pytz.timezone('US/Eastern').localize(datetime.datetime(2024, 9, 20, 10, 0, 0))
minute=datetime.datetime(2024, 9, 20, 10, 0, 0, tzinfo=ZoneInfo("US/Eastern"))
#print(f"current minute: {minute}")
#print(market_data.get_current_row(minute))
market_data.set_current_minute(minute)
#market_data.set_current_minute(minute)

row=market_data.get_current_row()

data validated


  df['minute'] = pd.to_datetime(df['minute'],errors="coerce")


0         2024-01-02
1         2024-01-02
2         2024-01-02
3         2024-01-02
4         2024-01-02
             ...    
112994    2025-03-03
112995    2025-03-03
112996    2025-03-03
112997    2025-03-03
112998    2025-03-03
Name: date, Length: 112999, dtype: object
main loaded data
{'start_date': Timestamp('2024-01-02 09:31:00-0500', tz='UTC-05:00'), 'end_date': Timestamp('2025-03-03 16:01:00-0500', tz='UTC-05:00'), 'num_records': 112999}


In [11]:

def _get_state(env, row):
    yest_close=row["under_close_shifted"]
    current_spot=row["implied_spot"]
    under_open=row["under_open"]
    atm_vol=row["atm_vol"]
    scaled_slope=row["scaled_slope"]
    scaled_quadratic=row["scaled_quadratic"]




    """
    state = {
        'time_remaining': env._get_state()[0],
        'current_price': row['SPY'],
        'current_vol': row['ATM Vol'],
        'current_dte': row['DTE'],
        'current_spot': row['SPY'],
        'current_strike': row['Strike'],
        'current_option_type': row['Option Type']
    }
    
    """
    state={
        "current_spot": current_spot,
        "yest_close": yest_close,
        "under_open": under_open,
        "atm_vol": atm_vol,
        "scaled_slope": scaled_slope,
        "scaled_quadratic": scaled_quadratic
    }

    array_fields=["current_spot","yest_close","under_open","atm_vol","scaled_slope","scaled_quadratic"]
    spot_scale_fields=["current_spot","under_open"]
    for field in spot_scale_fields:
        state_array=np.array([state[field] for field in array_fields])
        state[field]=state[field]/yest_close
    state_array=np.array([state[field] for field in array_fields])
    return state,state_array,array_fields


env.reset()
row=env.market_data.get_current_row()
a,a_arr,arr_fields=_get_state(env, row)
print(a)
print(a_arr)



NameError: name 'env' is not defined

In [17]:

import gymnasium as gym  # ✅ Use gymnasium instead of gym
import numpy as np
import pandas as pd
from gymnasium import spaces
from copy import deepcopy
from datetime import datetime, timedelta
import random
#import pytz

import matplotlib.pyplot as plt
import matplotlib.animation as animation





class StraddleEnvironment(gym.Env):
    def __init__(self):
        super(StraddleEnvironment, self).__init__()
        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(low=0, high=1, shape=(1, 1), dtype=np.float32)
        self.episode_duration = timedelta(minutes=180)
        #self.current_time = self.pick_initial_datetime()
        #initial_state = self.data.loc[self.current_time]
        loader = ZeroDTESurfaceLoader("./algo_data/vol_surfaces2.csv","./algo_data/spy_daily_prices.csv")
        data, underlying_data = loader.load_data()
        market_data= MarketData(data, underlying_data)
        self.market_data=market_data
        self.metadata = loader.get_metadata(data)
        self.start_date = self.metadata["start_date"]
        self.end_date = self.metadata["end_date"]

        #state variables
        self.position_opened = False
        self.current_episode_start_time = None
        self.end_time = None
        self.current_time = None

        self.trade_ledger = None
        self.portfolio = None

        #future state variables
        self.all_prices = None
        self.all_texp = None
        self.all_pnl = None


        #define the initial state



    def reset(self):
        self.current_time = pick_random_datetime(self.start_date,self.end_date, timedelta(hours=9,minutes=31), timedelta(hours=16,minutes=1)-self.episode_duration,self.market_data)
        #print(f"current time: {self.current_time}")
        self.end_time = self.current_time + self.episode_duration
        #state = self._get_state()
        info = self._get_info()
        self.market_data.set_current_minute(self.current_time)
        #print(f"current time: {self.current_time}")
        #print(f"market_data.current_row: {self.market_data.get_current_row()}")
        self.episode_start_time = self.current_time
        self.position_opened = False
        self.trade_ledger = TradeLedger()
        self.portfolio = Portfolio(ledger=self.trade_ledger)
        self.option_expiry=self.market_data.get_current_row()["minute"].replace(hour=16, minute=17)
        #self.option_expiry = pd.Timestamp("2024-04-19 16:17", tz="US/Eastern")
        self.last_pnl=0


        #print(f"current time: {self.current_time}")
        #state= self.market_data.get_current_row()
        #print("reset about to get state")
        (state_dict,state,state_fields)= self._get_state()
        return state, info

    def add_straddle(portfolio, timestamp, underlying, quantity, strike, expiry, price):
        portfolio.add_option(timestamp, underlying, quantity, 'call', strike, expiry, price/2)
        portfolio.add_option(timestamp, underlying, quantity, 'put', strike, expiry, price/2)

    def step(self, action):
        done = False
        market_row = self.market_data.get_current_row()
        if action == 0:
            if not self.position_opened:
                # Open a position
                self.position_opened = True
                spot_price = market_row["implied_spot"]
                straddle_price = market_row["straddle_price"]
                #print(f"straddle price: {straddle_price}")
                #print(f"current row: {market_row}")
                #print(f"spot price: {spot_price}")
                self.portfolio.add_option(self.current_time, 'SPY', 1, 'call', spot_price, self.option_expiry, straddle_price/2)
                self.portfolio.add_option(self.current_time, 'SPY', 1, 'put', spot_price, self.option_expiry, straddle_price/2)
                day_atm_vols = self.market_data.df_today["atm_vol"].to_numpy()
                day_slopes = self.market_data.df_today["slope"].to_numpy()
                day_quadratic_terms = self.market_data.df_today["quadratic_term"].to_numpy()
                all_times = self.market_data.df_today["minute"].to_numpy()
                all_times=pd.to_datetime(all_times)
                all_spots = self.market_data.df_today["implied_spot"].to_numpy()
                self.portfolio_prices, self.portfolio_price_arr= price_portfolio(self.portfolio, all_times, all_spots, day_atm_vols, day_slopes, day_quadratic_terms)
                self.buy_time= self.current_time
                #print(f"portfolio prices: {self.portfolio_prices.where(self.portfolio_prices.index == self.current_time)}")
                #self.market_data.open_position(self.current_time)
        elif action == 1:
            if self.position_opened:
                # Close the position
                self.position_opened = True
                #self.market_data.close_position(self.current_time)
                done = True
        elif action == 2:
            # Do nothing
            pass
        pnl=0
        if self.position_opened:
            #print("bla bla bla")
            pnl= compute_pnl_from_precomputed(self.portfolio, self.portfolio_prices, self.current_time) - self.portfolio.get_ledger().total_cost.sum()
        reward = pnl - self.last_pnl
        self.last_pnl = pnl
        self.current_time += timedelta(minutes=30)
        self.market_data.increment_minute(30)
        if self.current_time >= self.end_time:
            self.current_time = self.end_time
            #self.market_data.increment_minute(30)
            done = True
            truncated = True
        (state_dict,state,state_fields) = self._get_state()
        info = self._get_info()
        truncated = False
        return state, reward, done, truncated, info

    """
    def _get_state(self):
        time_remaining= (self.end_time - self.current_time)/self.episode_duration
        return np.array([time_remaining])
    """

    def _get_state(self):
        row=self.market_data.get_current_row()
        yest_close = row["under_close_shifted"]
        time_remaining= (self.end_time - self.current_time)/self.episode_duration
        state = {
            "current_spot": row["implied_spot"] / yest_close,
            "under_open": row["under_open"] / yest_close,
            "atm_vol": row["atm_vol"],
            "scaled_slope": row["scaled_slope"],
            "scaled_quadratic": row["scaled_quadratic"],
            "pct_straddle_price": row["pct_straddle_price"] / yest_close,
            "texp": row["years_to_maturity"],
            "yest_close": 1.0,  # normalized reference
            "time_remaining": time_remaining,
            "has_position": 1.0 if self.position_opened else 0.0,
        }

        array_fields = ["current_spot", "under_open", "atm_vol", "scaled_slope", "scaled_quadratic", "pct_straddle_price","texp","yest_close", "time_remaining", "has_position"]
        state_array = np.array([state[field] for field in array_fields])

        return state, state_array, array_fields

    def _state_to_dict(self, state):
        return {"time_remaining": state[0]}

    def _get_info(self):
        return {}


    """

    def render(self, mode='human'):
        state_info = {
            "current_time": self.current_time.isoformat(),
            "end_time": self.end_time.isoformat()
        }
        print(state_info)

    """

    def render(self, mode='human'):
        fig, ax = plt.subplots(figsize=(5, 3))
        ax.barh(['Time Remaining'], [self._get_state()[0]], color='blue')
        ax.set_xlim(0, 1)
        ax.set_xlabel('Normalized Time Remaining')
        ax.set_title(f'Time: {self.current_time.strftime("%H:%M")}')
        plt.tight_layout()

        if mode == 'human':
            plt.show()
            plt.close(fig)
        elif mode == 'rgb_array':
            fig.canvas.draw()
            width, height = fig.canvas.get_width_height()
            image = np.asarray(fig.canvas.renderer.buffer_rgba())
            plt.close(fig)
            return image[:, :, :3]
        plt.close(fig)

    def close(self):
        pass


    def rollout(self):
        state, _ = self.reset()
        done = False
        states, rewards, actions, infos = [], [], [], []
        while not done:
            action = self.action_space.sample()
            state1, reward, done, truncated, info = self.step(action)
            #print(f"State: {state}, Reward: {reward}, Done: {done}, Info: {info}")
            states.append(state)
            rewards.append(reward)
            actions.append(action)
            infos.append(info)
            state=state1

        states.append(state)
        return states, rewards, actions, infos, done, truncated

    def pick_initial_datetime(self):
        valid_start_times = self.data.index[self.data.index <= self.data.index.max() - self.episode_duration]
        initial_datetime = np.random.choice(valid_start_times)
        return initial_datetime

def pick_random_datetime(start_date: datetime, end_date: datetime, start_time: timedelta, end_time: timedelta,market_data,tz_str='America/New_York'):
    random_date = start_date + timedelta(days=random.randint(0, (end_date - start_date).days))
    while (not is_valid_date(random_date) or (pd.Timestamp(random_date.date()).tz_localize(ZoneInfo("US/Eastern"))in market_data.missing_dates)):
        random_date = start_date + timedelta(days=random.randint(0, (end_date - start_date).days))
        #print(f"random date: {random_date}")
        #print(f"missing dates: {market_data.missing_dates}")
    

    random_seconds = random.randint(int(start_time.total_seconds()), int(end_time.total_seconds()))//60*60
    random_datetime = datetime.combine(random_date, datetime.min.time()) + timedelta(seconds=random_seconds)
    random_datetime = random_datetime.replace(tzinfo=ZoneInfo("US/Eastern"))
    #random_datetime = pytz.timezone(tz_str).localize(random_datetime)
    return random_datetime

"""
def pick_random_datetime(start_date: datetime, end_date: datetime, start_time: timedelta, end_time: timedelta,tz_str='America/New_York'):
    random_date = start_date + timedelta(days=random.randint(0, (end_date - start_date).days))
    while (not is_valid_date(random_date)):
        random_date = start_date + timedelta(days=random.randint(0, (end_date - start_date).days))

    random_seconds = random.randint(int(start_time.total_seconds()), int(end_time.total_seconds()))//60*60
    random_datetime = datetime.combine(random_date, datetime.min.time()) + timedelta(seconds=random_seconds)
    random_datetime = random_datetime.replace(tzinfo=ZoneInfo("US/Eastern"))
    #random_datetime = pytz.timezone(tz_str).localize(random_datetime)
    return random_datetime
"""


def is_valid_date(dt_datetime: datetime) -> bool:
    if dt_datetime.weekday()>4:
        return False
    return True

def run_episode(env):
    env.reset()
    done = False
    states=[]
    rewards=[]
    actions=[]
    truncates=[]
    infos=[]

    
    while not done:
        action = env.action_space.sample()
        state, reward, done, truncated, info = env.step(action)
        print(f"State: {state}, Reward: {reward}, Done: {done}, Info: {info}")
        states.append(state)
        rewards.append(reward)
        actions.append(action)
        infos.append(info)
        if done:
            break
    return states, rewards, actions, infos


In [18]:
env=StraddleEnvironment()

data validated


  df['minute'] = pd.to_datetime(df['minute'],errors="coerce")


0         2024-01-02
1         2024-01-02
2         2024-01-02
3         2024-01-02
4         2024-01-02
             ...    
112994    2025-03-03
112995    2025-03-03
112996    2025-03-03
112997    2025-03-03
112998    2025-03-03
Name: date, Length: 112999, dtype: object


In [8]:
env.market_data.df_today

In [52]:
env.reset()

(array([ 9.99618991e-01,  1.00303431e+00,  2.52549812e-01, -6.10577406e-03,
         1.28868427e-01,  1.09066999e-05,  7.74360670e-04,  1.00000000e+00,
         1.00000000e+00]),
 {})

In [14]:
env.reset()
state, reward, done, truncated, info = env.step(1)
print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated} Info: {info}")
state,reward, done, truncated,info= env.step(1)
print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated}, Info: {info}")
state,reward, done, truncated,info= env.step(2)
print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated}, Info: {info}")



state,reward, done, truncated,info= env.step(0)

print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated}, Info: {info}")

state,reward, done, truncated,info= env.step(2)
print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated}, Info: {info}")
state,reward, done, truncated,info= env.step(1)
print(f"State: {state}, Reward: {reward}, Done: {done}, truncated: {truncated}, Info: {info}")



State: [ 1.00166788e+00  1.00000000e+00  1.57329438e-01 -8.10876030e-03
  2.57613524e-01  5.33109443e-06  6.06261023e-04  1.00000000e+00
  8.33333333e-01  0.00000000e+00], Reward: 0, Done: False, truncated: False Info: {}
State: [ 1.00286820e+00  1.00000000e+00  1.57352202e-01 -8.17349313e-03
  2.57550510e-01  4.95501194e-06  5.23589065e-04  1.00000000e+00
  6.66666667e-01  0.00000000e+00], Reward: 0, Done: False, truncated: False, Info: {}
State: [ 1.00254416e+00  1.00000000e+00  1.52901579e-01 -5.16829293e-03
  3.06212129e-01  4.41842096e-06  4.40917108e-04  1.00000000e+00
  5.00000000e-01  0.00000000e+00], Reward: 0, Done: False, truncated: False, Info: {}
State: [ 1.00307633e+00  1.00000000e+00  1.48368559e-01 -5.07125054e-03
  3.37991038e-01  3.86463716e-06  3.58245150e-04  1.00000000e+00
  3.33333333e-01  1.00000000e+00], Reward: 0.0, Done: False, truncated: False, Info: {}
State: [ 1.00470829e+00  1.00000000e+00  1.55660020e-01 -3.84719850e-03
  3.66728031e-01  3.55608618e-06  2

In [None]:
env.current_time
env.current_time= (env.current_time - timedelta(minutes=30))
env.market_data.set_current_minute(env.current_time)
#env.market_data.get_current_row()
#env.portfolio_prices.loc[env.current_time].sum()
env.market_data.df_today[env.market_data.df_today["minute"] == env.current_time]["straddle_price"]
env.

11168    1.138257
Name: straddle_price, dtype: float64

In [15]:
#for i in range(0, 1000):
# Reset the environment
#state, info = env.reset()
for i in range(0, 10000):
    env.rollout()
# Perform a rollout



In [None]:
env.portfolio.ledger.trades.total_cost.sum()
env.portfolio.get_positions()

compute_pnl_from_precomputed(env.portfolio, env.portfolio_prices, env.current_time)
env.portfolio.get_ledger().total_cost.sum()

np.float64(1.273206998291203)

In [None]:
bla=env.rollout()
bla[0][0]

array([1.])

In [None]:
action= random.sample(range(0,3), 1)[0]
action

0

In [None]:
env.reset()

(minute                 2025-01-30 10:45:00-05:00
 implied_spot                          603.241624
 atm_vol                                 0.227287
 slope                                   -1.85415
 quadratic_term                        189.117114
 scaled_slope                           -0.004658
 scaled_quadratic                        0.119376
 open_price                                603.46
 high_price                                603.47
 low_price                                  603.2
 close_price                               603.25
 volume                                   33733.0
 vwap                                    603.3708
 timestamp                          1738251840000
 transactions                                 675
 otc                                          NaN
 timestamp_utc          2025-01-30 15:44:00+00:00
 timestamp_est          2025-01-30 10:44:00-05:00
 date                   2025-01-30 00:00:00-05:00
 years_to_maturity                       0.000631


In [None]:
for i in range(10000):
    #print(f"Episode {i+1}")
    obs,_= env.reset()
print(obs)


minute                 2024-02-23 12:31:00-05:00
implied_spot                           507.83335
atm_vol                                 0.193386
slope                                  -1.858734
quadratic_term                         479.10558
scaled_slope                           -0.003853
scaled_quadratic                        0.205867
open_price                                508.26
high_price                                508.26
low_price                                 507.81
close_price                               507.85
volume                                   98517.0
vwap                                    508.0034
timestamp                          1708709400000
transactions                                1110
otc                                          NaN
timestamp_utc          2024-02-23 17:30:00+00:00
timestamp_est          2024-02-23 12:30:00-05:00
date                   2024-02-23 00:00:00-05:00
years_to_maturity                        0.00043
straddle_price      

In [None]:
env.current_time
#env.market_data.missing_dates

datetime.datetime(2024, 3, 29, 12, 36, tzinfo=zoneinfo.ZoneInfo(key='US/Eastern'))

In [None]:
env.current_time.date() in env.market_data.missing_dates

pd.Timestamp(env.current_time.date()).tz_localize(ZoneInfo("US/Eastern"))in env.market_data.missing_dates

True

In [None]:
env.current_time
any(env.market_data.df["date"]==env.current_date)
env.current_time.date() #in env.market_data.missing_dates

AttributeError: 'StraddleEnvironment' object has no attribute 'current_date'

In [None]:
env.market_data.missing_dates

[Timestamp('2024-01-15 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2024-02-19 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2024-03-29 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-05-27 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-06-19 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-07-03 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-07-04 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-09-02 00:00:00-0400', tz='US/Eastern'),
 Timestamp('2024-11-28 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2024-11-29 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2024-12-24 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2024-12-25 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2025-01-01 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2025-01-09 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2025-01-20 00:00:00-0500', tz='US/Eastern'),
 Timestamp('2025-02-17 00:00:00-0500', tz='US/Eastern')]

In [None]:
all_dates=env.market_data.df["date"].unique()
all_dates
#find missing weekdays
all_dates = pd.to_datetime(all_dates)
all_dates = sorted(all_dates)
date_range= pd.date_range(start=all_dates[0], end=all_dates[-1], freq='B')
missing_dates = []
for dt in date_range: 
    if dt not in all_dates:
        missing_dates.append(dt)
missing_dates

[Timestamp('2024-01-15 00:00:00'),
 Timestamp('2024-02-19 00:00:00'),
 Timestamp('2024-03-29 00:00:00'),
 Timestamp('2024-05-27 00:00:00'),
 Timestamp('2024-06-19 00:00:00'),
 Timestamp('2024-07-03 00:00:00'),
 Timestamp('2024-07-04 00:00:00'),
 Timestamp('2024-09-02 00:00:00'),
 Timestamp('2024-11-28 00:00:00'),
 Timestamp('2024-11-29 00:00:00'),
 Timestamp('2024-12-24 00:00:00'),
 Timestamp('2024-12-25 00:00:00'),
 Timestamp('2025-01-01 00:00:00'),
 Timestamp('2025-01-09 00:00:00'),
 Timestamp('2025-01-20 00:00:00'),
 Timestamp('2025-02-17 00:00:00')]

In [None]:

missing_minutes={}
for dt in date_range:
    if dt not in missing_dates:
        daily_df= env.market_data.df[env.market_data.df["date"]==dt]
        full_day_minutes = pd.date_range(
            start=dt.replace(hour=9, minute=31, tzinfo=ZoneInfo("US/Eastern")), 
            end=dt.replace(hour=16, minute=1, tzinfo=ZoneInfo("US/Eastern")),
            freq='min'
        )
        missing = full_day_minutes.difference(daily_df["minute"])
        if (len(missing)>0):
            missing_minutes[dt] = missing
#print(full_day_minutes)
print(missing_minutes)

{}


In [None]:
all_dfs={}
for dt in missing_dates:
    daily_df= env.market_data.df[env.market_data.df["date"]==dt]
    all_dfs[dt]=daily_df
