In [16]:
# ! pip install 'ray[tune]' --quiet
# ! pip install pandas --quiet

In [7]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [14]:
class NeuralNetwork(nn.Module):
    def __init__(self, l1=128, l2=64):
        super(NeuralNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=1), 
            nn.Flatten(), 
            nn.Linear(26*26*4, l1),
            nn.ReLU(),
            nn.Linear(l1, l2),
            nn.Linear(l2, 10)
        )

    def forward(self, X):
        logits = self.cnn(X)
        return logits

In [3]:
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [28]:
def load_data(dir='/Users/yujian/Documents/personal-workspace/nn_examples/pytorch/fashion_mnist'):
    trainset = datasets.FashionMNIST(root=dir, train=True, transform=transforms.ToTensor())
    testset = datasets.FashionMNIST(root=dir, train=False, transform=transforms.ToTensor())
    return trainset, testset

In [29]:
def train_network(config, checkpoint_dir=None, data_dir=None):
    model = NeuralNetwork(config["l1"], config["l2"])
    learning_rate = 1e-3
    batchsize=64
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    device = 'cpu'
    model.to(device)

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    
    training_data, testing_data = load_data()
    val_idx = int(len(training_data)*0.2)
    training_subset, validation_subset = torch.utils.data.random_split(training_data, [val_idx, len(training_data)-val_idx])
    train_dataloader = DataLoader(training_subset, batch_size=batchsize, shuffle=True)
    val_dataloader = DataLoader(validation_subset, batch_size=batchsize, shuffle=True)

    for epoch in range(10):
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(train_dataloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # forward pass + backprop + optimization step
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            # printing stats
            running_loss += loss.item()
            epoch_steps += 1

            # mini batches of 100
            if i % 100 == 0:
                print(f"{epoch}, {i}, {running_loss/epoch_steps:>5f}")
                running_loss = 0.0
            
            # validation functionality
            val_loss = 0.0
            val_steps = 0
            total = 0
            correct = 0
            for i, data in enumerate(val_dataloader, 0):
                with torch.no_grad():
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)

                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    loss = loss_fn(outputs, labels)
                    val_loss += loss.numpy() 
                    val_steps += 1
            
            with tune.checkpoint_dir(epoch) as checkpoint_dir:
                path = os.path.join(checkpoint_dir, "checkpoint")
                torch.save((model.state_dict(), optimizer.state_dict()), path)

            tune.report(loss = (val_loss/val_steps), accuracy = correct/total)

In [30]:
def test_accuracy(model, device="cpu"):
    training_data, testing_data = load_data()
    testloader = DataLoader(testing_data, batch_size=64, shuffle=False)
    
    # measure accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct/total

In [31]:
import numpy as np
import functools

In [32]:
config = {
    "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9))
}

In [33]:
result = tune.run(
    functools.partial(train_network),
    config=config
)

0,1
Current time:,2023-01-26 13:34:59
Running for:,00:15:43.08
Memory:,12.1/16.0 GiB

Trial name,status,loc,l1,l2,iter,total time (s),loss,accuracy
train_network_16e39_00000,RUNNING,127.0.0.1:31051,256,8,325,941.084,2.27722,0.0995417


[2m[36m(pid=31051)[0m   Referenced from: /Users/yujian/.pyenv/versions/3.10.1/lib/python3.10/site-packages/torchvision/image.so
[2m[36m(pid=31051)[0m   Expected in: /Users/yujian/.pyenv/versions/3.10.1/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
[2m[36m(pid=31051)[0m   warn(f"Failed to load image Python extension: {e}")


[2m[36m(func pid=31051)[0m 0, 0, 2.286733


Trial name,accuracy,date,done,episodes_total,experiment_id,hostname,iterations_since_restore,loss,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
train_network_16e39_00000,0.0995417,2023-01-26_13-34-59,False,,45bbd37591764a3988f40bbdf27e200d,yujians-mbp.lan,325,2.27722,127.0.0.1,31051,True,941.084,2.88751,941.084,1674768899,0,,325,16e39_00000,0.00209093


[2m[36m(func pid=31051)[0m 0, 100, 2.283761
[2m[36m(func pid=31051)[0m 1, 0, 2.273696
[2m[36m(func pid=31051)[0m 1, 100, 2.266338


2023-01-26 13:34:59,958	ERROR tune.py:758 -- Trials did not complete: [train_network_16e39_00000]
2023-01-26 13:34:59,958	INFO tune.py:762 -- Total run time: 943.29 seconds (943.06 seconds for the tuning loop).
