<a href="https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/GoTch_MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding

This notebook using

1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)
2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)
3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)

## Install Go kernel - GopherNotes

*NOTE: refresh/reload (browser) after this step.*

In [None]:
# run this cell first time using python runtime
!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null
!apt update > /dev/null 
!apt install golang-go > /dev/null
%env GOPATH=/root/go
!go get -u github.com/gopherdata/gophernotes
!cp ~/go/bin/gophernotes /usr/bin/
!mkdir /usr/local/share/jupyter/kernels/gophernotes
!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \
       /usr/local/share/jupyter/kernels/gophernotes
# then refresh (browser), it will now use gophernotes. Skip to golang in later cells





env: GOPATH=/root/go


## Install Pytorch C++ APIs and Go binding - GoTch

NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library

see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6

Google Colab default settings:
```bash
LD_LIBRARY_PATH=/usr/lib64-nvidia
LIBRARY_PATH=/usr/local/cuda/lib64/stubs
```
We copy directly `libtorch/lib` to those paths as a hacky way. 

In [1]:
$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip
$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local
$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/
$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/



In [2]:
import("os")
os.Setenv("CPATH", "usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include")

In [4]:
$rm -f -- go.mod
$go mod init github.com/sugarme/playgo
$go get github.com/sugarme/gotch@v0.3.2

go: creating new go.mod: module github.com/sugarme/playgo
go: downloading github.com/sugarme/gotch v0.3.2


## Download MNIST dataset

In [5]:
$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

$gunzip train-images-idx3-ubyte.gz
$gunzip train-labels-idx1-ubyte.gz
$gunzip t10k-images-idx3-ubyte.gz
$gunzip t10k-labels-idx1-ubyte.gz



## Create Convolution Neural Network (CNN)

In [6]:
import(
    "fmt"
    "time"

    "github.com/sugarme/gotch"
    "github.com/sugarme/gotch/nn"
    ts "github.com/sugarme/gotch/tensor"
    "github.com/sugarme/gotch/vision"
) 

const (
    MnistDirCNN string = "./"
    epochsCNN = 100
    batchCNN  = 256
    batchSize = 256
    LrCNN = 1e-4
)

type Net struct {
    conv1 *nn.Conv2D
    conv2 *nn.Conv2D
    fc1   *nn.Linear
    fc2   *nn.Linear
}

func newNet(vs *nn.Path) Net {
    conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
    conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
    fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
    fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())

    return Net{conv1,conv2,fc1,fc2}
}

func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
    outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
    defer outView1.MustDrop()
    outC1 := outView1.Apply(n.conv1)
    outMP1 := outC1.MaxPool2DDefault(2, true)
    defer outMP1.MustDrop()
    outC2 := outMP1.Apply(n.conv2)
    outMP2 := outC2.MaxPool2DDefault(2, true)
    outView2 := outMP2.MustView([]int64{-1, 1024}, true)
    defer outView2.MustDrop()
    outFC1 := outView2.Apply(n.fc1)
    outRelu := outFC1.MustRelu(true)
    defer outRelu.MustDrop()
    outDropout := ts.MustDropout(outRelu, 0.5, train)
    defer outDropout.MustDrop()
    return outDropout.Apply(n.fc2)
}

func trainCNN(){
    var ds *vision.Dataset
    ds = vision.LoadMNISTDir(MnistDirCNN)
    testImages := ds.TestImages
    testLabels := ds.TestLabels

    cuda := gotch.CudaBuilder(0)
    vs := nn.NewVarStore(cuda.CudaIfAvailable())

    var cnn Net = newNet(vs.Root())
    opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
    if err != nil {fmt.Print(err)}

    var bestAccuracy float64 = 0.0
    startTime := time.Now()

    for epoch := 0; epoch < epochsCNN; epoch++ {
        totalSize := ds.TrainImages.MustSize()[0]
        samples := int(totalSize)
        index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
        imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
        labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)

        batches := samples / batchSize
        batchIndex := 0
        var epocLoss *ts.Tensor
        for i := 0; i < batches; i++ {
            start := batchIndex * batchSize
            size := batchSize
            if samples-start < batchSize {break}
            batchIndex += 1

            // Indexing
            narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
            bImages := imagesTs.Idx(narrowIndex)
            bLabels := labelsTs.Idx(narrowIndex)

            bImages = bImages.MustTo(vs.Device(), true)
            bLabels = bLabels.MustTo(vs.Device(), true)

            logits := cnn.ForwardT(bImages, true)
            loss := logits.CrossEntropyForLogits(bLabels)

            opt.BackwardStep(loss)

            epocLoss = loss.MustShallowClone()
            epocLoss.Detach_()

            bImages.MustDrop()
            bLabels.MustDrop()
        }

        testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)
        fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
        if testAccuracy > bestAccuracy {
           bestAccuracy = testAccuracy
        }

        epocLoss.MustDrop()
        imagesTs.MustDrop()
        labelsTs.MustDrop()
    }

    fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
    fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
}

## Run train and evaluation

In [7]:
trainCNN()

Epoch: 0	 Loss: 0.24 	 Test accuracy: 93.34%
Epoch: 1	 Loss: 0.19 	 Test accuracy: 96.02%
Epoch: 2	 Loss: 0.09 	 Test accuracy: 97.10%
Epoch: 3	 Loss: 0.10 	 Test accuracy: 97.66%
Epoch: 4	 Loss: 0.03 	 Test accuracy: 98.13%
Epoch: 5	 Loss: 0.05 	 Test accuracy: 98.43%
Epoch: 6	 Loss: 0.09 	 Test accuracy: 98.60%
Epoch: 7	 Loss: 0.05 	 Test accuracy: 98.80%
Epoch: 8	 Loss: 0.03 	 Test accuracy: 98.80%
Epoch: 9	 Loss: 0.05 	 Test accuracy: 98.89%
Epoch: 10	 Loss: 0.02 	 Test accuracy: 98.88%
Epoch: 11	 Loss: 0.03 	 Test accuracy: 98.98%
Epoch: 12	 Loss: 0.03 	 Test accuracy: 99.05%
Epoch: 13	 Loss: 0.04 	 Test accuracy: 99.06%
Epoch: 14	 Loss: 0.02 	 Test accuracy: 99.07%
Epoch: 15	 Loss: 0.02 	 Test accuracy: 98.98%
Epoch: 16	 Loss: 0.02 	 Test accuracy: 99.06%
Epoch: 17	 Loss: 0.01 	 Test accuracy: 99.09%
Epoch: 18	 Loss: 0.02 	 Test accuracy: 99.14%
Epoch: 19	 Loss: 0.01 	 Test accuracy: 99.09%
Epoch: 20	 Loss: 0.02 	 Test accuracy: 99.12%
Epoch: 21	 Loss: 0.03 	 Test accuracy: 99.13