In [1]:
from datetime import datetime, timedelta
import pandas as pd
import plotly.figure_factory as ff
from ortools.sat.python import cp_model
from bisect import bisect_left, bisect_right
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import unittest
# from src.objects.timeinterval import *
# from src.utils.utils import *
import uuid

In [2]:
intervals = TimeIntervals([TimeInterval(datetime(2023, 2, 1, 10), datetime(2023, 2, 1, 11)),
                           TimeInterval(datetime(2023, 2, 1, 12), datetime(2023, 2, 1, 13)),
                           TimeInterval(datetime(2023, 2, 1, 14), datetime(2023, 2, 1, 16))])
intervals.visualize_gantt()


In [None]:
# Design classes for Shift scheduling problem (Employee, Shift, Schedule)

# Employee class
# Attributes: name: str, availability: TimeIntervals, assigned shifts: list of Shifts
# Methods: change_availability, change_name

# Shift class
# Attributes: name: str, type: str, interval: TimeInterval, min employees: int, max employees: int, assigned employees: list of Employees
# Methods: assign_employee, unassign_employee, change_interval, change_min_employees, change_max_employees, change_name, change_type

# Single Shift class, which inherits from the Shift class
# Attributes: None
# Methods: None

# Recurring Shift class, which inherits from the Shift class
# Attributes: recurrence interval: timedelta, start date: datetime, end date: datetime, recurrence shifts: list of Shifts
# Methods: get_recurrence_intervals, change_recurrence_interval, change_start_date, change_end_date, get_recurrence_shifts (returns a list of Shifts)

# Schedule class
# Attributes: name: str, employees: list of Employees, shifts: list of Shifts, shift types: list of str, start time: datetime, end time: datetime
# Methods: add_employee, add_shift, reset, change_name, show, change_start_time, change_end_time, solve (solve the shift scheduling problem with OR-Tools and return the solution)



In [35]:
class TimeInterval:
    def __init__(self, start: datetime, end: datetime):
        assert start < end, 'start must be less than end'
        self._start = start
        self._end = end

    @property
    def start(self):
        return self._start

    @start.setter
    def start(self, start: datetime):
        self._start = start

    @property
    def end(self):
        return self._end

    @end.setter
    def end(self, end: datetime):
        self._end = end

    def __repr__(self) -> str:
        return f'TimeInterval({self.start}, {self.end})'

    def __str__(self) -> str:
        return f'{self.start} - {self.end}'

    def __eq__(self, other) -> bool:
        assert isinstance(other, TimeInterval), 'compared object must be of type TimeInterval'
        return self.start == other.start and self.end == other.end

    def overlaps(self, other) -> bool:
        return self.start < other.end and self.end > other.start

    def contains(self, other) -> bool:
        return self.start <= other.start and self.end >= other.end

    def __contains__(self, other) -> bool:
        return self.contains(other)

    def __add__(self, other) -> list:
        """
        Returns a list of TimeIntervals that are the union of self and other
        """
        if self.overlaps(other):
            return [TimeInterval(min(self.start, other.start), max(self.end, other.end))]
        else:
            return [self, other]

    def __sub__(self, other) -> list:
        """
        Returns a list of TimeIntervals that are the difference of self and other
        """
        if self.overlaps(other):
            if self.start < other.start:
                return [TimeInterval(self.start, other.start)]
            elif self.end > other.end:
                return [TimeInterval(other.end, self.end)]
            else:
                return []
        else:
            return [self]

    def copy(self):
        return TimeInterval(self.start, self.end)



class TimeIntervals:
    """
    A class representing a list of time intervals, sorted by start time, and merged if necessary (no overlapping intervals)
    """

    def __init__(self, intervals: list = []):
        assert all([isinstance(interval, TimeInterval) for interval in intervals]), 'All elements in the list must be of type TimeInterval'
        self._intervals = self._merge(intervals)

    def __repr__(self) -> str:
        return f'TimeIntervals({self._intervals})'

    @staticmethod
    def _sort(intervals: list):
        intervals.sort(key=lambda x: x.start)
        return intervals
    
    @staticmethod
    def _merge(intervals: list):
        intervals = TimeIntervals._sort(intervals.copy())
        merged = []
        for interval in intervals:
            if not merged or merged[-1].end < interval.start:
                merged.append(interval)
            else:
                merged[-1].end = max(merged[-1].end, interval.end)
        return merged

    @property
    def intervals(self):
        return self._intervals

    @intervals.setter
    def intervals(self, intervals: list):
        self._intervals = TimeIntervals._merge(intervals)

    def __eq__(self, other) -> bool:
        assert isinstance(other, TimeIntervals), 'compared object must be of type TimeIntervals'
        return self._intervals == other._intervals

    def add(self, interval: TimeInterval):
        self._intervals = TimeIntervals._merge(self._intervals + [interval.copy()])

    def remove(self, interval: TimeInterval):
        self._intervals = self._subtract(self._intervals, interval)

    @staticmethod
    def _subtract(intervals: list, interval: TimeInterval):
        intervals = TimeIntervals._sort(intervals)
        subtracted = []
        for i in intervals:
            if i.overlaps(interval):
                if i.start < interval.start:
                    subtracted.append(TimeInterval(i.start, interval.start))
                if i.end > interval.end:
                    subtracted.append(TimeInterval(interval.end, i.end))
            else:
                subtracted.append(i)
        return subtracted

    def __add__(self, other) -> list: # type: ignore        
        """
        Returns TimeIntervals that are the union of self and other
        """
        if isinstance(other, TimeIntervals):
            return TimeIntervals(self.intervals + other.intervals) # type: ignore
        elif isinstance(other, TimeInterval):
            return TimeIntervals(self.intervals + [other]) # type: ignore
        else:
            raise TypeError('other must be of type TimeIntervals or TimeInterval')

    def __sub__(self, other) -> list:
        """
        Returns a list of TimeIntervals that are the difference of self and other
        """
        if isinstance(other, TimeIntervals):
            return TimeIntervals._subtract(self.intervals, other.intervals) # type: ignore 
        elif isinstance(other, TimeInterval):
            return TimeIntervals._subtract(self.intervals, other)
        else:
            raise TypeError('other must be of type TimeIntervals or TimeInterval')

    
    # overlaps with another TimeIntervals object or a TimeInterval object
    def overlaps(self, other) -> bool:
        if isinstance(other, TimeIntervals):
            for interval in self.intervals:
                for other_interval in other.intervals:
                    if interval.overlaps(other_interval):
                        return True
            return False
        elif isinstance(other, TimeInterval):
            for interval in self.intervals:
                if interval.overlaps(other):
                    return True
            return False
        else:
            raise TypeError('other must be of type TimeIntervals or TimeInterval')

    # contains another TimeIntervals object or a TimeInterval object
    def contains(self, other) -> bool:
        if isinstance(other, TimeIntervals):
            for interval in other.intervals:
                if not self.contains(interval):
                    return False
            return True
        elif isinstance(other, TimeInterval):
            for interval in self.intervals:
                if interval.contains(other):
                    return True
            return False
        else:
            raise TypeError('other must be of type TimeIntervals or TimeInterval')

    def __contains__(self, other) -> bool:
        assert isinstance(other, TimeIntervals) or isinstance(other, TimeInterval), 'other must be of type TimeIntervals or TimeInterval'
        return self.contains(other)



    

In [38]:
# Test the TimeInterval classes
# 1. test initialization
# 2. test overlaps
# 3. test contains
# 4. test subtraction
# 5. test addition

class TestTimeInterval(unittest.TestCase):
    dt1 = datetime(2023, 2, 1, 10)
    dt2 = datetime(2023, 2, 1, 11)
    dt3 = datetime(2023, 2, 1, 12)
    dt4 = datetime(2023, 2, 1, 13)

    ti1 = TimeInterval(dt1, dt2)
    ti2 = TimeInterval(dt2, dt3)
    ti3 = TimeInterval(dt3, dt4)
    ti4 = TimeInterval(dt1, dt4)
    ti5 = TimeInterval(dt1, dt3)
    ti6 = TimeInterval(dt2, dt4)


    def test_init(self):
        self.assertEqual(self.ti1.start, self.dt1)
        self.assertEqual(self.ti1.end, self.dt2)
    
    def test_overlaps(self):
        self.assertTrue(self.ti1.overlaps(self.ti4))
        self.assertTrue(self.ti4.overlaps(self.ti1))
        self.assertTrue(self.ti5.overlaps(self.ti6))
        self.assertTrue(self.ti6.overlaps(self.ti5))
        self.assertFalse(self.ti1.overlaps(self.ti2))
        self.assertFalse(self.ti2.overlaps(self.ti1))
        self.assertFalse(self.ti1.overlaps(self.ti3))
        self.assertFalse(self.ti3.overlaps(self.ti1))


    def test_contains(self):
        self.assertTrue(self.ti4.contains(self.ti1))
        self.assertTrue(self.ti4.contains(self.ti2))
        self.assertFalse(self.ti1.contains(self.ti4))
        self.assertFalse(self.ti2.contains(self.ti4))
        self.assertFalse(self.ti1.contains(self.ti2))
        self.assertFalse(self.ti2.contains(self.ti1))
        self.assertTrue(self.ti1.contains(self.ti1))

    def test_subtraction(self):
        self.assertEqual(self.ti1 - self.ti2, [self.ti1])
        self.assertEqual(self.ti2 - self.ti1, [self.ti2])
        self.assertEqual(self.ti1 - self.ti4, [])
        self.assertEqual(self.ti4 - self.ti1, [self.ti6])

unittest.main(argv=[''], verbosity=2, exit=False)



test_add_availability (__main__.TestEmployee) ... FAIL
test_init (__main__.TestEmployee) ... ok
test_remove_availability (__main__.TestEmployee) ... ERROR
test_contains (__main__.TestTimeInterval) ... ok
test_init (__main__.TestTimeInterval) ... ok
test_overlaps (__main__.TestTimeInterval) ... ok
test_subtraction (__main__.TestTimeInterval) ... ok
test_add (__main__.TestTimeIntervals) ... FAIL
test_eq (__main__.TestTimeIntervals) ... ok
test_init (__main__.TestTimeIntervals) ... ok
test_remove (__main__.TestTimeIntervals) ... ok

ERROR: test_remove_availability (__main__.TestEmployee)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/qm/hc980d_116zfqwx_9dxh3dmm0000gn/T/ipykernel_58110/2286337519.py", line 49, in test_remove_availability
    self.assertEqual(self.emp.availability, self.intervals2 + self.intervals3)
  File "/var/folders/qm/hc980d_116zfqwx_9dxh3dmm0000gn/T/ipykernel_58110/1319616489.py", line 139

2021-01-01 08:00:00 - 2021-01-01 13:00:00
TimeIntervals([])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 13:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 17:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 17:00:00)])
TimeIntervals([TimeInterval(2021-01-01 13:00:00, 2021-01-01 17:00:00)])


<unittest.main.TestProgram at 0x1372cf4c0>

In [39]:
# Test the TimeIntervals class
# 1. test initialization
# 2. test add
# 3. test remove
class TestTimeIntervals(unittest.TestCase):
    _iv1 = TimeInterval(datetime(2023, 2, 1, 10), datetime(2023, 2, 1, 12))
    _iv2 = TimeInterval(datetime(2023, 2, 1, 11), datetime(2023, 2, 1, 13))
    _iv3 = TimeInterval(datetime(2023, 2, 1, 12), datetime(2023, 2, 1, 14))

    def test_init(self):
        intervals = TimeIntervals([self._iv1, self._iv2, self._iv3])
        self.assertEqual(intervals.intervals, [TimeInterval(datetime(2023, 2, 1, 10), datetime(2023, 2, 1, 14))])

    def test_add(self):
        intervals = TimeIntervals([self._iv1])
        intervals.add(self._iv2)
        self.assertEqual(intervals.intervals, [TimeInterval(datetime(2023, 2, 1, 10), datetime(2023, 2, 1, 13))])
        intervals.add(self._iv3)
        self.assertEqual(self._iv2, TimeInterval(datetime(2023, 2, 1, 11), datetime(2023, 2, 1, 13)))

    def test_remove(self):
        intervals = TimeIntervals([self._iv1, self._iv2, self._iv3])
        intervals.remove(self._iv2)
        self.assertEqual(intervals.intervals, [TimeInterval(datetime(2023, 2, 1, 10), datetime(2023, 2, 1, 11)), TimeInterval(datetime(2023, 2, 1, 13), datetime(2023, 2, 1, 14))])
        
    def test_eq(self):
        intervals1 = TimeIntervals([self._iv1, self._iv2, self._iv3])
        intervals2 = TimeIntervals([self._iv1, self._iv3])
        self.assertTrue(intervals1 == intervals2)
        intervals1.remove(self._iv3)
        self.assertFalse(intervals1 == intervals2)

unittest.main(argv=[''], verbosity=2, exit=False)




test_add_availability (__main__.TestEmployee) ... FAIL
test_init (__main__.TestEmployee) ... ok
test_remove_availability (__main__.TestEmployee) ... ERROR
test_contains (__main__.TestTimeInterval) ... ok
test_init (__main__.TestTimeInterval) ... ok
test_overlaps (__main__.TestTimeInterval) ... ok
test_subtraction (__main__.TestTimeInterval) ... ok
test_add (__main__.TestTimeIntervals) ... ok
test_eq (__main__.TestTimeIntervals) ... ok
test_init (__main__.TestTimeIntervals) ... ok
test_remove (__main__.TestTimeIntervals) ... ok

ERROR: test_remove_availability (__main__.TestEmployee)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/qm/hc980d_116zfqwx_9dxh3dmm0000gn/T/ipykernel_58110/2286337519.py", line 49, in test_remove_availability
    self.assertEqual(self.emp.availability, self.intervals2 + self.intervals3)
  File "/var/folders/qm/hc980d_116zfqwx_9dxh3dmm0000gn/T/ipykernel_58110/1319616489.py", line 139, 

2021-01-01 08:00:00 - 2021-01-01 13:00:00
TimeIntervals([])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 13:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 17:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 17:00:00)])
TimeIntervals([TimeInterval(2021-01-01 13:00:00, 2021-01-01 17:00:00)])


<unittest.main.TestProgram at 0x13727c400>

In [32]:
# Employee class
class Employee:
    def __init__(self, name: str, role: str):
        self._name = name
        self._role = role
        self._availability = TimeIntervals([])
        self.assigned_shifts = []

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name: str):
        self._name = name

    @property
    def availability(self):
        return self._availability

    # @availability.setter
    # def availability(self, availability: TimeIntervals):
    #     self._availability = availability

    @property
    def role(self):
        return self._role
    
    @role.setter
    def role(self, role: str):
        self._role = role
    
    def _add_availability(self, interval: TimeInterval):
        # add copy of interval to availability
        self._availability.add(interval.copy())

    def _remove_availability(self, interval: TimeInterval):
        self._availability.remove(interval)

    def _clear_availability(self):
        self._availability = TimeIntervals([])



# Shift class
class Shift:
    def __init__(self, interval: TimeInterval, min_employees: int, max_employees: int, name: str = "New shift", type: str = "Uncategorized"):
        assert min_employees <= max_employees, "Minimum number of employees must be less than or equal to the maximum number of employees"
        self._name = name
        self._type = type
        self._interval = interval
        self._min_employees = min_employees
        self._max_employees = max_employees
        self._assigned_employees = []

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name: str):
        self._name = name

    @property
    def type(self):
        return self._type
    
    @type.setter
    def type(self, type: str):
        self._type = type
    
    @property
    def interval(self):
        return self._interval
    
    @interval.setter
    def interval(self, interval: TimeInterval):
        self._interval = interval

    @property
    def min_employees(self):
        return self._min_employees

    @min_employees.setter
    def min_employees(self, min_employees: int):
        self._min_employees = min_employees

    @property
    def max_employees(self):
        return self._max_employees

    @max_employees.setter
    def max_employees(self, max_employees: int):
        self._max_employees = max_employees

    @property
    def date(self):
        return self.interval.start.date()

    @property
    def weekday(self):
        return self.interval.start.weekday()

    @property
    def start(self):
        return self.interval.start.time()

    @property
    def end(self):
        return self.interval.end.time()

    @property
    def duration(self):
        return self.interval.end - self.interval.start

    def _add_employee(self, employee: Employee):
        self._assigned_employees.append(employee)

    def _remove_employee(self, employee: Employee):
        self._assigned_employees.remove(employee)



# Schedule class
class Schedule:
    def __init__(self, start_date: datetime, end_date: datetime, name: str = "New schedule"):
        self._name = name
        self._shifts = []
        self._shift_types = []
        self._employees = []
        self._roles = []
        self._start_date = start_date
        self._end_date = end_date

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name: str):
        self._name = name   

    @property
    def shifts(self):
        return self._shifts

    @property
    def shift_types(self):
        return self._shift_types

    @property
    def employees(self):
        return self._employees

    @employees.setter
    def employees(self, employees: list):
        self._employees = employees

    @property
    def roles(self):
        return self._roles

    @property
    def start_date(self):
        return self._start_date

    @start_date.setter
    def start_date(self, start_date: datetime):
        self._start_date = start_date

    @property
    def end_date(self):
        return self._end_date

    @end_date.setter
    def end_date(self, end_date: datetime):
        self._end_date = end_date

    # create_shift
    def create_shift(self, start: datetime, end: datetime, min_employees: int, max_employees: int, name: str = "New shift", type: str = "Uncategorized"):
        # check if start and end are within the schedule's start and end dates
        assert start >= self.start_date and end <= self.end_date, "Shift must be within the schedule's start and end dates"
        self._shifts.append(Shift(TimeInterval(start, end), min_employees, max_employees, name, type))
        if type not in self._shift_types:
            self._shift_types.append(type)

    # remove_shift
    def remove_shift(self, shift: Shift):
        self._shifts.remove(shift)
        if shift.type not in [s.type for s in self._shifts]:
            self._shift_types.remove(shift.type)

    # create_employee
    def create_employee(self, name: str = "New Employee", role: str = "Uncategorized"):
        self._employees.append(Employee(name = name, role = role))
        if role not in self._roles:
            self._roles.append(role)
            
    # remove_employee
    def remove_employee(self, employee: Employee):
        self._employees.remove(employee)
        if employee.role not in [e.role for e in self._employees]:
            self._roles.remove(employee.role)

    # assign_employee
    def assign_employee(self, employee: Employee, shift: Shift):
        # check if employee is available during the shift
        assert shift.interval in employee.availability, "Employee is not available during the shift"
        # check if shift is not full
        assert len(shift._assigned_employees) < shift.max_employees, "Shift is full"
        # check if employee is not already assigned to the shift
        assert employee not in shift._assigned_employees, "Employee is already assigned to the shift"

        employee.assigned_shifts.append(shift)
        shift._add_employee(employee)

        # update employee's availability
        employee._remove_availability(shift.interval)

    # unassign_employee
    def unassign_employee(self, employee: Employee, shift: Shift):
        # check if employee is assigned to the shift
        assert employee in shift._assigned_employees, "Employee is not assigned to the shift"

        employee.assigned_shifts.remove(shift)
        shift._remove_employee(employee)

        # update employee's availability
        employee._add_availability(shift.interval)

    # generate_schedule with constraints, objectives with CP-SAT
    def generate_schedule(self):
        pass

    # display
    def display(self):
        pass
    
    # save to csv
    def save(self):
        pass

    # load from csv
    def load(self):
        pass


    

In [33]:
# Employee test cases with unittest

class TestEmployee(unittest.TestCase):
    emp = Employee(name="John", role="Cashier")
    dt1 = datetime(2021, 1, 1, 8, 0)
    dt2 = datetime(2021, 1, 1, 12, 0)
    dt3 = datetime(2021, 1, 1, 13, 0)
    dt4 = datetime(2021, 1, 1, 17, 0)
    __iv1 = TimeInterval(dt1, dt2)
    __iv2 = TimeInterval(dt2, dt3)
    __iv3 = TimeInterval(dt3, dt4)
    __iv4 = TimeInterval(dt1, dt4)

    intervals1 = TimeIntervals([__iv1])
    intervals2 = TimeIntervals([__iv2])
    intervals3 = TimeIntervals([__iv3])
    intervals4 = TimeIntervals([__iv4])

    # print(intervals1)

    def test_init(self):
        self.assertEqual(self.emp.name, "John")
        self.assertEqual(self.emp.role, "Cashier")
        self.assertEqual(self.emp.assigned_shifts, [])
    
    def test_add_availability(self):
        self.emp._add_availability(self.__iv1)
        self.assertEqual(self.emp.availability, self.intervals1)
        self.emp._add_availability(self.__iv2)
        self.assertEqual(self.emp.availability, self.intervals1 + self.intervals2)
        self.emp._add_availability(self.__iv3)
        self.assertEqual(self.emp.availability, self.intervals4)
       

    def test_remove_availability(self):
        self.emp._clear_availability()
        print(self.emp.availability)
        self.emp._add_availability(self.__iv1)
        print(self.emp.availability)
        self.emp._add_availability(self.__iv2)
        print(self.emp.availability)
        self.emp._add_availability(self.__iv3)
        print(self.emp.availability)
        self.emp._remove_availability(self.__iv1)
        print(self.emp.availability)
        self.assertEqual(self.emp.availability, self.intervals2 + self.intervals3)

unittest.main(argv=[''], verbosity=2, exit=False)



test_add_availability (__main__.TestEmployee) ... ok
test_init (__main__.TestEmployee) ... ok
test_remove_availability (__main__.TestEmployee) ... FAIL
test_contains (__main__.TestTimeInterval) ... ok
test_init (__main__.TestTimeInterval) ... ok
test_overlaps (__main__.TestTimeInterval) ... ok
test_subtraction (__main__.TestTimeInterval) ... FAIL
test_add (__main__.TestTimeIntervals) ... FAIL
test_eq (__main__.TestTimeIntervals) ... FAIL
test_init (__main__.TestTimeIntervals) ... FAIL
test_remove (__main__.TestTimeIntervals) ... FAIL

FAIL: test_remove_availability (__main__.TestEmployee)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/qm/hc980d_116zfqwx_9dxh3dmm0000gn/T/ipykernel_58110/2286337519.py", line 49, in test_remove_availability
    self.assertEqual(self.emp.availability, self.intervals2 + self.intervals3)
AssertionError: TimeIntervals([TimeInterval(2021-01-01 13:00:00, 2021-01-01 17:00:00)]) != Ti

2021-01-01 08:00:00 - 2021-01-01 12:00:00
2021-01-01 08:00:00 - 2021-01-01 12:00:00
2021-01-01 08:00:00 - 2021-01-01 13:00:00
TimeIntervals([])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 13:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 13:00:00)])
TimeIntervals([TimeInterval(2021-01-01 08:00:00, 2021-01-01 17:00:00)])
TimeIntervals([TimeInterval(2021-01-01 13:00:00, 2021-01-01 17:00:00)])


<unittest.main.TestProgram at 0x1371c59d0>