# Text generation neural network

This neural network generates new pet names by first training a neural network on pet names from Seattle pet license data. The training runs very quickly by distributing the training work across multiple computers with GPUs in Saturn Cloud.

## Training the model

In [1]:
import pandas as pd
import re
import uuid
import datetime
import pickle
import json
import torch
import math
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
# additional libraries for doing the Saturn Cloud parallel work
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader
from dask_pytorch_ddp import data, dispatch, results
from dask_saturn import SaturnCluster
from dask.distributed import Client
from distributed.worker import logger

In [12]:
pet_names_raw = pd.read_csv("https://raw.githubusercontent.com/saturncloud/saturn-cloud-examples/master/text-generation-nn/seattle_pet_licenses.csv")
pet_names = pet_names_raw["Animal's Name"].tolist()
# Our list of characters, where * represents blank and + represents stop
characters = list("*+abcdefghijklmnopqrstuvwxyz-. ")

str_len = 8
num_epochs = 7
lstm_size = 128
lstm_layers = 4


def format_training_data(pet_names):
    def get_substrings(in_str):
        in_str = in_str.lower() + "+"
        res = [in_str[0: j] for j in range(1, len(in_str) + 1)]
        return res
    pattern = re.compile("^[ \\.\\-a-zA-Z]*$")
    pet_names_filtered = [name for name in pet_names if isinstance(name, str) and not name.isspace() and pattern.match(name)]
    pet_names_expanded = [get_substrings(name) for name in  pet_names_filtered]
    pet_names_expanded = [item for sublist in pet_names_expanded for item in sublist]
    pet_names_characters = [list(name) for name in pet_names_expanded]
    pet_names_padded = [name[-(str_len + 1):] for name in pet_names_characters]
    pet_names_padded = [list((str_len + 1- len(characters)) * "*") + characters for characters in pet_names_padded]
    pet_names_numeric = [[characters.index(char) for char in name] for name in pet_names_padded]

    # the final x and y data
    y = torch.tensor([name[1:] for name in pet_names_numeric])
    x = torch.tensor([name[:-1] for name in pet_names_numeric])
    x = torch.nn.functional.one_hot(x, num_classes = len(characters)).float()
    return x, y

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(
            input_size=len(characters),
            hidden_size=lstm_size,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=0.1,
        )
        self.fc = nn.Linear(lstm_size, len(characters))

    def forward(self, x):
        output, state = self.lstm(x)
        logits = self.fc(output)
        return logits

    def init_state(self, custom_batch_size=None):
        if custom_batch_size is None:
            custom_batch_size = batch_size
        return (torch.zeros(lstm_layers, custom_batch_size, lstm_size),
                torch.zeros(lstm_layers, custom_batch_size, lstm_size))

In [4]:


class OurDataset(Dataset):
    def __init__(self, pet_names):
        self.x, self.y = format_training_data(pet_names)
        self.permute()
        
    def __getitem__(self, idx):
        idx = self.permutation[idx]
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return len(self.x)
    
    def permute(self):
        self.permutation = torch.randperm(len(self.x))
        
# pet_names_raw = pd.read_csv("https://raw.githubusercontent.com/saturncloud/saturn-cloud-examples/master/text-generation-nn/seattle_pet_licenses.csv")
# pet_names = pet_names_raw["Animal's Name"].tolist()        
# loader = DataLoader(OurDataset(pet_names), batch_size=5)



In [14]:
# model training function
# when this is run it saved the model output after each epoch (overwriting the previous one)
# If multiple computers are training the model, they'll each save to the same place


def train():
    worker_rank = int(dist.get_rank())
    logger.info(f"Worker {worker_rank} - beginning")

    # x, y = format_training_data(pet_names)
    dataset = OurDataset(pet_names)
    sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    worker_rank = int(dist.get_rank())
    device = torch.device(0)
    
    model = Model()
    model = model.to(device)
    device_ids = [0]
    model = DDP(model, device_ids=device_ids)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001*learning_rate_multiplier)
    
    num_batches = len(loader)
    for epoch in range(num_epochs):
        logger.info(f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - Beginning epoch {epoch}")
        sampler.set_epoch(epoch)
            
        for i, (batch_x, batch_y) in enumerate(loader):
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()
            # batch_y_pred, (state_h, state_c) = model(batch_x, (state_h, state_c))
            batch_y_pred = model(batch_x)
            # state_h = state_h.detach()
            # state_c = state_c.detach()
            # batch_y_pred = model(batch_x)
            loss = criterion(batch_y_pred.transpose(1, 2), batch_y)
            loss.backward()
            optimizer.step()
            logger.info(f"Worker {worker_rank} - {datetime.datetime.now().isoformat()} - epoch {epoch} - batch {i} - batch complete - loss {loss.item()}")
        # store metrics while the model is training
        current_time = datetime.datetime.now().isoformat()
        rh.submit_result(
            f"logs/data_{worker_rank}_{epoch}_{current_time}.json", 
            json.dumps({'loss': loss.item(),
                        'elapsed_time': (datetime.datetime.now() - training_start_time).total_seconds(),
                        'epoch': epoch,
                        'worker': worker_rank})
        )
        #### I think there might be a bug with rh concurrency by having these two calls next to each other so I'm putting a sleep here
        # save the model at the end of each epoch
        rh.submit_result(f"model.pkl", pickle.dumps(model.state_dict()))
        dataset.permute()

## Train a single model

In [None]:
batch_size = 16384
learning_rate_multiplier = 1
num_workers = 3

key = uuid.uuid4().hex
rh = results.DaskResultsHandler(key)
cluster = SaturnCluster()
cluster.scale(num_workers)
client = Client(cluster)

In [None]:
# start the parallel job, and use process_results to save the output
# client.restart()
training_start_time = datetime.datetime.now()
futures = dispatch.run(client, train)
rh.process_results(f"/home/jovyan/training/{datetime.datetime.now().isoformat()}/", futures, raise_errors=False)

In [None]:
client.close()

## Train models for comparison

In [17]:
# batch size, num_workers, learning_rate_multiplier
parameters = [
    (16384, 3, 1),
    (int(16384/3), 3, 1),
    (16384, 3, 3),
    (int(16384/3), 3, 3),
    (16384, 1, 1),
    (int(16384/3), 1, 1),
    (16384, 1, 3),
    (int(16384/3), 1, 3)
]

In [18]:
comparison_start_time = datetime.datetime.now()

In [19]:
for batch_size, num_workers, learning_rate_multiplier in parameters:
    print(f"{datetime.datetime.now().isoformat()} - Running training for batch={batch_size} num_workers={num_workers} learning_rate_multiplier={learning_rate_multiplier}")
    training_start_time = datetime.datetime.now()
    key = uuid.uuid4().hex
    rh = results.DaskResultsHandler(key)
    cluster = SaturnCluster()
    cluster.scale(num_workers)
    client = Client(cluster)
    futures = dispatch.run(client, train)
    rh.process_results(f"/home/jovyan/training-comparison/{comparison_start_time.isoformat()}/batch={batch_size}&num_workers={num_workers}&learning_rate_multiplier={learning_rate_multiplier}/", futures, raise_errors=False)
    client.close()

2021-02-18T20:05:34.348255 - Running training for batch=16384 num_workers=3 learning_rate_multiplier=1


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}
ERROR:root:Timed out trying to connect to 'tcp://d-jnoli-neural-net-test-3a7c9e5679eb46bd834b2094f4113405.main-namespace:8786' after 10 s: Timed out trying to connect to 'tcp://d-jnoli-neural-net-test-3a7c9e5679eb46bd834b2094f4113405.main-namespace:8786' after 10 s: connect() didn't finish in time
Traceback (most recent call last):
  File "/srv/conda/envs/saturn/lib/python3.7/site-packages/distributed/comm/core.py", line 322, in connect
    _raise(error)
  File "/srv/conda/envs/saturn/lib/python3.7/site-packages/distributed/comm/core.py", line 275, in _raise
    raise IOError(msg)
OSError: Timed out trying to connect to 'tcp://d-jnoli-neural-net-test-3a7c9e5679eb46bd834b2094f4113405.main-namespace:8786' after 10 s: connect() didn't finish in time

D

2021-02-18T20:15:27.685761 - Running training for batch=5461 num_workers=3 learning_rate_multiplier=1


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}


2021-02-18T20:26:14.089030 - Running training for batch=16384 num_workers=3 learning_rate_multiplier=3


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}


2021-02-18T20:36:08.396415 - Running training for batch=5461 num_workers=3 learning_rate_multiplier=3


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}
ERROR:asyncio:Task exception was never retrieved
future: <Task finished coro=<connect.<locals>._() done, defined at /srv/conda/envs/saturn/lib/python3.7/site-packages/distributed/comm/core.py:288> exception=CommClosedError()>
Traceback (most recent call last):
  File "/srv/conda/envs/saturn/lib/python3.7/site-packages/distributed/comm/core.py", line 297, in _
    handshake = await asyncio.wait_for(comm.read(), 1)
  File "/srv/conda/envs/saturn/lib/python3.7/asyncio/tasks.py", line 435, in wait_for
    await waiter
concurrent.futures._base.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/srv/conda/envs/saturn/lib/python3.7/site-packages/distributed/comm/core.py", line 30

2021-02-18T20:45:29.938927 - Running training for batch=16384 num_workers=3 learning_rate_multiplier=0.333


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}


2021-02-18T20:50:55.012705 - Running training for batch=16384 num_workers=1 learning_rate_multiplier=1


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}, 'tcp://10.0.13.144:43293': {'status': 'repeat'}, 'tcp://10.0.3.61:33135': {'status': 'repeat'}}
INFO:dask-saturn:Cluster is ready


2021-02-18T20:54:06.893501 - Running training for batch=5461 num_workers=1 learning_rate_multiplier=1


INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}}


2021-02-18T20:57:16.530829 - Running training for batch=16384 num_workers=1 learning_rate_multiplier=3


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}}


2021-02-18T21:00:26.870718 - Running training for batch=5461 num_workers=1 learning_rate_multiplier=3


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}}


2021-02-18T21:03:37.096368 - Running training for batch=16384 num_workers=1 learning_rate_multiplier=0.333


INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:{'tcp://10.0.0.125:42207': {'status': 'repeat'}}
