# Stock Trading

The cell below defines the **abstract class** whose API you need to implement. **Do NOT modify it** - use the dedicated cell further below for your implementation instead.

In [None]:
# DO NOT MODIFY THIS CELL

from abc import ABC, abstractmethod  
      

# abstract class to represent a stock trading platform
class AbstractStockTradingPlatform(ABC):
    
    # constructor
    @abstractmethod
    def __init__(self):
        pass           
        
    # adds transactionRecord to the set of completed transactions
    @abstractmethod
    def logTransaction(self, transactionRecord):
        pass

    # returns a list with all transactions of a given stockName,
    # sorted by increasing trade value. 
    # stockName : str
    @abstractmethod
    def sortedTransactions(self, stockName): 
        sortedList = []
        return sortedList    
    
    # returns a list of transactions of a given stockName with minimum trade value
    # stockName : str
    @abstractmethod
    def minTransactions(self, stockName): 
        minList = []
        return minList    
    
    # returns a list of transactions of a given stockName with maximum trade value
    # stockName : str
    @abstractmethod
    def maxTransactions(self, stockName): 
        maxList = []
        return maxList    

    # returns a list of transactions of a given stockName, 
    # with the largest trade value below a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def floorTransactions(self, stockName, thresholdValue): 
        floorList = []
        return floorList    

    # returns a list of transactions of a given stockName, 
    # with the smallest trade value above a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def ceilingTransactions(self, stockName, thresholdValue): 
        ceilingList = []
        return ceilingList    

        
    # returns a list of transactions of a given stockName,  
    # whose trade value is within the range [fromValue, toValue].
    # stockName : str
    # fromValue : double
    # toValue : double
    @abstractmethod
    def rangeTransactions(self, stockName, fromValue, toValue): 
        rangeList = []
        return rangeList    

Use the cell below to define any data structure and auxiliary python function you may need. Leave the implementation of the main API to the next code cell instead.

In [None]:
# ADD AUXILIARY DATA STRUCTURE DEFINITIONS AND HELPER CODE HERE

import datetime as d
import typing as t


# This class models a transaction record.
# It stores all the relevant information associated with a single transaction record.
class Trade:
    def __init__(self, name: str, price: float, quantity: int, time: d.datetime) -> None:
        self.name = name
        self.price = price
        self.quantity = quantity
        self.time = time

    # This helper method makes it easy to retrieve the trade value of a Trade object
    def get_trade_val(self) -> float:
        return self.price * self.quantity

    # Converting to a list makes it easier to work with in some applications and can have performance advantages
    def to_list(self) -> list:
        return [self.name, self.price, self.quantity, self.time]


# This class models the nodes used in the TradeTree class.
# The trade value is interpreted as the key of a TradeNode object.
# The array of trades is interpreted as the value of a TradeNode object.
class TradeNode:
    RED = True
    BLACK = False

    def __init__(self, trade: Trade) -> None:
        self.trade_val = trade.get_trade_val()

        # List of all Trade objects with the same trade value
        self.trades = [trade]

        self.left = None
        self.right = None
        self.color = TradeNode.RED


# This class models all information for a single stock name.
# All trades on a given stock will be stored here.
# This has no model of any other stocks.
# It implements a balanced search tree ADT using a left-leaning red-black binary search tree.
# Each node in the tree is a TradeNode object.
class TradeTree:
    def __init__(self, stock_name: str) -> None:
        self.stock_name = stock_name
        self.root = None

    def put_trade(self, trade: Trade) -> None:
        # Ensure that the Trade object to be inserted matches the stock name of the current TradeTree object
        if trade.name != self.stock_name:
            raise ValueError("Invalid Stock Name")

        self.root = self.__insert(trade, self.root)

        # Maintain invariant of coloring root node black
        self.root.color = TradeNode.BLACK

    def __insert(self, trade: Trade, node: TradeNode) -> TradeNode:
        if node is None:
            return TradeNode(trade)

        # Use trade value as the key for inserting TradeNode objects into the TradeTree
        trade_val = trade.get_trade_val()

        if trade_val == node.trade_val:
            node.trades.append(trade)
        elif trade_val < node.trade_val:
            node.left = self.__insert(trade, node.left)
        elif trade_val > node.trade_val:
            node.right = self.__insert(trade, node.right)

        # Recursively balance the TradeTree to maintain logarithmic height
        return TradeTree.__balance(node)

    def get_all_trades(self, node: TradeNode = None) -> t.List[Trade]:
        # Base case where the root of the TradeTree has not been initialized
        if self.root is None:
            return []

        # Optional node parameter used to display some subtree
        if node is None:
            node = self.root

        all_trades = []

        if node.left is not None:
            all_trades = self.get_all_trades(node.left)

        all_trades.extend(node.trades)

        if node.right is not None:
            all_trades.extend(self.get_all_trades(node.right))

        return all_trades

    def get_min_trades(self) -> t.List[Trade]:
        if self.root is None:
            return []

        node = self.root

        while node.left is not None:
            node = node.left

        return node.trades

    def get_max_trades(self) -> t.List[Trade]:
        if self.root is None:
            return []

        node = self.root

        while node.right is not None:
            node = node.right

        return node.trades

    def get_floor_trades(self, high: float) -> t.List[Trade]:
        node = self.root
        floor_trades = []

        while node is not None:
            if node.trade_val == high:
                return node.trades
            elif node.trade_val < high:
                floor_trades = node.trades
                node = node.right
            elif node.trade_val > high:
                node = node.left

        return floor_trades

    def get_ceil_trades(self, low: float) -> t.List[Trade]:
        node = self.root
        ceil_trades = []

        while node is not None:
            if node.trade_val == low:
                return node.trades
            elif node.trade_val > low:
                ceil_trades = node.trades
                node = node.left
            elif node.trade_val < low:
                node = node.right

        return ceil_trades

    def get_trades_in_range(self, low: float, high: float, node: TradeNode = None) -> t.List[Trade]:
        # Ensure that the low and high parameters are valid
        if low > high or low < 0:
            raise ValueError("Invalid Range")

        if self.root is None:
            return []

        if node is None:
            node = self.root

        trades_in_range = []

        if low <= node.trade_val <= high:
            if node.left is not None:
                trades_in_range = self.get_trades_in_range(low, high, node.left)

            trades_in_range.extend(node.trades)

            if node.right is not None:
                trades_in_range.extend(self.get_trades_in_range(low, high, node.right))
        elif node.trade_val < low and node.right is not None:
            trades_in_range = self.get_trades_in_range(low, high, node.right)
        elif node.trade_val > high and node.left is not None:
            trades_in_range = self.get_trades_in_range(low, high, node.left)

        return trades_in_range

    @staticmethod
    def __rotate_left(node: TradeNode) -> TradeNode:
        x = node.right
        node.right = x.left
        x.left = node
        x.color = node.color
        node.color = TradeNode.RED
        return x

    @staticmethod
    def __rotate_right(node: TradeNode) -> TradeNode:
        x = node.left
        node.left = x.right
        x.right = node
        x.color = node.color
        node.color = TradeNode.RED
        return x

    @staticmethod
    def __flip_colors(node: TradeNode) -> None:
        node.color = TradeNode.RED
        node.left.color = TradeNode.BLACK
        node.right.color = TradeNode.BLACK

    @staticmethod
    def __is_red(node: TradeNode) -> bool:
        return node.color == TradeNode.RED if node is not None else False

    # This method maintains the structural invariants of a left-leaning red-black binary search tree
    @staticmethod
    def __balance(node: TradeNode) -> TradeNode:
        if TradeTree.__is_red(node.right) and not TradeTree.__is_red(node.left):
            node = TradeTree.__rotate_left(node)

        if TradeTree.__is_red(node.left) and TradeTree.__is_red(node.left.left):
            node = TradeTree.__rotate_right(node)

        if TradeTree.__is_red(node.left) and TradeTree.__is_red(node.right):
            TradeTree.__flip_colors(node)

        return node

Use the cell below to implement the requested API. 

In [None]:
# IMPLEMENT HERE THE REQUESTED API

import typing as t


# This class implements a hash table ADT using a Python's built-in dictionary data structure.
# The keys are the stock names and the values are references to TradeTree objects.
# Each method in this class does thorough error checking before calling operations on the TradeTree objects.
class StockTradingPlatform(AbstractStockTradingPlatform):
    def __init__(self) -> None:
        self.STOCKS = ["Barclays", "HSBA", "Lloyds Banking Group", "NatWest Group", "Standard Chartered", "3i",
                       "Abrdn", "Hargreaves Lansdown", "London Stock Exchange Group", "Pershing Square Holdings",
                       "Schroders", "St. James's Place plc."]

        self.__trade_trees = {}

        for stock in self.STOCKS:
            self.__trade_trees[stock] = TradeTree(stock)

    def logTransaction(self, transactionRecord: list) -> None:
        trade = Trade(*transactionRecord)
        self.__validate_trade(trade)
        self.__trade_trees[trade.name].put_trade(trade)

    def sortedTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("sortedTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_all_trades()

    def minTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("minTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_min_trades()

    def maxTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("maxTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_max_trades()

    def floorTransactions(self, stockName: str, thresholdValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("floorTransactions: Invalid Stock Name: " + stockName)

        if thresholdValue < 0:
            raise ValueError("floorTransactions: Invalid Transaction Value: " + str(thresholdValue))

        return self.__trade_trees[stockName].get_floor_trades(thresholdValue)

    def ceilingTransactions(self, stockName: str, thresholdValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("ceilingTransactions: Invalid Stock Name: " + stockName)

        if thresholdValue < 0:
            raise ValueError("ceilingTransactions: Invalid Transaction Value: " + str(thresholdValue))

        return self.__trade_trees[stockName].get_ceil_trades(thresholdValue)

    def rangeTransactions(self, stockName: str, fromValue: float, toValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("rangeTransactions: Invalid Stock Name: " + stockName)

        if fromValue > toValue or fromValue < 0 or toValue < 0:
            raise ValueError(
                "rangeTransactions: Invalid Range Bounds: fromValue: " + str(fromValue) + " toValue: " + str(toValue)
            )

        return self.__trade_trees[stockName].get_trades_in_range(fromValue, toValue)

    # This private helper method ensures that the transaction records to be inserted are valid
    def __validate_trade(self, trade: Trade) -> None:
        if trade.name not in self.STOCKS:
            raise ValueError("Invalid Stock Name: " + trade.name)

        if trade.quantity < 1:
            raise ValueError("Invalid Stock Quantity: " + str(trade.quantity))

        if trade.price <= 0.0:
            raise ValueError("Invalid Stock Price: " + str(trade.price))

The cell below provides helper code that you can use within your experimental framework to generate random transaction data. **Do NOT modify it**.

In [None]:
# DO NOT MODIFY THIS CELL

import random
from datetime import timedelta
from datetime import datetime

class TransactionDataGenerator:
    def __init__(self):
        self.stockNames = ["Barclays", "HSBA", "Lloyds Banking Group", "NatWest Group", 
                      "Standard Chartered", "3i", "Abrdn", "Hargreaves Lansdown", 
                      "London Stock Exchange Group", "Pershing Square Holdings", 
                      "Schroders", "St. James's Place plc."]
        self.minTradeValue = 500.00
        self.maxTradeValue = 100000.00
        self.startDate = datetime.strptime('1/1/2022 1:00:00', '%d/%m/%Y %H:%M:%S')
        random.seed(20221603)
          
    # returns the name of a traded stock at random
    def getStockName(self):
        return random.choice(self.stockNames)

    # returns the trade value of a transaction at random
    def getTradeValue(self):
        return round(random.uniform(self.minTradeValue, self.maxTradeValue), 2)
    
    # returns a list of N randomly generated transactions,
    # where each transaction is represented as a list [stock name, price, quantity, timestamp]
    # N : int
    def generateTransactionData(self, N):   
        listTransactions = [[]]*N
        listDates = [self.startDate + timedelta(seconds=3*x) for x in range(0, N)]
        listDatesFormatted = [x.strftime('%d/%m/%Y %H:%M:%S') for x in listDates]
        for i in range(N):
            stockName = random.choice(self.stockNames)
            price = round(random.uniform(50.00, 100.00), 2)
            quantity = random.randint(10,1000)
            listTransactions[i] = [stockName, price, quantity, listDatesFormatted[i]]   
        return listTransactions

Use the cell below for the python code needed to realise your **experimental framework** (i.e., to generate test data, to instante the `StockTrading` class, to thorouhgly experiment with its API functions, and to experimentally measure their performance). You may use the previously provided ``TransactionDataGenerator`` class to generate random transaction data.

In [None]:
import random
import timeit

# ADD YOUR EXPERIMENTAL FRAMEWORK CODE HERE




The cell below exemplifies **debug** code I will invoke on your submission - it does not represent an experimental framework (which should me much more comprehensive). **Do NOT modify it**. 

In [None]:
# DO NOT MODIFY THIS CELL

import timeit

testPlatform = StockTradingPlatform()
testDataGen = TransactionDataGenerator()

numTransactions = 1000000
testData = testDataGen.generateTransactionData(numTransactions)

numRuns = 100

print("Examples of transactions:", testData[0], testData[numTransactions//2], testData[numTransactions-1])

#
# testing the logTransaction() API 
#
starttime = timeit.default_timer()
for i in range(numTransactions):
    testPlatform.logTransaction(testData[i])
endtime = timeit.default_timer()
print("\nExecution time to load", numTransactions, "transactions:", round(endtime-starttime,4))

#
# testing the various API functions
#
starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.sortedTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time sortedTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.minTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time minTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.maxTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time maxTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.floorTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time floorTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.ceilingTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time ceilingTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    rangeValues = sorted([testDataGen.getTradeValue(), testDataGen.getTradeValue()])
    output = testPlatform.rangeTransactions(testDataGen.getStockName(), rangeValues[0], rangeValues[1])
endtime = timeit.default_timer()
print("\nMean execution time rangeTransactions:", round((endtime-starttime)/numRuns,4))