# Stock Trading

The cell below defines the **abstract class** whose API you need to implement. **Do NOT modify it** - use the dedicated cell further below for your implementation instead.

In [1]:
# DO NOT MODIFY THIS CELL

from abc import ABC, abstractmethod  
      

# abstract class to represent a stock trading platform
class AbstractStockTradingPlatform(ABC):
    
    # constructor
    @abstractmethod
    def __init__(self):
        pass           
        
    # adds transactionRecord to the set of completed transactions
    @abstractmethod
    def logTransaction(self, transactionRecord):
        pass

    # returns a list with all transactions of a given stockName,
    # sorted by increasing trade value. 
    # stockName : str
    @abstractmethod
    def sortedTransactions(self, stockName): 
        sortedList = []
        return sortedList    
    
    # returns a list of transactions of a given stockName with minimum trade value
    # stockName : str
    @abstractmethod
    def minTransactions(self, stockName): 
        minList = []
        return minList    
    
    # returns a list of transactions of a given stockName with maximum trade value
    # stockName : str
    @abstractmethod
    def maxTransactions(self, stockName): 
        maxList = []
        return maxList    

    # returns a list of transactions of a given stockName, 
    # with the largest trade value below a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def floorTransactions(self, stockName, thresholdValue): 
        floorList = []
        return floorList    

    # returns a list of transactions of a given stockName, 
    # with the smallest trade value above a given thresholdValue.  
    # stockName : str
    # thresholdValue : double
    @abstractmethod
    def ceilingTransactions(self, stockName, thresholdValue): 
        ceilingList = []
        return ceilingList    

        
    # returns a list of transactions of a given stockName,  
    # whose trade value is within the range [fromValue, toValue].
    # stockName : str
    # fromValue : double
    # toValue : double
    @abstractmethod
    def rangeTransactions(self, stockName, fromValue, toValue): 
        rangeList = []
        return rangeList    

Use the cell below to define any data structure and auxiliary python function you may need. Leave the implementation of the main API to the next code cell instead.

In [2]:
# ADD AUXILIARY DATA STRUCTURE DEFINITIONS AND HELPER CODE HERE

import datetime as d
import typing as t


class Trade:
    def __init__(self, name: str, price: float, quantity: int, time: d.datetime) -> None:
        self.name = name
        self.price = price
        self.quantity = quantity
        self.time = time

    def get_trade_val(self) -> float:
        return self.price * self.quantity

    def to_list(self) -> list:
        # Note: Converting to list makes it easier to work with in some applications and can have performance advantages
        return [self.name, self.price, self.quantity, self.time]


class TradeNode:
    RED = True
    BLACK = False

    def __init__(self, trade: Trade) -> None:
        self.trade_val = trade.get_trade_val()

        # Array of all nodes with same value
        self.trades = [trade]

        self.left = None
        self.right = None
        self.color = TradeNode.RED

    def to_dict(self) -> t.Dict[float, t.List[list]]:
        # This has a performance advantage over storing the trade as a dictionary; list lookup is much
        # faster than kv lookup and takes better advantage of locality

        # Maps trade value to trade information in list form
        return {self.trade_val: [trade.to_list() for trade in self.trades]}


class TradeTree:
    # TradeTree models all information for a single stock
    # i.e. all trades on a given stock will be stored here
    # This has no model of any other src

    def __init__(self, stock_name: str) -> None:
        self.stock_name = stock_name
        self.root = None

    def put_trade(self, trade: Trade) -> None:
        if trade.name != self.stock_name:
            raise ValueError("Invalid Stock Name")

        self.root = self.__insert(trade, self.root)
        self.root.color = TradeNode.BLACK

    def __insert(self, trade: Trade, node: TradeNode) -> TradeNode:
        if node is None:
            return TradeNode(trade)

        trade_val = trade.get_trade_val()

        if trade_val == node.trade_val:
            node.trades.append(trade)
        elif trade_val < node.trade_val:
            node.left = self.__insert(trade, node.left)
        elif trade_val > node.trade_val:
            node.right = self.__insert(trade, node.right)

        return TradeTree.__balance(node)

    def get_all_trades(self, node: TradeNode = None) -> t.List[Trade]:
        if self.root is None:
            return []

        # Optional node parameter used to display some other tree or some subtree
        if node is None:
            node = self.root

        all_trades = []

        if node.left is not None:
            all_trades = self.get_all_trades(node.left)

        all_trades.extend(node.trades)

        if node.right is not None:
            all_trades.extend(self.get_all_trades(node.right))

        return all_trades

    def get_min_trades(self) -> t.List[Trade]:
        if self.root is None:
            return []

        node = self.root

        while node.left is not None:
            node = node.left

        return node.trades

    def get_max_trades(self) -> t.List[Trade]:
        if self.root is None:
            return []

        node = self.root

        while node.right is not None:
            node = node.right

        return node.trades

    def get_floor_trades(self, high: float) -> t.List[Trade]:
        node = self.root
        floor_trades = []

        while node is not None:
            if node.trade_val == high:
                return node.trades
            elif node.trade_val < high:
                floor_trades = node.trades
                node = node.right
            elif node.trade_val > high:
                node = node.left

        return floor_trades

    def get_ceil_trades(self, low: float) -> t.List[Trade]:
        node = self.root
        ceil_trades = []

        while node is not None:
            if node.trade_val == low:
                return node.trades
            elif node.trade_val > low:
                ceil_trades = node.trades
                node = node.left
            elif node.trade_val < low:
                node = node.right

        return ceil_trades

    def get_trades_in_range(self, low: float, high: float, node: TradeNode = None) -> t.List[Trade]:
        if low > high or low < 0:
            raise ValueError("Invalid Range")

        if self.root is None:
            return []

        if node is None:
            node = self.root

        trades_in_range = []

        if low <= node.trade_val <= high:
            if node.left is not None:
                trades_in_range = self.get_trades_in_range(low, high, node.left)

            trades_in_range.extend(node.trades)

            if node.right is not None:
                trades_in_range.extend(self.get_trades_in_range(low, high, node.right))
        elif node.trade_val < low and node.right is not None:
            trades_in_range = self.get_trades_in_range(low, high, node.right)
        elif node.trade_val > high and node.left is not None:
            trades_in_range = self.get_trades_in_range(low, high, node.left)

        return trades_in_range

    @staticmethod
    def __rotate_left(node: TradeNode) -> TradeNode:
        x = node.right
        node.right = x.left
        x.left = node
        x.color = node.color
        node.color = TradeNode.RED
        return x

    @staticmethod
    def __rotate_right(node: TradeNode) -> TradeNode:
        x = node.left
        node.left = x.right
        x.right = node
        x.color = node.color
        node.color = TradeNode.RED
        return x

    @staticmethod
    def __flip_colors(node: TradeNode) -> None:
        node.color = TradeNode.RED
        node.left.color = TradeNode.BLACK
        node.right.color = TradeNode.BLACK

    @staticmethod
    def __is_red(node: TradeNode) -> bool:
        return node.color == TradeNode.RED if node is not None else False

    @staticmethod
    def __balance(node: TradeNode) -> TradeNode:
        if TradeTree.__is_red(node.right) and not TradeTree.__is_red(node.left):
            node = TradeTree.__rotate_left(node)

        if TradeTree.__is_red(node.left) and TradeTree.__is_red(node.left.left):
            node = TradeTree.__rotate_right(node)

        if TradeTree.__is_red(node.left) and TradeTree.__is_red(node.right):
            TradeTree.__flip_colors(node)

        return node

Use the cell below to implement the requested API. 

In [3]:
# IMPLEMENT HERE THE REQUESTED API

# noinspection PyPep8Naming
class StockTradingPlatform:
    def __init__(self) -> None:
        # noinspection SpellCheckingInspection
        self.STOCKS = ["Barclays", "HSBA", "Lloyds Banking Group", "NatWest Group", "Standard Chartered", "3i",
                       "Abrdn", "Hargreaves Lansdown", "London Stock Exchange Group", "Pershing Square Holdings",
                       "Schroders", "St. James's Place plc."]

        self.__trade_trees = {}

        for stock in self.STOCKS:
            self.__trade_trees[stock] = TradeTree(stock)

    def logTransaction(self, transactionRecord: list) -> None:
        trade = Trade(*transactionRecord)
        self.__validate_trade(trade)
        self.__trade_trees[trade.name].put_trade(trade)

    def sortedTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("sortedTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_all_trades()

    def minTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("minTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_min_trades()

    def maxTransactions(self, stockName: str) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("maxTransactions: Invalid Stock Name: " + stockName)

        return self.__trade_trees[stockName].get_max_trades()

    def floorTransactions(self, stockName: str, thresholdValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("floorTransactions: Invalid Stock Name: " + stockName)

        if thresholdValue < 0:
            raise ValueError("floorTransactions: Invalid Transaction Value: " + str(thresholdValue))

        return self.__trade_trees[stockName].get_floor_trades(thresholdValue)

    def ceilingTransactions(self, stockName: str, thresholdValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("ceilingTransactions: Invalid Stock Name: " + stockName)

        if thresholdValue < 0:
            raise ValueError("ceilingTransactions: Invalid Transaction Value: " + str(thresholdValue))

        return self.__trade_trees[stockName].get_ceil_trades(thresholdValue)

    def rangeTransactions(self, stockName: str, fromValue: float, toValue: float) -> t.List[Trade]:
        if stockName not in self.STOCKS:
            raise ValueError("rangeTransactions: Invalid Stock Name: " + stockName)

        if fromValue > toValue or fromValue < 0 or toValue < 0:
            raise ValueError("rangeTransactions: Invalid Range Bounds: "
                             "fromValue: " + str(fromValue) + " toValue: " + str(toValue))

        return self.__trade_trees[stockName].get_trades_in_range(fromValue, toValue)

    def __validate_trade(self, trade: Trade) -> None:
        if trade.name not in self.STOCKS:
            raise ValueError("Invalid Stock Name: " + trade.name)

        if trade.quantity < 1:
            raise ValueError("Invalid Stock Quantity: " + str(trade.quantity))

        if trade.price <= 0.0:
            raise ValueError("Invalid Stock Price: " + str(trade.price))

The cell below provides helper code that you can use within your experimental framework to generate random transaction data. **Do NOT modify it**.

In [4]:
# DO NOT MODIFY THIS CELL

import random
from datetime import timedelta
from datetime import datetime

class TransactionDataGenerator:
    def __init__(self):
        self.stockNames = ["Barclays", "HSBA", "Lloyds Banking Group", "NatWest Group", 
                      "Standard Chartered", "3i", "Abrdn", "Hargreaves Lansdown", 
                      "London Stock Exchange Group", "Pershing Square Holdings", 
                      "Schroders", "St. James's Place plc."]
        self.minTradeValue = 500.00
        self.maxTradeValue = 100000.00
        self.startDate = datetime.strptime('1/1/2022 1:00:00', '%d/%m/%Y %H:%M:%S')
        random.seed(20221603)
          
    # returns the name of a traded stock at random
    def getStockName(self):
        return random.choice(self.stockNames)

    # returns the trade value of a transaction at random
    def getTradeValue(self):
        return round(random.uniform(self.minTradeValue, self.maxTradeValue), 2)
    
    # returns a list of N randomly generated transactions,
    # where each transaction is represented as a list [stock name, price, quantity, timestamp]
    # N : int
    def generateTransactionData(self, N):   
        listTransactions = [[]]*N
        listDates = [self.startDate + timedelta(seconds=3*x) for x in range(0, N)]
        listDatesFormatted = [x.strftime('%d/%m/%Y %H:%M:%S') for x in listDates]
        for i in range(N):
            stockName = random.choice(self.stockNames)
            price = round(random.uniform(50.00, 100.00), 2)
            quantity = random.randint(10,1000)
            listTransactions[i] = [stockName, price, quantity, listDatesFormatted[i]]   
        return listTransactions

Use the cell below for the python code needed to realise your **experimental framework** (i.e., to generate test data, to instante the `StockTrading` class, to thorouhgly experiment with its API functions, and to experimentally measure their performance). You may use the previously provided ``TransactionDataGenerator`` class to generate random transaction data.

In [5]:
import random
import timeit

# ADD YOUR EXPERIMENTAL FRAMEWORK CODE HERE




The cell below exemplifies **debug** code I will invoke on your submission - it does not represent an experimental framework (which should me much more comprehensive). **Do NOT modify it**. 

In [4]:
# DO NOT MODIFY THIS CELL

import timeit

testPlatform = StockTradingPlatform()
testDataGen = TransactionDataGenerator()

numTransactions = 1000000
testData = testDataGen.generateTransactionData(numTransactions)

numRuns = 100

print("Examples of transactions:", testData[0], testData[numTransactions//2], testData[numTransactions-1])

#
# testing the logTransaction() API 
#
starttime = timeit.default_timer()
for i in range(numTransactions):
    testPlatform.logTransaction(testData[i])
endtime = timeit.default_timer()
print("\nExecution time to load", numTransactions, "transactions:", round(endtime-starttime,4))

#
# testing the various API functions
#
starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.sortedTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time sortedTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.minTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time minTransactions:", round((endtime-starttime)/numRuns,4))

starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.maxTransactions(testDataGen.getStockName())
endtime = timeit.default_timer()
print("\nMean execution time maxTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.floorTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time floorTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    output = testPlatform.ceilingTransactions(testDataGen.getStockName(), testDataGen.getTradeValue())
endtime = timeit.default_timer()
print("\nMean execution time ceilingTransactions:", round((endtime-starttime)/numRuns,4))


starttime = timeit.default_timer()
for i in range(numRuns):
    rangeValues = sorted([testDataGen.getTradeValue(), testDataGen.getTradeValue()])
    output = testPlatform.rangeTransactions(testDataGen.getStockName(), rangeValues[0], rangeValues[1])
endtime = timeit.default_timer()
print("\nMean execution time rangeTransactions:", round((endtime-starttime)/numRuns,4))

NameError: name 'TransactionDataGenerator' is not defined

In [8]:
from unittest import TestCase, defaultTestLoader, TextTestRunner
from datetime import datetime, timedelta
from random import choice, seed, uniform, randint

TradeList = list[str, float, int, datetime]

SAMPLE_DATE = datetime.strptime('1/1/2022 1:00:00', '%d/%m/%Y %H:%M:%S')
SAMPLE_STOCK = "HSBA"
SAMPLE_MIN_VAL = 100
SAMPLE_MAX_VAL = 100000
SAMPLE_SIZE = 100


class TestSets:
    def __init__(self):

        _platform = StockTradingPlatform()
        self.__stocks = tuple(_platform.STOCKS)
        self.__start_date = datetime.strptime('1/1/2022 1:00:00', '%d/%m/%Y %H:%M:%S')
        seed(20221603)
        self.__time_offset = 0

    @staticmethod
    def __get_value(min_val, max_val) -> tuple[int, int]:
        if min_val == max_val:
            return min_val, 1

        if max_val / 5 < min_val:
            return round(uniform(min_val, max_val)), 1

        return round(uniform(min_val, max_val / 5)), randint(1, 5)

    def gen_one_trade(self, stock, min_val, max_val) -> TradeList:
        self.__time_offset += 1
        value = self.__get_value(min_val, max_val)

        return [stock,
                value[0], value[1],
                self.__start_date + timedelta(seconds=self.__time_offset)]

    def trade_gen_many_same_stock(self, stock: str = SAMPLE_STOCK, min_val: int = SAMPLE_MIN_VAL,
                                  max_val: int = SAMPLE_MAX_VAL, n: int = SAMPLE_SIZE) -> list[TradeList]:
        trade_list = []
        self.__time_offset += 1
        trade_list.append([stock, min_val, 1, self.__start_date + timedelta(seconds=self.__time_offset)])

        for _ in range(n - 2):
            trade_list.append(self.gen_one_trade(stock, min_val, max_val))

        trade_list.append([stock, max_val, 1, self.__start_date + timedelta(seconds=self.__time_offset)])

        return trade_list

    def trade_gen_many_same_value(self, value: float, n: int = SAMPLE_SIZE) -> list[TradeList]:
        trade_list = []

        for i in range(n):
            trade_list.append(self.gen_one_trade(choice(self.__stocks), value, value))

        return trade_list

    def trade_gen_many(self, min_val: int = SAMPLE_MIN_VAL, max_val: int = SAMPLE_MAX_VAL, n: int = SAMPLE_SIZE) \
            -> list[TradeList]:
        trade_list = []

        for _ in range(n):
            trade_list.append(self.gen_one_trade(choice(self.__stocks), min_val, max_val))

        return trade_list

    def tree_gen_many_same_val(self, value: float, stock: str = SAMPLE_STOCK, n: int = SAMPLE_SIZE) \
            -> tuple[TradeTree, list[TradeList]]:
        tree = TradeTree(stock)
        trade_list = self.trade_gen_many_same_value(value, n=n)

        for trade in trade_list:
            tree.put_trade(Trade(*trade))

        return tree, trade_list

    def tree_gen_many(self, min_val: int = SAMPLE_MIN_VAL, max_val: int = SAMPLE_MAX_VAL, stock: str = SAMPLE_STOCK,
                      n: int = SAMPLE_SIZE) -> tuple[TradeTree, list[TradeList]]:

        tree = TradeTree(stock)
        trade_list = self.trade_gen_many_same_stock(stock, min_val, max_val, n=n)

        for trade in trade_list:
            tree.put_trade(Trade(*trade))

        return tree, trade_list

    def platform_gen_many_same_stock(self, stock: str, low: int = SAMPLE_MIN_VAL, high: int = SAMPLE_MAX_VAL,
                                     n: int = SAMPLE_SIZE) -> tuple[StockTradingPlatform, list[TradeList]]:

        platform = StockTradingPlatform()
        trade_list = self.trade_gen_many_same_stock(stock, low, high, n=n)

        for trade in trade_list:
            platform.logTransaction(trade)

        return platform, trade_list

    def platform_gen_many_same_val(self, value: float, n: int = SAMPLE_SIZE) \
            -> tuple[StockTradingPlatform, list[TradeList]]:

        platform = StockTradingPlatform()
        trade_list = self.trade_gen_many_same_value(value, n=n)

        for trade in trade_list:
            platform.logTransaction(trade)

        return platform, trade_list

    def platform_gen_many(self, low: int = SAMPLE_MIN_VAL, high: int = SAMPLE_MAX_VAL) \
            -> tuple[StockTradingPlatform, list[TradeList]]:

        platform = StockTradingPlatform()
        trade_list = self.trade_gen_many(low, high)

        for trade in trade_list:
            platform.logTransaction(trade)

        return platform, trade_list


test_sets = TestSets()


class TestStockTradingPlatform(TestCase):

    @staticmethod
    def __trades_equal(trade_1: Trade, trade_2: Trade) -> bool:
        return trade_1.name == trade_2.name and trade_1.price == trade_2.price \
               and trade_1.quantity == trade_2.quantity and trade_1.time == trade_2.time

    def test_log_bad_stock(self):
        sut = StockTradingPlatform()
        try:
            sut.logTransaction(["UCL Bank", 1, 1, SAMPLE_DATE])
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)

    def test_log_bad_stock_value(self):
        sut = StockTradingPlatform()
        try:
            sut.logTransaction([SAMPLE_STOCK, 0, 1, SAMPLE_DATE])
        except ValueError as e:
            self.assertEqual("Invalid Stock Price: 0", e.args[0])
        except:
            self.assertFalse(True)

    def test_log_bad_quantity(self):
        sut = StockTradingPlatform()
        try:
            sut.logTransaction([SAMPLE_STOCK, 100, 0, SAMPLE_DATE])
        except ValueError as e:
            self.assertEqual("Invalid Stock Quantity: 0", e.args[0])
        except:
            self.assertFalse(True)

    def test_log_insert_first(self):
        sut = StockTradingPlatform()
        t = [SAMPLE_STOCK, 100, 2, SAMPLE_DATE]

        sut.logTransaction(t)

        result = sut.sortedTransactions(SAMPLE_STOCK)

        self.assertEqual(len(result), 1)
        self.assertTrue(self.__trades_equal(result[0], Trade(*t)))

    def test_log_many_of_one(self):
        # This generates 100 trades for HSBA with a min value equal to SAMPLE_MIN_VAL and a max value of SAMPLE_MAX_VAL
        sut, test_trades = test_sets.platform_gen_many_same_stock(stock=SAMPLE_STOCK,
                                                                  low=SAMPLE_MIN_VAL, high=SAMPLE_MAX_VAL)

        trade_list = sut.sortedTransactions(SAMPLE_STOCK)

        # Assert right number of trades were inserted
        self.assertEqual(len(test_trades), len(trade_list))

        # Assert trade with correct minimum value was inserted
        self.assertEqual(sut.minTransactions(SAMPLE_STOCK)[0].get_trade_val(), SAMPLE_MIN_VAL)

        # Assert trade with correct maximum value was inserted
        self.assertEqual(sut.maxTransactions(SAMPLE_STOCK)[0].get_trade_val(), SAMPLE_MAX_VAL)

    def test_log_one_of_each(self):
        sut = StockTradingPlatform()

        for stock in sut.STOCKS:
            sut.logTransaction(test_sets.gen_one_trade(stock, SAMPLE_MIN_VAL, SAMPLE_MAX_VAL))

        for stock in sut.STOCKS:
            self.assertEqual(len(sut.sortedTransactions(stock)), 1)

    def test_log(self):
        sut = StockTradingPlatform()

        sut.logTransaction(["London Stock Exchange Group",
                            1000, 5,
                            datetime.strptime("2020-02-25T22:00:15", "%Y-%m-%dT%H:%M:%S")])

        self.assertTrue(True)

    def test_log_all_same_trade_val(self):
        sut, trades = test_sets.platform_gen_many_same_val(250)

        for stock in sut.STOCKS:
            self.assertEqual(sut.minTransactions(stock), sut.maxTransactions(stock))

    def test_log_some_conflicts(self):
        trades1 = test_sets.trade_gen_many_same_value(550)
        trades2 = test_sets.trade_gen_many(min_val=SAMPLE_MIN_VAL, max_val=SAMPLE_MAX_VAL)
        sut = StockTradingPlatform()

        for trade in trades1 + trades2:
            sut.logTransaction(trade)

        total_len = 0
        for stock in sut.STOCKS:
            total_len += len(sut.sortedTransactions(stock))

        self.assertEqual(total_len, len(trades1) + len(trades2))

    def test_sorted_transactions_empty(self):
        sut = StockTradingPlatform()

        result = sut.sortedTransactions(SAMPLE_STOCK)

        self.assertEqual(result, [])

    def test_sorted_one(self):
        t = test_sets.gen_one_trade(SAMPLE_STOCK, min_val=SAMPLE_MIN_VAL, max_val=SAMPLE_MIN_VAL)
        sut = StockTradingPlatform()
        sut.logTransaction(t)

        result = sut.sortedTransactions(SAMPLE_STOCK)

        self.assertTrue(self.__trades_equal(result[0], Trade(*t)))

    def test_sorted_many(self):
        sut, trades = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK)
        trades.sort(key=lambda x: x[1] * x[2])

        sorted_trades = sut.sortedTransactions(SAMPLE_STOCK)

        for index, trade in enumerate(trades):
            self.assertTrue(self.__trades_equal(sorted_trades[index], Trade(*trade)))

    def test_sorted_all_one_val(self):
        sut, _ = test_sets.platform_gen_many_same_val(500)

        trade_list = sut.sortedTransactions(SAMPLE_STOCK)

        # Essentially, just make sure nothing breaks
        for trade in trade_list:
            self.assertEqual(trade.get_trade_val(), 500)

    def test_min_transactions_none(self):
        sut = StockTradingPlatform()

        min_t = sut.minTransactions(SAMPLE_STOCK)

        self.assertEqual([], min_t)

    def test_min_transactions_all_same(self):
        sut, _ = test_sets.platform_gen_many_same_val(5000)

        min_trades = sut.minTransactions(SAMPLE_STOCK)

        for trade in min_trades:
            self.assertEqual(trade.get_trade_val(), 5000)

    def test_min_transactions(self):
        sut, _ = test_sets.platform_gen_many_same_stock(low=SAMPLE_MIN_VAL, high=SAMPLE_MAX_VAL, stock=SAMPLE_STOCK)

        min_set = sut.minTransactions(SAMPLE_STOCK)

        self.assertEqual(SAMPLE_MIN_VAL, min_set[0].get_trade_val())

    def test_min_transactions_one(self):
        sut = StockTradingPlatform()
        t = [SAMPLE_STOCK, 500, 2, SAMPLE_DATE]
        sut.logTransaction(t)

        min_set = sut.minTransactions(SAMPLE_STOCK)

        self.assertEqual(len(min_set), 1)
        self.assertTrue(self.__trades_equal(min_set[0], Trade(*t)))

    def test_min_bad_name(self):
        sut = StockTradingPlatform()

        try:
            sut.minTransactions("UCL Bank")
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "minTransactions: Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)

    def test_max_transactions_none(self):
        sut = StockTradingPlatform()

        max_t = sut.minTransactions(SAMPLE_STOCK)

        self.assertEqual([], max_t)

    def test_max_transactions_all_same(self):
        sut, _ = test_sets.platform_gen_many_same_val(5000)

        max_trades = sut.maxTransactions(SAMPLE_STOCK)

        for trade in max_trades:
            self.assertEqual(trade.get_trade_val(), 5000)

    def test_max_transactions(self):
        sut, _ = test_sets.platform_gen_many_same_stock(low=SAMPLE_MIN_VAL, high=SAMPLE_MAX_VAL, stock=SAMPLE_STOCK)

        max_set = sut.maxTransactions(SAMPLE_STOCK)

        self.assertEqual(SAMPLE_MAX_VAL, max_set[0].get_trade_val())

    def test_max_transactions_one(self):
        sut = StockTradingPlatform()
        t = [SAMPLE_STOCK, 500, 2, SAMPLE_DATE]
        sut.logTransaction(t)

        max_set = sut.minTransactions(SAMPLE_STOCK)

        self.assertEqual(len(max_set), 1)
        self.assertTrue(self.__trades_equal(max_set[0], Trade(*t)))

    def test_max_bad_name(self):
        sut = StockTradingPlatform()

        try:
            sut.maxTransactions("UCL Bank")
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "maxTransactions: Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)

    def test_floor_transactions_empty(self):
        sut = StockTradingPlatform()

        floor = sut.floorTransactions(SAMPLE_STOCK, 100)

        self.assertEqual(floor, [])

    def test_floor_below_min(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, low=SAMPLE_MIN_VAL)

        floor = sut.floorTransactions(SAMPLE_STOCK, SAMPLE_MIN_VAL - 1)

        self.assertEqual([], floor)

    def test_floor_above_max(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, high=SAMPLE_MAX_VAL)

        floor = sut.floorTransactions(SAMPLE_STOCK, SAMPLE_MAX_VAL + 1)

        self.assertEqual(floor[0].get_trade_val(), SAMPLE_MAX_VAL)

    def test_floor_equal_min(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, low=SAMPLE_MIN_VAL)

        floor = sut.floorTransactions(SAMPLE_STOCK, SAMPLE_MIN_VAL)

        self.assertEqual(SAMPLE_MIN_VAL, floor[0].get_trade_val())

    def test_floor_bad_name(self):
        sut = StockTradingPlatform()

        try:
            sut.floorTransactions("UCL Bank", 200)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "floorTransactions: Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)

    def test_ceiling_transactions_empty(self):
        sut = StockTradingPlatform()  #

        ceiling = sut.ceilingTransactions(SAMPLE_STOCK, SAMPLE_MIN_VAL)

        self.assertEqual(ceiling, [])

    def test_ceiling_above_max(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, high=SAMPLE_MAX_VAL)

        ceiling = sut.ceilingTransactions(SAMPLE_STOCK, SAMPLE_MAX_VAL + 1)

        self.assertEqual([], ceiling)

    def test_ceiling_below_min(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, low=SAMPLE_MIN_VAL)

        ceiling = sut.ceilingTransactions(SAMPLE_STOCK, SAMPLE_MIN_VAL - 1)

        self.assertEqual(ceiling[0].get_trade_val(), SAMPLE_MIN_VAL)

    def test_ceiling_equal_max(self):
        sut, _ = test_sets.platform_gen_many_same_stock(SAMPLE_STOCK, high=SAMPLE_MAX_VAL)

        ceiling = sut.floorTransactions(SAMPLE_STOCK, SAMPLE_MAX_VAL)

        self.assertEqual(SAMPLE_MAX_VAL, ceiling[0].get_trade_val())

    def test_ceiling_bad_name(self):
        sut = StockTradingPlatform()

        try:
            sut.ceilingTransactions("UCL Bank", SAMPLE_MAX_VAL)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "ceilingTransactions: Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)

    def test_range_bad_range(self):
        sut = StockTradingPlatform()

        try:
            sut.rangeTransactions(SAMPLE_STOCK, fromValue=101, toValue=99)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "rangeTransactions: Invalid Range Bounds: fromValue: 101 toValue: 99")
        except:
            self.assertFalse(True)

    def test_range_inclusive_below(self):
        sut = StockTradingPlatform()
        sut.logTransaction([SAMPLE_STOCK, 100, 2, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 500, 3, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 1000, 4, SAMPLE_DATE])

        range_set = sut.rangeTransactions(SAMPLE_STOCK, 200, 300)

        self.assertEqual(1, len(range_set))
        self.assertTrue(self.__trades_equal(range_set[0], Trade(*[SAMPLE_STOCK, 100, 2, SAMPLE_DATE])))

    def test_range_inclusive_above(self):
        sut = StockTradingPlatform()
        sut.logTransaction([SAMPLE_STOCK, 100, 2, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 500, 3, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 1000, 4, SAMPLE_DATE])

        range_set = sut.rangeTransactions(SAMPLE_STOCK, 2000, 4000)

        self.assertEqual(1, len(range_set))
        self.assertTrue(self.__trades_equal(range_set[0], Trade(*[SAMPLE_STOCK, 1000, 4, SAMPLE_DATE])))

    def test_range_equal_to_stock(self):
        sut = StockTradingPlatform()
        sut.logTransaction([SAMPLE_STOCK, 100, 2, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 500, 3, SAMPLE_DATE])
        sut.logTransaction([SAMPLE_STOCK, 1000, 4, SAMPLE_DATE])

        range_set = sut.rangeTransactions(SAMPLE_STOCK, 1500, 1500)

        self.assertEqual(1, len(range_set))
        self.assertTrue(self.__trades_equal(range_set[0], Trade(*[SAMPLE_STOCK, 500, 3, SAMPLE_DATE])))

    def test_range_bad_name(self):
        sut = StockTradingPlatform()

        try:
            sut.rangeTransactions("UCL Bank", 200, 500)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "rangeTransactions: Invalid Stock Name: UCL Bank")
        except:
            self.assertFalse(True)


class TestStockTradeLog(TestCase):

    def __assert_trade_lists_contain_same_elems(self, trade_list_1: list[Trade], trade_list_2: list[Trade]):
        trade_list_1 = trade_list_1.copy()
        trade_list_2 = trade_list_2.copy()

        trade_list_1.sort(key=lambda x: x.to_list()[1] * x.to_list()[2])
        trade_list_2.sort(key=lambda x: x.to_list()[1] * x.to_list()[2])

        if len(trade_list_1) != len(trade_list_2):
            self.assertTrue(False)
            return

        for i in range(len(trade_list_1)):
            if trade_list_1[i].to_list() != trade_list_2[i].to_list():
                print(trade_list_1[i].to_list(), trade_list_2[i].to_list())
                self.assertTrue(False)
                return

    def _test_trade_add(self, log: TradeTree, t: Trade):
        log.put_trade(t)

        self.assertEqual(log.get_trades_in_range(
            t.get_trade_val(), t.get_trade_val()), [t])

    def test_add_trade_empty(self):
        log = TradeTree(SAMPLE_STOCK)
        t = Trade(SAMPLE_STOCK, 123.4, 3, SAMPLE_DATE)

        self._test_trade_add(log, t)
        self.assertEqual(log.get_min_trades(), [t])
        self.assertEqual(log.get_max_trades(), [t])

    def test_add_trade_busy(self):
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK)
        t = Trade(SAMPLE_STOCK, 123.4, 3, SAMPLE_DATE)

        self._test_trade_add(log, t)

    def test_add_lots_trades(self):
        final_log, trades = test_sets.tree_gen_many(stock=SAMPLE_STOCK)
        test_log = TradeTree(SAMPLE_STOCK)

        for trade in trades:
            test_log.put_trade(Trade(*trade))

        self.__assert_trade_lists_contain_same_elems(test_log.get_all_trades(), final_log.get_all_trades())

    def test_bad_stock(self):
        t = Trade("Lloyds", 23.4, 1, SAMPLE_DATE)
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK)

        # Expect failure as we use a stock not used in constructor
        try:
            self._test_trade_add(log, t)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "Invalid Stock Name")
        except:
            self.assertFalse(True)

    def test_single_min_trade(self):
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK, min_val=SAMPLE_MIN_VAL)

        self.assertEqual(log.get_min_trades()[0].get_trade_val(), SAMPLE_MIN_VAL)

    def test_many_min_trade(self):
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK, min_val=SAMPLE_MIN_VAL)

        min_trade = Trade(SAMPLE_STOCK, SAMPLE_MIN_VAL, 1, SAMPLE_DATE)

        log.put_trade(min_trade)

        self.assertTrue(len(log.get_min_trades()) >= 2)
        self.assertEqual(log.get_min_trades()[0].get_trade_val(), SAMPLE_MIN_VAL)

    def test_single_max_trade(self):
        log = TradeTree(SAMPLE_STOCK)

        log.put_trade(Trade(SAMPLE_STOCK, SAMPLE_MAX_VAL, 1, SAMPLE_DATE))

        self.assertEqual(log.get_max_trades()[0].get_trade_val(), SAMPLE_MAX_VAL)

    def test_many_max_trade(self):
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK, max_val=SAMPLE_MAX_VAL)

        max_trade = Trade(SAMPLE_STOCK, SAMPLE_MAX_VAL, 1, SAMPLE_DATE)

        log.put_trade(max_trade)

        self.assertTrue(len(log.get_max_trades()) >= 2)
        self.assertEqual(log.get_max_trades()[0].get_trade_val(), SAMPLE_MAX_VAL)

    def test_trade_range_all(self):
        log, trades = test_sets.tree_gen_many(stock=SAMPLE_STOCK, max_val=SAMPLE_MAX_VAL)

        trades = [Trade(*trade) for trade in trades]

        self.__assert_trade_lists_contain_same_elems(log.get_trades_in_range(0, SAMPLE_MAX_VAL), trades)

    def test_trade_range_none(self):
        log, _ = test_sets.tree_gen_many(stock=SAMPLE_STOCK, min_val=SAMPLE_MIN_VAL)

        self.assertEqual(log.get_trades_in_range(0.1, 0.2), [])

    def test_range_bad_min(self):
        log = TradeTree(SAMPLE_STOCK)

        try:
            log.get_trades_in_range(-1, 4)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "Invalid Range")
        except:
            self.assertFalse(True)

    def test_range_bad_max(self):
        log = TradeTree(SAMPLE_STOCK)

        try:
            log.get_trades_in_range(5, 4)
            self.assertFalse(True)
        except ValueError as e:
            self.assertEqual(e.args[0], "Invalid Range")
        except:
            self.assertFalse(True)

    def test_trade_range_one(self):
        log = TradeTree(SAMPLE_STOCK)
        t = Trade(SAMPLE_STOCK, 116, 1, SAMPLE_DATE)

        log.put_trade(t)

        self.assertEqual(log.get_trades_in_range(116, 116), [t])


class TestTrade(TestCase):
    def test_good_trade_constructor(self):
        # Basically just confirm no errors occur
        t = Trade(SAMPLE_STOCK, 123.456, 5, SAMPLE_DATE)
        self.assertTrue(t is not None)

    def test_get_trade_val_single(self):
        t = Trade(SAMPLE_STOCK, 123, 1, SAMPLE_DATE)
        self.assertEqual(t.get_trade_val(), 123)

    def test_get_trade_val_multi(self):
        t = Trade(SAMPLE_STOCK, 123, 3, SAMPLE_DATE)

        self.assertEqual(t.get_trade_val(), 369)

if __name__ == '__main__':
    platform_tests = defaultTestLoader.loadTestsFromTestCase(TestStockTradingPlatform)
    trade_tests = defaultTestLoader.loadTestsFromTestCase(TestTrade)
    tree_tests = defaultTestLoader.loadTestsFromTestCase(TestStockTradeLog)
    TextTestRunner().run(platform_tests)
    TextTestRunner().run(tree_tests)
    TextTestRunner().run(trade_tests)


......................................
----------------------------------------------------------------------
Ran 38 tests in 0.021s

OK
.............
----------------------------------------------------------------------
Ran 13 tests in 0.011s

OK
...
----------------------------------------------------------------------
Ran 3 tests in 0.001s

OK
