<a href="https://colab.research.google.com/github/omarsar/pytorch_notebooks/blob/master/pytorch_quick_start.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
BATCH_SIZE = 32

## transformations
transform = transforms.Compose([transforms.ToTensor()])

## download and load training dataset
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)

## download and load testing dataset
testset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
)

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        # 28x28x1 => 26x26x32
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.d1 = nn.Linear(26 * 26 * 32, 128)
        self.d2 = nn.Linear(128, 10)

    def forward(self, x):
        # 32x1x28x28 => 32x32x26x26
        x = self.conv1(x)
        x = F.relu(x)

        # flatten => 32 x (32*26*26)
        x = x.flatten(start_dim=1)

        # 32 x (32*26*26) => 32x128
        x = self.d1(x)
        x = F.relu(x)

        # logits => 32x10
        logits = self.d2(x)
        out = F.softmax(logits, dim=1)
        return out

In [None]:
## test the model with 1 batch
model = MyModel()
for images, labels in trainloader:
    print("batch size:", images.shape)
    out = model(images)
    print(out.shape)
    break

In [None]:
def get_accuracy(logit, target, batch_size):
    """Obtain accuracy for training round"""
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects / batch_size
    return accuracy.item()

In [None]:
learning_rate = 0.001
num_epochs = 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyModel()

In [None]:
import types


def get_user_vars():
    def imports():
        for name, val in globals().items():
            if isinstance(val, types.ModuleType):
                yield name

    jupyter_builtins = set(
        [
            "quit",
            "exit",
            "In",
            "Out",
            "imports",
            "get_user_vars",
            "jupyter_builtins",
            "builtins",
            "internals",
        ]
    )
    builtins = set(dir(__builtins__))
    internals = set(k for k in globals().keys() if k.startswith("_"))
    return (
        set(globals().keys()) - builtins - internals - jupyter_builtins - set(imports())
    )

In [None]:
from datetime import datetime


def send_training_job(job: callable, model):
    from multiprocessing.connection import Client
    import pickle
    import dill

    if globals().get("_remote_"):
        model = job(model)
        torch.save(model.state_dict(), globals()["_remote_"])
        return model

    address = ("localhost", 6001)
    conn = Client(address, authkey=b"secret password")

    filename = f"state-{datetime.now().isoformat()}"
    conn.send(filename)
    dill.dump_module(filename)
    conn.send(globals()["In"][-1])
    # conn.send(job.__name__)
    # vars = {key:globals()[key] for key in get_user_vars()}
    # vars[job.__name__] = job

    # conn.send(pickle.dumps(vars))

    filename = conn.recv()
    device = torch.device("cpu")
    model.load_state_dict(torch.load(filename, map_location=device))
    return model

In [None]:
def remote(trainfn: callable):
    from functools import wraps

    @wraps(trainfn)
    def decorated(model):
        return send_training_job(trainfn, model)

    decorated.original = trainfn
    return decorated

In [None]:
@remote
def train(model):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        train_running_loss = 0.0
        train_acc = 0.0

        model = model.train()

        ## training step
        for i, (images, labels) in enumerate(trainloader):
            images = images.to(device)
            labels = labels.to(device)

            ## forward + backprop + loss
            logits = model(images)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()

            ## update model params
            optimizer.step()

            train_running_loss += loss.detach().item()
            train_acc += get_accuracy(logits, labels, BATCH_SIZE)

        model.eval()
        print(
            "Epoch: %d | Loss: %.4f | Train Accuracy: %.2f"
            % (epoch, train_running_loss / i, train_acc / i)
        )
    return model

In [None]:
model = train(model)

In [None]:
test_acc = 0.0
for i, (images, labels) in enumerate(testloader, 0):
    outputs = model(images)
    test_acc += get_accuracy(outputs, labels, BATCH_SIZE)

print("Test Accuracy: %.2f" % (test_acc / i))