package main import ( "fmt" "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" ts "github.com/sugarme/gotch/tensor" ) type Net struct { conv1 *nn.Conv2D conv2 *nn.Conv2D fc *nn.Linear } func newNet(vs *nn.Path) *Net { conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig()) conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig()) fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig()) return &Net{ conv1, conv2, fc, } } func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor { xs = xs.MustView([]int64{-1, 1, 8, 8}, false) outC1 := xs.Apply(n.conv1) outMP1 := outC1.MaxPool2DDefault(2, true) defer outMP1.MustDrop() outC2 := outMP1.Apply(n.conv2) outMP2 := outC2.MaxPool2DDefault(2, true) outView2 := outMP2.MustView([]int64{-1, 10}, true) defer outView2.MustDrop() outFC := outView2.Apply(n.fc) return outFC.MustRelu(true) } func main() { vs := nn.NewVarStore(gotch.CPU) net := newNet(vs.Root()) xs := ts.MustOnes([]int64{8, 8}, gotch.Float, gotch.CPU) logits := net.ForwardT(xs, false) fmt.Printf("Logits: %0.3f", logits) } //Logits: 0.000 0.000 0.000 0.225 0.321 0.147 0.000 0.207 0.000 0.000