In [1]:
%config Completer.use_jedi = False

In [None]:
import collections
from ortools.sat.python import cp_model

Options:

1) function that transforms a path into job data structure
    * this will only evaluate the scheduling effort of our method
2) change our method to use the added constraints of machine assignement
    * does this make the problem harder or easier?
    * link prediction on the "next" arrows

#### Minimal jobshop problem

In [None]:
def g2jobdata(g0, a):
    """
    Reformat data such that CPModel can solve for the optimal MakeSpan.
    
    a : list of action tuples as recieved from MDP path sampler
    g0: initial g as returned from env.reset()
    
    From action log get sequence of assignments and get times from node features.
    Then deduce jobs (sequence of operations) and store in data log.
    
    Example format:
    jobs_data = [  # task = (machine_id, processing_time).
        [(0, 3), (1, 2), (2, 2)],  # Job0
        [(0, 2), (2, 1), (1, 4)],  # Job1
        [(1, 4), (2, 3)]  # Job2
    ]
    """
    
    a = torch.tensor(a)
    assignments = a[a[:, 1].argsort()][:, 0]
    times = g0.ndata['hv']['job'][:, 0].div(0.1, rounding_mode='trunc').int().tolist()
    
    jobs_data = []
    prev_j = -1
    for (i, j) in torch.stack(g0.edges(etype="precede")).T.tolist():
        if prev_j != i:
            jobs_data.append([(assignment[i], times[i])])
        jobs_data[-1].append((assignment[j], times[j]))
        prev_j = j

    single_jobs = [i for i in range(g0.num_nodes('job')) if i not in torch.stack(g0.edges(etype="precede")).unique()]
    for i in single_jobs:
        jobs_data.append([(assignment[i], times[i])])

    return jobs_data

In [None]:
# Named tuple to store information about created variables.
task_type = collections.namedtuple('task_type', 'start end interval')
# Named tuple to manipulate solution information.
assigned_task_type = collections.namedtuple('assigned_task_type',
                                            'start job index duration')

In [None]:
def get_makespan(jobs_data, verbose=False):
    
    machines_count = 1 + max(task[0] for job in jobs_data for task in job)
    all_machines = range(machines_count)
    horizon = sum(task[1] for job in jobs_data for task in job)

    model = cp_model.CpModel()

    # Creates job intervals and add to the corresponding machine lists.
    all_tasks = {}
    machine_to_intervals = collections.defaultdict(list)
    for job_id, job in enumerate(jobs_data):
        for task_id, task in enumerate(job):
            machine = task[0]
            duration = task[1]
            suffix = '_%i_%i' % (job_id, task_id)
            start_var = model.NewIntVar(0, horizon, 'start' + suffix)
            end_var = model.NewIntVar(0, horizon, 'end' + suffix)
            interval_var = model.NewIntervalVar(start_var, duration, end_var,
                                                'interval' + suffix)
            all_tasks[job_id, task_id] = task_type(start=start_var,
                                                   end=end_var,
                                                   interval=interval_var)
            machine_to_intervals[machine].append(interval_var)

    # Create and add disjunctive constraints.
    for machine in all_machines:
        model.AddNoOverlap(machine_to_intervals[machine])

    # Precedences inside a job.
    for job_id, job in enumerate(jobs_data):
        for task_id in range(len(job) - 1):
            model.Add(all_tasks[job_id, task_id +
                                1].start >= all_tasks[job_id, task_id].end)


    # Makespan objective.
    obj_var = model.NewIntVar(0, horizon, 'makespan')
    model.AddMaxEquality(obj_var, [
        all_tasks[job_id, len(job) - 1].end
        for job_id, job in enumerate(jobs_data)
    ])
    model.Minimize(obj_var)

    # Creates the solver and solve.
    solver = cp_model.CpSolver()
    status = solver.Solve(model)

    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
        if verbose:
            print('Solution:')
            # Create one list of assigned tasks per machine.
            assigned_jobs = collections.defaultdict(list)
            for job_id, job in enumerate(jobs_data):
                for task_id, task in enumerate(job):
                    machine = task[0]
                    assigned_jobs[machine].append(
                        assigned_task_type(start=solver.Value(
                            all_tasks[job_id, task_id].start),
                                           job=job_id,
                                           index=task_id,
                                           duration=task[1]))

            # Create per machine output lines.
            output = ''
            for machine in all_machines:
                # Sort by starting time.
                assigned_jobs[machine].sort()
                sol_line_tasks = 'Machine ' + str(machine) + ': '
                sol_line = '           '

                for assigned_task in assigned_jobs[machine]:
                    name = 'job_%i_task_%i' % (assigned_task.job,
                                               assigned_task.index)
                    # Add spaces to output to align columns.
                    sol_line_tasks += '%-15s' % name

                    start = assigned_task.start
                    duration = assigned_task.duration
                    sol_tmp = '[%i,%i]' % (start, start + duration)
                    # Add spaces to output to align columns.
                    sol_line += '%-15s' % sol_tmp

                sol_line += '\n'
                sol_line_tasks += '\n'
                output += sol_line_tasks
                output += sol_line

            # Finally print the solution found.
            print(f'Optimal Schedule Length: {solver.ObjectiveValue()}')
            print(output)
        
        return solver.ObjectiveValue(), "OPTIMAL" if status==cp_model.OPTIMAL else "FEASIBLE"
            
    if verbose:
        print('No solution found.')
        print('\nStatistics')
        print('  - conflicts: %i' % solver.NumConflicts())
        print('  - branches : %i' % solver.NumBranches())
        print('  - wall time: %f s' % solver.WallTime())
        
    return -1, "INFEASIBLE"

In [None]:
out, status = get_makespan(jobs_data)