In [1]:
# Specific libraries for distributed training

import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from torch.nn.parallel import DistributedDataParallel as DDP
from dask_pytorch_ddp import data, dispatch
import torch.distributed as dist
from dask.distributed import Client, progress

In [2]:
## Import helper functions and some additional libraries
%run -i fns.py

In [3]:
### Setup ###
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

s3 = s3fs.S3FileSystem(anon=True)
with s3.open('s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
    imagenetclasses = [line.strip() for line in f.readlines()]

In [4]:
### ============== Constants ============== ###
model_params = {'n_epochs': 6, 
    'batch_size': 100,
    'base_lr': .01,
    'train_pct': .7,
    'downsample_to':1,
    'subset': True, # Whether to break data into N pieces for training
    'worker_ct': 6, # N of pieces to break into
    'bucket': "saturn-public-data",
    'prefix': "dogs/Images",
    'pretrained_classes':imagenetclasses} 

wbargs = {**model_params,
    'classes':120,
    'dataset':"StanfordDogs",
    'architecture':"ResNet"}

project_id = 'a2ae799b6f234f09bd0341aa9769971f'
num_workers = 40

In [5]:
def cluster_transfer_learn(bucket, prefix, train_pct, batch_size, downsample_to,
                          n_epochs, base_lr, pretrained_classes, subset, worker_ct):

    worker_rank = int(dist.get_rank())
    
    # --------- Format model and params --------- #
    device = torch.device("cuda")
    net = models.resnet50(pretrained=False) # True means we start with the imagenet version
    model = net.to(device)
    model = DDP(model)
    
    if worker_rank == 0:
        wandb.init(config=wbargs, reinit=True, project = 'cdl-demo')
        wandb.watch(model)
    
    criterion = nn.CrossEntropyLoss().cuda()    
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9)
    
    # --------- Retrieve data for training and eval --------- #
    whole_dataset = preprocess(bucket, prefix)
    new_class_to_idx = {x: int(replace_label(x, pretrained_classes)[1]) for x in whole_dataset.classes}
    whole_dataset.class_to_idx = new_class_to_idx
    
    train, val = train_test_split(
        train_pct,
        whole_dataset, 
        batch_size=batch_size,
        downsample_to=downsample_to,
        subset = subset, 
        workers = worker_ct
    )
    
    dataloaders = {'train' : train, 'val': val}

    # --------- Start iterations --------- #
    for epoch in range(n_epochs):
        count = 0
        t_count = 0
        
    # --------- Training section --------- #    
        model.train()  # Set model to training mode
        for inputs, labels in dataloaders["train"]:
            dt = datetime.datetime.now().isoformat()

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            perct = [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)]

            loss = criterion(outputs, labels)
            correct = (preds == labels).sum().item()
            
            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1
            
            # Track statistics
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
                
            # Record the results of this model iteration (training sample) for later review.
            if worker_rank == 0:
                wandb.log({
                        'loss': loss.item(),
                        'learning_rate':current_lr, 
                        'correct':correct, 
                        'epoch': epoch, 
                        'count': count,
                        'worker': worker_rank
                    })
            if worker_rank == 0 and count % 5 == 0:
                wandb.log({f'predictions vs. actuals, training, epoch {epoch}, count {count}': plot_model_performance(
                    model, inputs, labels, preds, perct, imagenetclasses)})
                
    # --------- Evaluation section --------- #   
        with torch.no_grad():
            model.eval()  # Set model to evaluation mode
            for inputs_t, labels_t in dataloaders["val"]:
                dt = datetime.datetime.now().isoformat()

                inputs_t, labels_t = inputs_t.to(device), labels_t.to(device)
                outputs_t = model(inputs_t)
                _, pred_t = torch.max(outputs_t, 1)
                perct_t = [torch.nn.functional.softmax(el, dim=0)[i].item() for i, el in zip(pred_t, outputs_t)]

                loss_t = criterion(outputs_t, labels_t)
                correct_t = (pred_t == labels_t).sum().item()
            
                t_count += 1

                # Track statistics
                for param_group in optimizer.param_groups:
                    current_lr = param_group['lr']
                    
                # Record the results of this model iteration (evaluation sample) for later review.
                if worker_rank == 0:
                    wandb.log({
                        'val_loss': loss_t.item(),
                        'val_correct':correct_t, 
                        'epoch': epoch, 
                        'count': t_count,
                        'worker': worker_rank
                    })
                if worker_rank == 0 and count % 5 == 0:
                    wandb.log({f'predictions vs. actuals, eval, epoch {epoch}, count {t_count}': plot_model_performance(
                        model, inputs_t, labels_t, pred_t, perct_t, imagenetclasses)})


In [6]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mskirmer[0m (use `wandb login --relogin` to force relogin)


True

In [7]:
### Saturn Connection Setup ###
with open('config.json') as f:
    tokens = json.load(f)

conn = ExternalConnection(
    project_id=project_id,
    base_url='https://app.internal.saturnenterprise.io',
    saturn_token=tokens['api_token']
)
conn

<dask_saturn.external.ExternalConnection at 0x7fbb61899e50>

In [8]:
cluster = SaturnCluster(
    external_connection=conn,
    n_workers=6,
    worker_size='g4dn4xlarge',
    scheduler_size='2xlarge',
    nthreads=16)

client = Client(cluster)
client.wait_for_workers(6)
client

INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins

+---------+----------------+---------------+---------------+
| Package | client         | scheduler     | workers       |
+---------+----------------+---------------+---------------+
| numpy   | 1.20.1         | 1.19.2        | 1.19.2        |
| python  | 3.7.10.final.0 | 3.7.9.final.0 | 3.7.9.final.0 |
+---------+----------------+---------------+---------------+


0,1
Client  Scheduler: tls://d-steph-cdl-demo-fa90a721acb8498caea5f7a29a297b25.internal.saturnenterprise.io:8786  Dashboard: https://d-steph-cdl-demo-fa90a721acb8498caea5f7a29a297b25.internal.saturnenterprise.io,Cluster  Workers: 6  Cores: 96  Memory: 381.00 GB


In [10]:
### Run Model ###
futures = dispatch.run(
    client, 
    cluster_transfer_learn, 
    **model_params
    )

futures
#futures[0].result()

[<Future: pending, key: dispatch_with_ddp-adbec36c39e45df89b0ec3024ba9c999>,
 <Future: pending, key: dispatch_with_ddp-fd5fd88c251071c12dbf595cf67709b7>,
 <Future: pending, key: dispatch_with_ddp-262e88e33a012c20f3a00ac6fac4f193>,
 <Future: pending, key: dispatch_with_ddp-4ea41f5338ca99e56434370d189c53b7>,
 <Future: pending, key: dispatch_with_ddp-e3d1c8f3c951ceb784f977a311ff8170>,
 <Future: pending, key: dispatch_with_ddp-920f0b7d0038c18bdd7348a6442473e7>]

In [11]:
futures[0].result()