Skip to content

Commit

Permalink
Merge pull request #114 from sumitdvlp/AvgPool
Browse files Browse the repository at this point in the history
avg pool 2d
  • Loading branch information
interesaaat committed Aug 2, 2019
2 parents 52662d8 + 5a909d7 commit 7154f65
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/Native/LibTorchSharp/THSNN.cpp
Expand Up @@ -148,6 +148,18 @@ Tensor THSNN_adaptiveAvgPool2DApply(const Tensor tensor, const int length, const
return new torch::Tensor(torch::adaptive_avg_pool2d(*tensor, at::IntList(outputSize, length)));
}

Tensor THSNN_avgPool2DApply(const Tensor tensor,
const int kernelSizeLength,
const int64_t* kernelSize,
const int strideLength,
const int64_t* stride)
{
return new torch::Tensor(torch::avg_pool2d(
*tensor,
at::IntList(kernelSize, kernelSizeLength),
at::IntList(stride, strideLength)));
}

Tensor THSNN_logSoftMaxApply(const Tensor tensor, const int64_t dimension)
{
return new torch::Tensor(torch::log_softmax(*tensor, dimension));
Expand Down
8 changes: 8 additions & 0 deletions src/Native/LibTorchSharp/THSNN.h
Expand Up @@ -71,6 +71,14 @@ EXPORT_API(Tensor) THSNN_maxPool2DApply(
// Applies a 2D adaptive average pooling over an input signal composed of several input planes.
EXPORT_API(Tensor) THSNN_adaptiveAvgPool2DApply(const Tensor tensor, const int length, const int64_t* outputSize);

// Applies a avgpool 2d on the input tensor.
EXPORT_API(Tensor) THSNN_avgPool2DApply(
const Tensor tensor,
const int kernelSizeLength,
const int64_t* kernelSize,
const int strideLength,
const int64_t* stride);

// Applies a log soft max on the input tensor.
EXPORT_API(Tensor) THSNN_logSoftMaxApply(const Tensor tensor, const int64_t dimension);

Expand Down
26 changes: 26 additions & 0 deletions src/TorchSharp/NN/AvgPool2D.cs
@@ -0,0 +1,26 @@
using System;
using System.Runtime.InteropServices;
using TorchSharp.Tensor;

namespace TorchSharp.NN
{
public class AvgPool2D : FunctionalModule<AdaptiveAvgPool2D>
{
private readonly long[] _kernelSize;
private readonly long[] _stride;

internal AvgPool2D(long[] kernelSize, long[] stride) : base()
{
_kernelSize = kernelSize;
_stride = stride ?? new long[0];
}

[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_avgPool2DApply(IntPtr tensor, int kernelSizeLength, long[] kernelSize, int strideLength, long[] stride);

public override TorchTensor Forward(TorchTensor tensor)
{
return new TorchTensor(THSNN_avgPool2DApply(tensor.Handle, _kernelSize.Length, _kernelSize, _stride.Length, _stride));
}
}
}
8 changes: 8 additions & 0 deletions src/TorchSharp/NN/Module.cs
Expand Up @@ -155,6 +155,14 @@ static public AdaptiveAvgPool2D AdaptiveAvgPool2D(params long[] outputSize)
return new AdaptiveAvgPool2D(outputSize);
}

static public TorchTensor AvgPool2D(TorchTensor x, long[] kernelSize, long[] stride = null)
{
using (var m = new AvgPool2D(kernelSize, stride))
{
return m.Forward(x);
}
}

static public TorchTensor AdaptiveAvgPool2D(TorchTensor x, params long[] outputSize)
{
using (var a = new AdaptiveAvgPool2D(outputSize))
Expand Down
19 changes: 18 additions & 1 deletion test/TorchSharpTest/TorchSharp/TorchSharp.cs
Expand Up @@ -3,6 +3,7 @@
using System.Linq;
using System.Runtime.InteropServices;
using TorchSharp.JIT;
using TorchSharp.NN;
using TorchSharp.Tensor;
using Xunit;

Expand Down Expand Up @@ -1226,5 +1227,21 @@ public void TestMNISTLoaderWithEpochs()
Assert.Equal(size * epochs, i * 32);
}
}

[Fact]
public void AvgPool2DObjectInitialized()
{
TorchTensor ones = FloatTensor.Ones(new long[] { 2, 2, 2 });
var obj = NN.Module.AvgPool2D(ones, new long[] { 2 }, new long[] { 2 });
Assert.Equal(typeof(TorchTensor), obj.GetType());
}

[Fact]
public void MaxPool2DObjectInitialized()
{
TorchTensor ones = FloatTensor.Ones(new long[] { 2, 2, 2 });
var obj = NN.Module.MaxPool2D(ones, new long[] { 2 }, new long[] { 2 });
Assert.Equal(typeof(TorchTensor), obj.GetType());
}
}
}
}
2 changes: 1 addition & 1 deletion test/TorchSharpTest/TorchSharpTest.csproj
Expand Up @@ -47,7 +47,7 @@
</ItemGroup>

<ItemGroup Condition="'$(NativeTargetArchitecture)' == 'x64' and $([MSBuild]::IsOSPlatform('linux'))">
<NativeAssemblyReference Include="torch" ExtraExtension=".1"/>
<NativeAssemblyReference Include="torch" ExtraExtension=".1" />
<NativeAssemblyReference Include="gomp-8bba0e50" ExtraExtension=".1" />
</ItemGroup>

Expand Down

0 comments on commit 7154f65

Please sign in to comment.