In [1]:
import numpy as np
import pandas as pd
import random
from collections import defaultdict
from algorithms import sarsa,exp_sarsa,nstep_sarsa,q_learning

In [2]:
class TaxiData:
    def __init__(self):
        self.filepath = "final_updated_csv.csv"
        self.df, self.db_len = self.load_data()

    def load_data(self):
        df = pd.read_csv(self.filepath)
        df['tpep_pickup_datetime'] = pd.to_datetime(df['tpep_pickup_datetime'])
        df['tpep_dropoff_datetime'] = pd.to_datetime(df['tpep_dropoff_datetime'])
        df['pickup_time'] = df['tpep_pickup_datetime'].dt.hour * 60 + df['tpep_pickup_datetime'].dt.minute
        df['pickup_day'] = df['tpep_pickup_datetime'].dt.day_name()
        df['dropoff_time'] = df['tpep_dropoff_datetime'].dt.hour * 60 + df['tpep_dropoff_datetime'].dt.minute
        df['dropoff_day'] = df['tpep_dropoff_datetime'].dt.day_name()
        df['pickup_zone'] = df['PULocationID']
        df['dropoff_zone'] = df['DOLocationID']
        df['fare'] = df['total_amount']
        df['trip_distance'] = df['trip_distance']
        return df , len(pd.read_csv(self.filepath))


    def day_to_number(self,day_name):
        # Dictionary mapping day names to day numbers
        day_dict = {
            "Monday": 0,
            "Tuesday": 1,
            "Wednesday": 2,
            "Thursday": 3,
            "Friday": 4,
            "Saturday": 5,
            "Sunday": 6
        }

        # Convert the day name to title case to handle mixed case inputs
        day_name = day_name.title()

        # Return the day number or raise an error if the day name is invalid
        if day_name in day_dict:
            return day_dict[day_name]
        else:
            raise ValueError("Invalid day name. Please enter a valid day of the week.")

    def get_value_by_index(self, index, column_name):
        """Retrieve a specific value from a DataFrame based on an index and column name."""
        try:
            return self.df.at[index, column_name]
        except KeyError:
            print(f"Column {column_name} does not exist in the DataFrame.")
            return None
        except IndexError:
            print(f"Index {index} is out of bounds for the DataFrame.")
            return None



    def find_next_trips_indices(self,minute_of_day, day_of_week, num_trips):
        """
        Find the indices of the next 'num_trips' trips for a given minute of day and day of week.
        """
        # Filter the DataFrame to find entries that match the time and day criteria
        filtered_df = self.df[
            (self.df['pickup_time'] >= minute_of_day) & 
            (self.df['pickup_day'] == day_of_week)
        ]

        # Check if the filtered DataFrame has enough entries
        if not filtered_df.empty:
            if len(filtered_df) >= num_trips:
                return filtered_df.iloc[:num_trips].index.tolist()
            else:
                print("Not enough trips meet the criteria. Returning available trips.")
                return filtered_df.index.tolist()
        else:
            return []  # Return an empty list if no trips are found
    
    def get_next_trip(self,minute_of_day, day_of_week, num_trips):

        #getting drop zone, drop time, fare, drop day
        next_trips_indexes = self.find_next_trips_indices(minute_of_day, day_of_week, num_trips)
        print(minute_of_day, day_of_week, num_trips)
        next_trip_index = random.choice(next_trips_indexes)
        
        drop_zone = self.df['dropoff_zone'][next_trip_index]
        drop_time = self.df['dropoff_time'][next_trip_index]
        fare = self.df['fare'][next_trip_index]
        drop_day = self.df['dropoff_day'][next_trip_index]

        drop_day = self.day_to_number(drop_day)


        return drop_zone, drop_time, fare, drop_day

In [3]:
class CabEnvironment():
    
    def __init__(self):
        
        # State space
        # (zone,minute of the day, day of the week)
        self.max_zones = 200
        self.max_minutes = 1440
        
        self.taxiRequest = TaxiData()
        self.state_space = [[zone, minute_of_day, day_of_the_week] for zone in range(1,self.max_zones+1) for minute_of_day in range(self.max_minutes+1) for day_of_the_week in range(7)]
        
        # Action Space
        # 0 - Wait
        # 1 -  go
        self.action_space = [0,1] 
        
        # current state of the agent
        
        self.state = None
        
    
    def reset(self):
        
        start_minute = random.randint(0, 420)
        start_zone = random.randint(0,self.max_zones)
        self.state = [start_zone,start_minute,0]
        return self.state
    
    
    def reverse(self,day):
        
        
        day_number ={0: 'Monday', 
             1: 'Tuesday', 
             2: 'Wednesday', 
             3: 'Thursday', 
             4: 'Friday', 
             5: 'Saturday', 
             6: 'Sunday'}
        return day_number[day]
    
    
    def step(self,action):
        
        
        reward = 0
        done = False
        
        # Wait action adds 15 mins
        if action == 0:
            next_state = [self.state[0],self.state[1]+15,self.state[2]]


        # Driver decides to go to the next request, gets rewards
        elif action == 1:
            
            drop_zone,fare,drop_time,drop_day = self.getNextRequest(self.state[0],self.state[1],self.state[2])
            next_state = [drop_zone,drop_time,drop_day]
            reward = fare
            
        # if minutes are greater than 1440 then go to the next day
        if next_state[1]>=1440:
            next_state[1]=next_state[1]-1440
            next_state[2]=next_state[2]+1
            
        # checks if the week is over to end the episode
        if next_state[2] == 6:
            done = True
            
        self.state = next_state
            
        return next_state, reward, done
    
    
    def getNextRequest(self,zone,minute_of_day,day):
        
        day_name = self.reverse(day)
        print(day_name)
        drop_zone, drop_time, fare, drop_day = self.taxiRequest.get_next_trip(minute_of_day, day_name, 1)
        

        return drop_zone,fare,drop_time, drop_day
        
        
        
        
        
        
        
        

In [4]:
env = CabEnvironment()

In [6]:
Q_sarsa, history = sarsa(env, num_steps=8000, gamma=1, epsilon=0.1)

TypeError: unhashable type: 'list'