adapted from:
https://docs.ray.io/en/latest/train/examples/train_fashion_mnist_example.html

In [1]:
%%capture pip_outputs
%pip install --upgrade python-dotenv openshift-client
%pip uninstall --yes pickle5

In [2]:
%xmode Minimal

Exception reporting mode: Minimal


In [3]:
import os
import glob
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor, Resize,Compose, ToPILImage

from sklearn.metrics import classification_report, confusion_matrix

import ray
from ray import train
from ray import tune
import ray.train.torch
from ray.train import Trainer
from ray.train import CheckpointStrategy
from ray.data.datasource import SimpleTorchDatasource

import matplotlib.pyplot as plt

from ray_cluster_control import start_ray_cluster, stop_ray_cluster

In [4]:
start_ray_cluster(cluster_name = 'octo-next-2022')

RayCluster octo-next-2022 has started


In [5]:
os.getenv('RAY_CLUSTER_NAME')

'octo-next-2022'

In [8]:
ray.init('ray://{ray_head}-ray-head:10001'.format(ray_head=os.environ['RAY_CLUSTER_NAME']))

2022-09-11 00:24:40,012	ERROR dataclient.py:290 -- Unrecoverable error in data channel.


ConnectionError: Request can't be sent because the Ray client has already been disconnected due to an error. Last exception: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.NOT_FOUND
	details = "Attempted to reconnect a session that has already been cleaned up"
	debug_error_string = "{"created":"@1662855880.012452796","description":"Error received from peer ipv4:172.30.148.144:10001","file":"src/core/lib/surface/call.cc","file_line":1074,"grpc_message":"Attempted to reconnect a session that has already been cleaned up","grpc_status":5}"
>

In [None]:
stop_ray_cluster(cluster_name = 'octo-next-2022')

This data set is only 3K and not the full 7k since we read in dataset with the `split` parameters that pulls in the files based on the corresponding txt file in `data/data/oxford-iiit-pet/annotations`

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
device = "cpu"

In [None]:
# modified from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3,1 )
        
        self.fc1 = nn.Linear(6272 , 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 37)
        
        self.dropout1 = nn.Dropout(0.25)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool((self.conv1(x)))
        x = self.pool((self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        y_pred = self.fc3(x)
        return y_pred

In [None]:
# Now with Ray 

In [None]:
dataset_factory = lambda: torchvision.datasets.OxfordIIITPet(
    "data/",
    download = True,
    target_types = "category",
    transform = transforms,
    split = "trainval")

dataset = ray.data.read_datasource(
    SimpleTorchDatasource(),
    parallelism=1,
    dataset_factory=dataset_factory
)

In [None]:
x = dataset.random_shuffle().repartition(2).split(2)
x

In [None]:
train_data = x[0]
test_data = x[1]

In [None]:
def train_epoch(dataloader, model, loss_fn, optim):
    model.train()
    model.to("cuda")
    for batch_idx, data in enumerate(dataloader.iter_batches()):
        X = [x[0] for x in data]
        X = torch.stack(X)
        X = X.to("cuda")
        y = [torch.tensor(x[1]) for x in data]
        y = torch.stack(y)
        y = y.to("cuda") 
        pred = model(X)
        loss = loss_fn(pred, y)
        
        #backprop
        optim.zero_grad()
        loss.backward()
        optim.step()
        
    return loss.item()

In [None]:
def validate_epoch(dataloader, model, loss_fn):
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for data in dataloader.iter_batches():
            X = [x[0] for x in data]
            X = torch.stack(X)
            X = X.to("cuda")
            y = [torch.tensor(x[1]) for x in data]
            y = torch.stack(y)
            y = y.to("cuda")
            pred = model(X)
            test_loss += loss_fn(pred, y).item()  
    
    return test_loss/dataloader.num_blocks() 

In [None]:
# Define a search space.
search_space = {
    "lr": tune.grid_search([0.001, 0.1]),
    "momentum": tune.grid_search([0.1, 0.9]),
    "batch_size": tune.grid_search([4]), 
    "epochs": tune.grid_search([300]),
    "tune_run": tune.grid_search([True])
}

In [None]:
def train_func(config):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    momentum = config["momentum"]
    tune_run = config["tune_run"]
    
    train_dataloader = train_data
    test_dataloader = test_data
    model = ConvNet()
    #model = train.torch.prepare_model(model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr,  momentum=momentum)
    
    loss_results = []
    eval_results = []
    
    for epoch in range(epochs):
        loss = train_epoch(train_dataloader, model, loss_fn, optimizer)
        val_loss = validate_epoch(test_dataloader, model, loss_fn)
        #train.report(loss=val_loss)
        loss_results.append(loss)
        eval_results.append(val_loss)
        if not tune_run:
            if epoch % 100 == 0:
                print(f"epoch:{epoch}  loss {loss:>7f}")
                print(f"eval loss: {val_loss:>7f}")    

            if epoch == epochs-1:
                state_dict = model.state_dict()
                consume_prefix_in_state_dict_if_present(state_dict, "module.")
                train.save_checkpoint(epoch=epochs, model_weights=state_dict)
    
    if tune_run:
        tune.report(loss=loss)
    
    return loss_results, eval_results

### Tune! 

In [None]:
analysis = tune.run(
    tune.with_parameters(train_func),
    config=search_space,
    resources_per_trial={'gpu': 1})

print(analysis.get_best_config(metric="loss", mode="min"))

### Now that we know best param, we'll train and save our model. This prevents us from saving N model trails that we'll need to deleted later.  

In [None]:
%%time
if os.environ.get('RAY_CLUSTER_NAME') is not None:
    num_workers = 2
else:
    num_workers = 1
    

trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=True)

In [None]:
%%time
trainer.start()# This keeps starting new workers 

In [None]:
config = analysis.get_best_config(metric="loss", mode="min")
config["tune_run"] = False

In [None]:
%%time
loss, eval_ = trainer.run(train_func=train_func,
                          config = config)


In [None]:
results = trainer.latest_checkpoint

In [None]:
%%time
trainer.shutdown()

In [None]:
plt.plot(loss[0])
plt.plot(eval_[1])
plt.show()

In [None]:
model = ConvNet()
model.load_state_dict(results["model_weights"])
model_scripted = torch.jit.script(model)
model_scripted.save('models/ray_model_scripted.pt')

In [None]:
model = torch.jit.load("models/ray_model_scripted.pt", map_location=torch.device('cpu')).eval()

In [None]:
train_features, train_labels = next(iter(data_loader))
make = ToPILImage()
make(train_features[0]).show()
print(f'label:{train_labels[0]}')
logits = model(train_features[0].reshape(1,3,64,64).to(device))
print(f"prediction: {torch.argmax(logits).item()}")

In [None]:
y_preds = []
gt = []
for X,y in data_loader:
    logits = model(X[0].reshape(1,3,64,64).to(device))
    y_pred = torch.argmax(logits).item()
    y_preds.append(y_pred)
    gt.append(y[0].item())

In [None]:
for data in test_data.iter_batches():
    X = [x[0] for x in data]
    X = torch.stack(X)
    X = X.to("cpu")
    y = [torch.tensor(x[1]) for x in data]
    y = torch.stack(y)
    logits = model(X.to(device))
    y_preds = torch.argmax(logits, dim=1)
    y_preds = y_preds.to("cpu")

In [None]:
print(classification_report(y_preds,y))

In [None]:
plt.imshow(confusion_matrix(y_preds,y), cmap='hot', interpolation='nearest')
plt.show()

In [None]:
print(f"Accuracy: {np.sum(np.array(y_preds)==np.array(y))/(len(y))*100:.1f}%")
print(f"Random guess would yield {(1/37)*100:.1f}%")
print(f"Model outperforms Random chance by a factor of {(np.sum(np.array(y_preds)==np.array(y))/len(y))/(1/37):.2f}")

In [None]:
stop_ray_cluster()