# MNIST Training

In [3]:
$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$wget  -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

$gunzip train-images-idx3-ubyte.gz
$gunzip train-labels-idx1-ubyte.gz
$gunzip t10k-images-idx3-ubyte.gz
$gunzip t10k-labels-idx1-ubyte.gz



In [5]:
import (
    "fmt"

    "github.com/sugarme/gotch"
    "github.com/sugarme/gotch/nn"
    ts "github.com/sugarme/gotch/tensor"
    "github.com/sugarme/gotch/vision"
)

## Linear regression

In [6]:
const (
    ImageDimNN    int64  = 784
    HiddenNodesNN int64  = 128
    LabelNN       int64  = 10
    MnistDirNN    string = "./"

    epochsNN = 200

    LrNN = 1e-3
)

In [7]:
func netInit(vs *nn.Path) ts.Module {
    n := nn.Seq()
    n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
    n.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
        return xs.MustRelu(false)
    }))
    n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
    return n
}

In [8]:
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {

    logits := m.Forward(trainX)
    loss := logits.CrossEntropyForLogits(trainY)

    opt.BackwardStep(loss)

    testLogits := m.Forward(testX)
    testAccuracy := testLogits.AccuracyForLogits(testY)
    accuracy := testAccuracy.Float64Values()[0] * 100
    testAccuracy.MustDrop()
    lossVal := loss.Float64Values()[0]
    loss.MustDrop()

    fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
}


In [9]:
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil { fmt.Print(err)}

for epoch := 0; epoch < epochsNN; epoch++ {
    train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
}

Epoch: 0 	 Loss: 2.308 	 Test accuracy: 21.71%
Epoch: 1 	 Loss: 2.240 	 Test accuracy: 33.96%
Epoch: 2 	 Loss: 2.173 	 Test accuracy: 46.14%
Epoch: 3 	 Loss: 2.104 	 Test accuracy: 55.92%
Epoch: 4 	 Loss: 2.032 	 Test accuracy: 59.80%
Epoch: 5 	 Loss: 1.956 	 Test accuracy: 61.91%
Epoch: 6 	 Loss: 1.879 	 Test accuracy: 64.24%
Epoch: 7 	 Loss: 1.803 	 Test accuracy: 66.79%
Epoch: 8 	 Loss: 1.727 	 Test accuracy: 69.53%
Epoch: 9 	 Loss: 1.653 	 Test accuracy: 72.17%
Epoch: 10 	 Loss: 1.580 	 Test accuracy: 74.54%
Epoch: 11 	 Loss: 1.509 	 Test accuracy: 76.42%
Epoch: 12 	 Loss: 1.440 	 Test accuracy: 77.76%
Epoch: 13 	 Loss: 1.373 	 Test accuracy: 78.83%
Epoch: 14 	 Loss: 1.308 	 Test accuracy: 79.84%
Epoch: 15 	 Loss: 1.246 	 Test accuracy: 80.65%
Epoch: 16 	 Loss: 1.186 	 Test accuracy: 81.24%
Epoch: 17 	 Loss: 1.129 	 Test accuracy: 81.88%
Epoch: 18 	 Loss: 1.075 	 Test accuracy: 82.53%
Epoch: 19 	 Loss: 1.024 	 Test accuracy: 82.90%
Epoch: 20 	 Loss: 0.976 	 Test accuracy: 83.46%
Ep

# Neural Network

In [1]:
import (
    "fmt"

    "github.com/sugarme/gotch"
    "github.com/sugarme/gotch/nn"
    ts "github.com/sugarme/gotch/tensor"
    "github.com/sugarme/gotch/vision"
)

const (
    ImageDimNN    int64  = 784
    HiddenNodesNN int64  = 128
    LabelNN       int64  = 10
    MnistDirNN    string = "./"
    epochsNN = 200
    LrNN = 1e-3
)

func netInit(vs *nn.Path) ts.Module {
    n := nn.Seq()
    n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
    n.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
        return xs.MustRelu(false)
    }))
    n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
    return n
}

func trainNN(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
    logits := m.Forward(trainX)
    loss := logits.CrossEntropyForLogits(trainY)
    opt.BackwardStep(loss)

    testLogits := m.Forward(testX)
    testAccuracy := testLogits.AccuracyForLogits(testY)
    accuracy := testAccuracy.Float64Values()[0] * 100
    testAccuracy.MustDrop()
    lossVal := loss.Float64Values()[0]
    loss.MustDrop()

    fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
}

func runNN() {
    var ds *vision.Dataset
    ds = vision.LoadMNISTDir(MnistDirNN)
    vs := nn.NewVarStore(gotch.CPU)
    net := netInit(vs.Root())
    opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
    if err != nil { fmt.Print(err) }

    for epoch := 0; epoch < epochsNN; epoch++ {
        trainNN(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
    }
}

runNN()

Epoch: 0 	 Loss: 2.299 	 Test accuracy: 16.22%
Epoch: 1 	 Loss: 2.238 	 Test accuracy: 25.63%
Epoch: 2 	 Loss: 2.177 	 Test accuracy: 41.26%
Epoch: 3 	 Loss: 2.114 	 Test accuracy: 56.68%
Epoch: 4 	 Loss: 2.048 	 Test accuracy: 64.31%
Epoch: 5 	 Loss: 1.978 	 Test accuracy: 68.49%
Epoch: 6 	 Loss: 1.907 	 Test accuracy: 71.51%
Epoch: 7 	 Loss: 1.834 	 Test accuracy: 73.96%
Epoch: 8 	 Loss: 1.761 	 Test accuracy: 75.71%
Epoch: 9 	 Loss: 1.688 	 Test accuracy: 76.98%
Epoch: 10 	 Loss: 1.615 	 Test accuracy: 77.91%
Epoch: 11 	 Loss: 1.544 	 Test accuracy: 78.76%
Epoch: 12 	 Loss: 1.473 	 Test accuracy: 79.40%
Epoch: 13 	 Loss: 1.404 	 Test accuracy: 79.86%
Epoch: 14 	 Loss: 1.338 	 Test accuracy: 80.37%
Epoch: 15 	 Loss: 1.273 	 Test accuracy: 80.78%
Epoch: 16 	 Loss: 1.211 	 Test accuracy: 81.32%
Epoch: 17 	 Loss: 1.152 	 Test accuracy: 81.80%
Epoch: 18 	 Loss: 1.096 	 Test accuracy: 82.01%
Epoch: 19 	 Loss: 1.043 	 Test accuracy: 82.47%
Epoch: 20 	 Loss: 0.994 	 Test accuracy: 82.79%
Ep

# Convolutional Neural Network (CNN)

In [11]:
import(
    "fmt"
    "time"

    "github.com/sugarme/gotch"
    "github.com/sugarme/gotch/nn"
    ts "github.com/sugarme/gotch/tensor"
    "github.com/sugarme/gotch/vision"
) 

const (
    MnistDirCNN string = "./"
    epochsCNN = 100
    batchCNN  = 256
    batchSize = 256
    LrCNN = 1e-4
)

type Net struct {
    conv1 *nn.Conv2D
    conv2 *nn.Conv2D
    fc1   *nn.Linear
    fc2   *nn.Linear
}

func newNet(vs *nn.Path) Net {
    conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
    conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
    fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
    fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())

    return Net{conv1,conv2,fc1,fc2}
}

func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
    outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
    defer outView1.MustDrop()
    outC1 := outView1.Apply(n.conv1)
    outMP1 := outC1.MaxPool2DDefault(2, true)
    defer outMP1.MustDrop()
    outC2 := outMP1.Apply(n.conv2)
    outMP2 := outC2.MaxPool2DDefault(2, true)
    outView2 := outMP2.MustView([]int64{-1, 1024}, true)
    defer outView2.MustDrop()
    outFC1 := outView2.Apply(n.fc1)
    outRelu := outFC1.MustRelu(true)
    defer outRelu.MustDrop()
    outDropout := ts.MustDropout(outRelu, 0.5, train)
    defer outDropout.MustDrop()
    return outDropout.Apply(n.fc2)
}

func trainCNN(){
    var ds *vision.Dataset
    ds = vision.LoadMNISTDir(MnistDirCNN)
    testImages := ds.TestImages
    testLabels := ds.TestLabels

    cuda := gotch.CudaBuilder(0)
    vs := nn.NewVarStore(cuda.CudaIfAvailable())
    // vs := nn.NewVarStore(gotch.CPU)

    var cnn Net = newNet(vs.Root())
    opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
    if err != nil {fmt.Print(err)}

    var bestAccuracy float64 = 0.0
    startTime := time.Now()

    for epoch := 0; epoch < epochsCNN; epoch++ {
        totalSize := ds.TrainImages.MustSize()[0]
        samples := int(totalSize)
        index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
        imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
        labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)

        batches := samples / batchSize
        batchIndex := 0
        var epocLoss *ts.Tensor
        for i := 0; i < batches; i++ {
            start := batchIndex * batchSize
            size := batchSize
            if samples-start < batchSize {break}
            batchIndex += 1

            // Indexing
            narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
            bImages := imagesTs.Idx(narrowIndex)
            bLabels := labelsTs.Idx(narrowIndex)

            bImages = bImages.MustTo(vs.Device(), true)
            bLabels = bLabels.MustTo(vs.Device(), true)

            logits := cnn.ForwardT(bImages, true)
            loss := logits.CrossEntropyForLogits(bLabels)

            // loss = loss.MustSetRequiresGrad(true, false)
            opt.BackwardStep(loss)

            epocLoss = loss.MustShallowClone()
            epocLoss.Detach_()

            // fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])

            bImages.MustDrop()
            bLabels.MustDrop()
        }

        // vs.Freeze()
        testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)
        // vs.Unfreeze()
        fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
        if testAccuracy > bestAccuracy {
           bestAccuracy = testAccuracy
        }

        epocLoss.MustDrop()
        imagesTs.MustDrop()
        labelsTs.MustDrop()
    }

    fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
    fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
    
}

In [12]:
trainCNN()

Epoch: 0	 Loss: 0.24 	 Test accuracy: 92.96%
Epoch: 1	 Loss: 0.15 	 Test accuracy: 95.95%
Epoch: 2	 Loss: 0.13 	 Test accuracy: 97.41%
Epoch: 3	 Loss: 0.10 	 Test accuracy: 97.73%
Epoch: 4	 Loss: 0.10 	 Test accuracy: 98.25%
Epoch: 5	 Loss: 0.10 	 Test accuracy: 98.40%
Epoch: 6	 Loss: 0.12 	 Test accuracy: 98.54%
Epoch: 7	 Loss: 0.07 	 Test accuracy: 98.67%
Epoch: 8	 Loss: 0.06 	 Test accuracy: 98.77%
Epoch: 9	 Loss: 0.04 	 Test accuracy: 98.82%
Epoch: 10	 Loss: 0.06 	 Test accuracy: 98.90%
Epoch: 11	 Loss: 0.03 	 Test accuracy: 98.96%
Epoch: 12	 Loss: 0.02 	 Test accuracy: 99.00%
Epoch: 13	 Loss: 0.06 	 Test accuracy: 98.99%
Epoch: 14	 Loss: 0.03 	 Test accuracy: 99.07%
Epoch: 15	 Loss: 0.02 	 Test accuracy: 99.12%
Epoch: 16	 Loss: 0.01 	 Test accuracy: 99.16%
Epoch: 17	 Loss: 0.04 	 Test accuracy: 99.13%
Epoch: 18	 Loss: 0.05 	 Test accuracy: 99.13%
Epoch: 19	 Loss: 0.04 	 Test accuracy: 99.15%
Epoch: 20	 Loss: 0.02 	 Test accuracy: 99.19%
Epoch: 21	 Loss: 0.01 	 Test accuracy: 99.02