<a href="https://colab.research.google.com/github/pr33tham7/RL_tsp_vrp/blob/main/optim_rl_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [2]:
!git clone https://github.com/pr33tham7/RL_tsp_vrp

Cloning into 'RL_tsp_vrp'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 44 (delta 7), reused 31 (delta 0), pack-reused 0
Unpacking objects: 100% (44/44), done.


In [3]:
"""Defines the main trainer model for combinatorial problems

Each task must define the following functions:
* mask_fn: can be None
* update_fn: can be None
* reward_fn: specifies the quality of found solutions
* render_fn: Specifies how to plot found solutions. Can be None
"""

import os
import time
import argparse
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from RL_tsp_vrp import *
from RL_tsp_vrp.model import DRL4TSP, Encoder
from RL_tsp_vrp.tasks import tsp
from RL_tsp_vrp.tasks.tsp import TSPDataset
from RL_tsp_vrp.tasks import vrp
from RL_tsp_vrp.tasks.vrp import VehicleRoutingDataset

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Critic network that estimates the reward for any problem instance from a given state
class StateCritic(nn.Module):

    def __init__(self, static_size, dynamic_size, hidden_size):
        super(StateCritic, self).__init__()
        
        # Static state input to the Encoder
        self.static_encoder = Encoder(static_size, hidden_size)

        # Dynamic state input to the Encoder
        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)

        # Define the encoder & decoder models
        # Network architecture
        self.fc1 = nn.Conv1d(hidden_size * 2, 20, kernel_size=1)
        self.fc2 = nn.Conv1d(20, 20, kernel_size=1)
        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, static, dynamic):

        # Use the probabilities of visiting each
        static_hidden = self.static_encoder(static)
        dynamic_hidden = self.dynamic_encoder(dynamic)

        hidden = torch.cat((static_hidden, dynamic_hidden), 1)

        output = F.relu(self.fc1(hidden))
        output = F.relu(self.fc2(output))
        output = self.fc3(output).sum(dim=2)
        return output


# Function to evaluate the validation set from the locations database
def validate(data_loader, actor, reward_fn, render_fn=None, save_dir='.',
             num_plot=5):
    
    # Evaluate the actor network that predicts a probability distribution over the next action at any given decision step
    actor.eval()

    # Create a directory to sabe the plots of the paths generated for the trips evaluated
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Evaluate for each batch of validation dataset
    rewards = []
    for batch_idx, batch in enumerate(data_loader):

        static, dynamic, x0 = batch

        static = static.to(device)
        dynamic = dynamic.to(device)
        x0 = x0.to(device) if len(x0) > 0 else None

        with torch.no_grad():
            tour_indices, _ = actor.forward(static, dynamic, x0)
        
        # Calculate the reward and append it to rewards history
        reward = reward_fn(static, tour_indices).mean().item()
        rewards.append(reward)

        if render_fn is not None and batch_idx < num_plot:
            name = 'batch%d_%2.4f.png'%(batch_idx, reward)
            path = os.path.join(save_dir, name)
            render_fn(static, tour_indices, path)

    actor.train()
    return np.mean(rewards)


def train(actor, critic, task, num_nodes, train_data, valid_data, reward_fn,
          render_fn, batch_size, actor_lr, critic_lr, max_grad_norm,
          **kwargs):
    """Constructs the main actor & critic networks, and performs all training."""

    now = '%s' % datetime.datetime.now().time()
    now = now.replace(':', '_')
    save_dir = os.path.join(task, '%d' % num_nodes, now)

    checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    actor_optim = optim.Adam(actor.parameters(), lr=actor_lr)
    critic_optim = optim.Adam(critic.parameters(), lr=critic_lr)

    train_loader = DataLoader(train_data, batch_size, True, num_workers=0)
    valid_loader = DataLoader(valid_data, batch_size, False, num_workers=0)

    best_params = None
    best_reward = np.inf

    for epoch in range(6):
        print(epoch)
        actor.train()
        critic.train()

        times, losses, rewards, critic_rewards = [], [], [], []

        epoch_start = time.time()
        start = epoch_start

        for batch_idx, batch in enumerate(train_loader):

            static, dynamic, x0 = batch

            static = static.to(device)
            dynamic = dynamic.to(device)
            x0 = x0.to(device) if len(x0) > 0 else None

            # Full forward pass through the dataset
            tour_indices, tour_logp = actor(static, dynamic, x0)

            # Sum the log probabilities for each city in the tour
            reward = reward_fn(static, tour_indices)

            # Query the critic for an estimate of the reward
            critic_est = critic(static, dynamic).view(-1)

            advantage = (reward - critic_est)
            actor_loss = torch.mean(advantage.detach() * tour_logp.sum(dim=1))
            critic_loss = torch.mean(advantage ** 2)

            actor_optim.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_grad_norm)
            actor_optim.step()

            critic_optim.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_grad_norm)
            critic_optim.step()

            critic_rewards.append(torch.mean(critic_est.detach()).item())
            rewards.append(torch.mean(reward.detach()).item())
            losses.append(torch.mean(actor_loss.detach()).item())

            if (batch_idx + 1) % 100 == 0:
                end = time.time()
                times.append(end - start)
                start = end

                mean_loss = np.mean(losses[-100:])
                mean_reward = np.mean(rewards[-100:])

                print('  Batch %d/%d, reward: %2.3f, loss: %2.4f, took: %2.4fs' %
                      (batch_idx, len(train_loader), mean_reward, mean_loss,
                       times[-1]))

        mean_loss = np.mean(losses)
        mean_reward = np.mean(rewards)

        # Save the weights
        epoch_dir = os.path.join(checkpoint_dir, '%s' % epoch)
        if not os.path.exists(epoch_dir):
            os.makedirs(epoch_dir)

        save_path = os.path.join(epoch_dir, 'actor.pt')
        torch.save(actor.state_dict(), save_path)

        save_path = os.path.join(epoch_dir, 'critic.pt')
        torch.save(critic.state_dict(), save_path)

        # Save rendering of validation set tours
        valid_dir = os.path.join(save_dir, '%s' % epoch)

        mean_valid = validate(valid_loader, actor, reward_fn, render_fn,
                              valid_dir, num_plot=5)

        # Save best model parameters
        if mean_valid < best_reward:

            best_reward = mean_valid

            save_path = os.path.join(save_dir, 'actor.pt')
            torch.save(actor.state_dict(), save_path)

            save_path = os.path.join(save_dir, 'critic.pt')
            torch.save(critic.state_dict(), save_path)

        print('Mean epoch loss/reward: %2.4f, %2.4f, %2.4f, took: %2.4fs '\
              '(%2.4fs / 100 batches)\n' % \
              (mean_loss, mean_reward, mean_valid, time.time() - epoch_start,
              np.mean(times)))



def train_tsp():#args

    # Goals from paper:
    # TSP20, 3.97
    # TSP50, 6.08
    # TSP100, 8.44

    # from RL_tsp_vrp.tasks import tsp
    # from tasks.tsp import TSPDataset

    STATIC_SIZE = 2 # (x, y)
    DYNAMIC_SIZE = 1 # dummy for compatibility
    
    
#  parser = argparse.ArgumentParser(description='Combinatorial Optimization')
#     parser.add_argument('--seed', default=12345, type=int)
#     parser.add_argument('--checkpoint', default=None)
#     parser.add_argument('--test', action='store_true', default=False)
#     parser.add_argument('--task', default='tsp')
#     parser.add_argument('--nodes', dest='num_nodes', default=20, type=int)
#     parser.add_argument('--actor_lr', default=5e-4, type=float)
#     parser.add_argument('--critic_lr', default=5e-4, type=float)
#     parser.add_argument('--max_grad_norm', default=2., type=float)
#     parser.add_argument('--batch_size', default=256, type=int)
#     parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)
#     parser.add_argument('--dropout', default=0.1, type=float)
#     parser.add_argument('--layers', dest='num_layers', default=1, type=int)
#     parser.add_argument('--train-size',default=1000000, type=int)
#     parser.add_argument('--valid-size', default=1000, type=int)
    
    # Parameters 
    num_nodes = 20
    train_size = 1000000
    seed = 12345
    valid_size = 1000
    hidden_size = 128
    num_layers = 100
    checkpoint = None
    dropout = 0.1
    batch_size = 256
    test = None #layers-nn
    actor_lr = 5e-4
    critic_lr = 5e-4
    max_grad_norm = 2


    train_data = TSPDataset(num_nodes, train_size, seed)
    valid_data = TSPDataset(num_nodes, valid_size, seed + 1)

    update_fn = None

    actor = DRL4TSP(STATIC_SIZE,
                    DYNAMIC_SIZE,
                    hidden_size,
                    update_fn,
                    tsp.update_mask,
                    num_layers,
                    dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, hidden_size).to(device)

    kwargs = {} #vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = tsp.reward
    kwargs['render_fn'] = tsp.render

    if checkpoint:
        path = os.path.join(checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))

        path = os.path.join(checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))
    
    if not test:
        task = 'tsp'
        train(actor, critic,task,num_nodes,train_data,valid_data, tsp.reward, tsp.render, batch_size,actor_lr, critic_lr, max_grad_norm ) #,kwargs

    test_data = TSPDataset(num_nodes, train_size, seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, batch_size, False, num_workers=0)
    out = validate(test_loader, actor, tsp.reward, tsp.render, test_dir, num_plot=5)

    print('Average tour length: ', out)


In [None]:
if __name__ == '__main__':

    task = 'tsp' 
    
    if task == 'tsp':
        train_tsp()
    else:
        raise ValueError('Task <%s> not understood'%args.task)


0
  Batch 99/3907, reward: 9.781, loss: -137.4742, took: 64.1935s
  Batch 199/3907, reward: 7.863, loss: -0.4888, took: 63.5860s
  Batch 299/3907, reward: 7.618, loss: -0.5283, took: 63.7242s


In [None]:
# parser