In [1]:
# from objects import *
from datetime import datetime, timedelta
from utils import *
from ortools.sat.python import cp_model
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.figure_factory as ff

In [75]:
class Availability():
    """Availability class, define the availability of each person
    list of datetime interval, each interval is a tuple of (start, end)
    start and end are datetime object
    Availability can be combined by using the + operator if intervals overlap, they will be combined into one interval
    Availability can be subtracted by using the - operator, the result will be the availability of the first object minus the availability of the second object
    """


    def __init__(self, intervals = []):
        self.intervals = intervals

    def __add__(self, other):
        intervals = self.intervals + other.intervals
        intervals.sort(key=lambda x: x[0])
        new_intervals = [intervals[0]]
        for i in range(1, len(intervals)):
            if intervals[i][0] <= new_intervals[-1][1]:
                new_intervals[-1] = (new_intervals[-1][0], max(intervals[i][1], new_intervals[-1][1]))
            else:
                new_intervals.append(intervals[i])
        return Availability(new_intervals)

    def __sub__(self, other):
        intervals = self.intervals
        for interval in other.intervals:
            new_intervals = []
            for i in range(len(intervals)):
                if intervals[i][1] <= interval[0] or intervals[i][0] >= interval[1]:
                    new_intervals.append(intervals[i])
                elif intervals[i][0] < interval[0] and intervals[i][1] > interval[1]:
                    new_intervals.append((intervals[i][0], interval[0]))
                    new_intervals.append((interval[1], intervals[i][1]))
                elif intervals[i][0] < interval[0] and intervals[i][1] <= interval[1]:
                    new_intervals.append((intervals[i][0], interval[0]))
                elif intervals[i][0] >= interval[0] and intervals[i][1] > interval[1]:
                    new_intervals.append((interval[1], intervals[i][1]))
            intervals = new_intervals
        return Availability(intervals)

    def __contains__(self, other: tuple):
        for interval in self.intervals:
            if interval[0] <= other[0] and interval[1] >= other[1]:
                return True
        return False

    def get_hours(self):
        hours = 0
        for interval in self.intervals:
            hours += (interval[1] - interval[0]).total_seconds() / 3600
        return hours

    # Plot the availability
    def plot(self):
        df = pd.DataFrame(columns=['Task', 'Start', 'Finish', 'Resource'])
        for i in range(len(self.intervals)):
            df.loc[i] = ['Task', self.intervals[i][0], self.intervals[i][1], 'Resource']
        fig = ff.create_gantt(df, index_col='Resource', show_colorbar=True, group_tasks=True)
        fig.show()

    def is_available_for_shift(self, shift):
        for interval in self.intervals:
            if interval[0] <= shift.start and interval[1] >= shift.end:
                return True
        return False


# Interval






In [None]:
a1 = Availability([(datetime(2021, 3, 1, 8), datetime(2021, 3, 1, 12)), (datetime(2021, 3, 1, 13), datetime(2021, 3, 1, 17))])

In [100]:
#Design a shift scheduler for an emergency room. with object oriented programming.
#The scheduler should be able to take in a list of doctors and their availability, and a list of shifts and their requirements.
#Schedule can add new shift types then generate a list of shifts for the schedule.


#Classes: employee, shift, schedule

#Employee class
#Attributes: name, availability, shift preferences
#Methods: add shift preference, remove shift preference, add availability, remove availability
class Employee:
    def __init__(self, name, availability, shift_preferences= []):
        self.name = name
        self.availability = availability
        self.shift_preferences = shift_preferences

    def __repr__(self) -> str:
        return self.name


    def add_shift_preference(self, shift):
        self.shift_preferences.append(shift)

    def remove_shift_preference(self, shift):
        self.shift_preferences.remove(shift)

    def add_availability(self, availability):
        self.availability = self.availability + availability

    def remove_availability(self, availability):
        self.availability = self.availability - availability



#Shift class
#Attributes: start, end, minimum employees, maximum employees, list of employees, shift type, workload
#Methods: set required employees, get duplicate shifts
class Shift:
    def __init__(self, start, end, minimum_employees, maximum_employees, shift_type, workload = 1):
        self.start = start
        self.end = end
        self.interval = (start, end)
        self.minimum_employees = minimum_employees
        self.maximum_employees = maximum_employees
        self.employees = []
        self.shift_type = shift_type
        self.workload = workload
        self.hours = (self.end - self.start).total_seconds() / 3600

    def __str__(self):
        return f"{self.start} - {self.end} {self.shift_type}"
    
    def __repr__(self):
        return f"{self.start} - {self.end} {self.shift_type}"

    def set_required_employees(self, minimum_employees, maximum_employees):
        self.minimum_employees = minimum_employees
        self.maximum_employees = maximum_employees

    def get_duplicate_shifts(self, start, end, interval, work_on_weekends = False, work_on_holidays = False):
        shifts = []
        current = start
        i = 0
        while current < end:
            if work_on_weekends or not is_weekend(current):
                if work_on_holidays or not is_holiday(current):
                    shifts.append(Shift(self.start + interval * i, self.end + interval * i, self.minimum_employees, self.maximum_employees, self.shift_type, self.workload))
            current += interval
            i += 1
        return shifts

    def is_overlap(self, shift):
        return self.start < shift.end and self.end > shift.start




#Schedule class
#Attributes: list of employees, list of shifts
#Methods: add employee, remove employee, add shift, remove shift, solve, display
class Schedule:
    def __init__(self):
        self.employees = []
        self.num_employees = 0
        self.shifts = []
        self.num_shifts = 0
        self.start = None
        self.end = None

    def add_employee(self, employee):
        self.employees.append(employee)
        self.num_employees = len(self.employees)

    def add_employees(self, employees):
        self.employees += employees
        self.num_employees = len(self.employees)

    def remove_employee(self, employee):
        self.employees.remove(employee)
        self.num_employees = len(self.employees)

    def add_shift(self, shift):
        self.shifts.append(shift)
        self.num_shifts += 1
        self.start = min([shift.start for shift in self.shifts])
        self.end = max([shift.end for shift in self.shifts])

    def add_shifts(self, shifts):
        self.shifts += shifts
        self.num_shifts += len(shifts)
        self.start = min([shift.start for shift in self.shifts])
        self.end = max([shift.end for shift in self.shifts])

    def remove_shift(self, shift):
        self.shifts.remove(shift)


    # Solve the schedule with CP-SAT
    def solve(self):
        # Create a model
        model = cp_model.CpModel()

        # Create a variables (shifts, employees)
        shifts = {}
        for shift in self.shifts:
            for employee in self.employees:
                    shifts[(shift, employee)] = model.NewBoolVar(f"{shift} {employee}")

        # # Each shift must have at least the minimum number of employees, and at most the maximum number of employees
        for shift in self.shifts:
            model.Add(sum([shifts[(shift, employee)] for employee in self.employees]) >= shift.minimum_employees)
            model.Add(sum([shifts[(shift, employee)] for employee in self.employees]) <= shift.maximum_employees)

        # employee can't work overlapping shifts
        for employee in self.employees:
            for shift1 in self.shifts:
                for shift2 in self.shifts:
                    if shift1.is_overlap(shift2) and shift1 != shift2:
                        model.Add(shifts[(shift1, employee)] + shifts[(shift2, employee)] <= 1)

        # can't work if not available
        for employee in self.employees:
            for shift in self.shifts:
                if not employee.availability.is_available_for_shift(shift):
                    model.Add(shifts[(shift, employee)] == 0)
    
        
        # Solve the model
        solver = cp_model.CpSolver()
        status = solver.Solve(model)

        # Print the solution
        if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
            for shift in self.shifts:
                for employee in self.employees:
                    if solver.Value(shifts[(shift, employee)]) == 1:
                        shift.employees.append(employee)
                        # print(f"{shift} {employee.name}")
        else:
            print("No solution found.")

    def shift_types(self):
        return list(set([shift.shift_type for shift in self.shifts]))

    def display(self):
        if len(self.shifts) == 0:
            raise Exception("No shifts to display")
        schedule = []
        day = self.start
        while day < self.end:
            shifts_for_day = [shift.employees for shift in self.shifts if shift.start.date() == day.date()]
            day += timedelta(days=1)
            #sort shifts 
            schedule.append(shifts_for_day)
        

        df = pd.DataFrame(data =  schedule)
        display(df)








       






In [89]:
# check is in available sample
shift1 = Shift(datetime(2020, 1, 1, 5), datetime(2020, 1, 1, 11), 1, 2, "A")
a1 = Availability([(datetime(2020, 1, 1, 6), datetime(2020, 1, 1, 12))])
employee1 = Employee("A", a1)

# check if employee is available
employee1.availability.is_available_for_shift(shift1)



False

In [37]:
# Check shift overlap
def is_overlap(shift1, shift2):
    return shift1.start < shift2.end and shift1.end > shift2.start

# sample adjacent shifts
shift1 = Shift(datetime(2020, 1, 1, 8), datetime(2020, 1, 1, 12), 1, 2, "Morning")
shift2 = Shift(datetime(2020, 1, 1, 12), datetime(2020, 1, 1, 16), 1, 2, "Afternoon")

shift1.is_overlap(shift2)

False

In [104]:
start = datetime(2023, 3, 1, 0, 0, 0)
end = datetime(2023, 4, 2, 0, 0, 0)

# Create a schedule
schedule = Schedule()

# Create a list of shifts
morning_shift = Shift(datetime(2023, 3, 1, 7, 30, 0), datetime(2023, 3, 1, 15, 30, 0), 1, 3, "morning shift", 1)
morning_shifts = morning_shift.get_duplicate_shifts(start, end, timedelta(days=1), work_on_weekends=True, work_on_holidays=True)
evening_shift = Shift(datetime(2023, 3, 1, 15, 30, 0), datetime(2023, 3, 1, 23, 30, 0), 1, 3, "evening shift", 1)
evening_shifts = evening_shift.get_duplicate_shifts(start, end, timedelta(days=1), work_on_weekends=True, work_on_holidays=False)
night_shift = Shift(datetime(2023, 3, 1, 23, 30, 0), datetime(2023, 3, 2, 7, 30, 0), 1, 3, "night shift", 1)
night_shifts = night_shift.get_duplicate_shifts(start, end, timedelta(days=1), work_on_weekends=False, work_on_holidays=True)
all_shifts = morning_shifts + evening_shifts + night_shifts




# Create a list of employees
employee1 = Employee("John", availability=Availability([(datetime(2023, 3, 1, 0, 0, 0), datetime(2023, 3, 2, 0, 0, 0))]))
employee2 = Employee("Mary", availability=Availability([(start, end)]))
# employee3 = Employee("Bob", availability=Availability([(start, end)]))
all_employee = [employee1, employee2]

# Add the employees and shifts to the schedule
schedule.add_employees(all_employee)

# Add the shifts to the schedule
schedule.add_shifts(all_shifts)

# # # # Solve the schedule
schedule.solve()

# # # # Display the schedule
schedule.display()




Unnamed: 0,0,1,2
0,[John],[John],[Mary]
1,[Mary],[Mary],[Mary]
2,[Mary],[Mary],[Mary]
3,[Mary],,
4,[Mary],,
5,[Mary],[Mary],
6,[Mary],[Mary],[Mary]
7,[Mary],[Mary],[Mary]
8,[Mary],[Mary],[Mary]
9,[Mary],[Mary],[Mary]


In [99]:
for employee in all_employee:
    for shift in all_shifts:
        if not employee.availability.is_available_for_shift(shift):
            print(f"{employee.name} is not available for {shift}")

In [60]:
for shift1 in morning_shifts:
    for shift2 in morning_shifts:
        if shift1.is_overlap(shift2) and shift1 != shift2:
            print(shift1, shift2)
            print(shift1.is_overlap(shift2))