Skip to content

Commit

Permalink
dotnet#867 add RoiAlign class
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuan8 committed Dec 13, 2022
1 parent a9299c1 commit 63d0447
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/Native/LibTorchSharp/THSVision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,14 @@ Tensor THSVision_nms(const Tensor dets, const Tensor scores, double iou_threshol
return NULL;

CATCH_TENSOR(nms(*dets, *scores, iou_threshold));
}

Tensor THSVision_roi_align(const Tensor input, const Tensor rois, double spatial_scale, long pooled_height, long pooled_width, long sampling_ratio, bool aligned)
{
typedef at::Tensor(*TorchVisionFunc)(at::Tensor&, at::Tensor&, double, long, long, long, bool);
auto roi_align = (TorchVisionFunc)LoadNativeSymbol("libtorchvision.dll", "?roi_align@ops@vision@@YA?AVTensor@at@@AEBV34@0N_J11_N@Z");
if (roi_align == NULL)
return NULL;

CATCH_TENSOR(roi_align(*input, *rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned));
}
3 changes: 2 additions & 1 deletion src/Native/LibTorchSharp/THSVision.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ EXPORT_API(void) THSVision_BRGA_RGBA(const uint8_t* inputBytes, uint8_t* redByte

EXPORT_API(void) THSVision_RGB_BRGA(const uint8_t* inputBytes, uint8_t* outBytes, int64_t inputChannelCount, int64_t imageSize);

EXPORT_API(Tensor) THSVision_nms(const Tensor dets, const Tensor scores, double iou_threshold);
EXPORT_API(Tensor) THSVision_nms(const Tensor dets, const Tensor scores, double iou_threshold);
EXPORT_API(Tensor) THSVision_roi_align(const Tensor input, const Tensor rois, double spatial_scale, long pooled_height, long pooled_width, long sampling_ratio, bool aligned);
65 changes: 65 additions & 0 deletions src/TorchSharp/NN/Utils/ModulesUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.

// A number of implementation details in this file have been translated from the Python version of torchvision,
// largely located in the files found in this folder:
//
// https://github.com/pytorch/pytorch/blob/3a02873183e81ed0af76ab46b01c3829b8dc1d35/torch/nn/modules/utils.py
//
// The origin has the following copyright notice and license:
//
// https://github.com/pytorch/vision/blob/main/LICENSE
//

using System;
using System.Collections.Generic;
using System.Text;

namespace TorchSharp
{
using System.Linq;
using System.Security.Cryptography;
using System.Xml.Linq;
using Modules;

namespace Modules
{
public static class ModulesUtils
{
public class _ntuple<T>
{
private int n;
private string name;

public _ntuple(int n, string name)
{
this.n = n;
this.name = name;
}

public IEnumerable<T> parse(object x)
{
if (x is IEnumerable<T> list)
return list;
return Enumerable.Repeat((T)x, n);
}
}

public static IEnumerable<T> _single<T>(object x)
{
return new _ntuple<T>(1, "_single").parse(x);
}
public static IEnumerable<T> _pair<T>(object x)
{
return new _ntuple<T>(2, "_pair").parse(x);
}
public static IEnumerable<T> _triple<T>(object x)
{
return new _ntuple<T>(3, "_triple").parse(x);
}
public static IEnumerable<T> _quadruple<T>(object x)
{
return new _ntuple<T>(4, "_quadruple").parse(x);
}
}
}
}
3 changes: 3 additions & 0 deletions src/TorchVision/LibTorchSharp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,8 @@ internal static class LibTorchSharp

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSVision_nms(IntPtr dets, IntPtr scores, double iou_threshold);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSVision_roi_align(IntPtr input, IntPtr rois, double spatial_scale, long pooled_height, long pooled_width, long sampling_ratio, bool aligned);
}
}
115 changes: 115 additions & 0 deletions src/TorchVision/Ops/RoiAlign.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.

// A number of implementation details in this file have been translated from the Python version of torchvision,
// largely located in the files found in this folder:
//
// https://github.com/pytorch/vision/blob/f56e6f63aa1d37e648b0c4cb951ce26292238c53/torchvision/ops/roi_align.py
//
// The origin has the following copyright notice and license:
//
// https://github.com/pytorch/vision/blob/main/LICENSE
//

using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Security.Cryptography;
using System.Text;

using static TorchSharp.torch;

namespace TorchSharp
{
public static partial class torchvision
{
public static partial class ops
{
/// <summary>
/// Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
/// </summary>
/// <param name="input">(Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
/// contains ``C`` feature maps of dimensions ``H x W``.
/// If the tensor is quantized, we expect a batch size of ``N == 1``.</param>
/// <param name="boxes">
/// (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
/// format where the regions will be taken from.
/// The coordinate must satisfy ``0 &lt;= x1 &lt; x2`` and ``0 &lt;= y1 &lt; y2``.
/// If a single Tensor is passed, then the first column should
/// contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
/// If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
/// in the batch.
/// </param>
/// <param name="output_size">(int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
/// is performed, as (height, width).</param>
/// <param name="spatial_scale">a scaling factor that maps the box coordinates to
/// the input coordinates. For example, if your boxes are defined on the scale
/// of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
/// the original image), you'll want to set this to 0.5. Default: 1.0</param>
/// <param name="sampling_ratio">
/// number of sampling points in the interpolation grid
/// used to compute the output value of each pooled output bin. If > 0,
/// then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
/// &lt;= 0, then an adaptive number of grid points are used (computed as
/// ``ceil(roi_width / output_width)``, and likewise for height). Default: -1
/// </param>
/// <param name="aligned">
/// If False, use the legacy implementation.
/// If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
/// neighboring pixel indices. This version is used in Detectron2
/// </param>
/// <returns>Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.</returns>
public static Tensor roi_align(
Tensor input,
object boxes,
object output_size,
float spatial_scale = 1.0f,
int sampling_ratio = -1,
bool aligned = false
)
{
check_roi_boxes_shape(boxes);
object rois = boxes;
var output_size_list = Modules.ModulesUtils._pair<int>(output_size).ToArray();
if (rois is List<Tensor> list)
rois = convert_boxes_to_roi_format(list);
var roisTensor = rois as Tensor;
var res = LibTorchSharp.THSVision_roi_align(
input.Handle, roisTensor.Handle, spatial_scale, output_size_list[0], output_size_list[1], sampling_ratio, aligned);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// see roi_align.
/// </summary>
public class RoIAlign : nn.Module
{
private object output_size;
private float spatial_scale;
private int sampling_ratio;
private bool aligned;

public RoIAlign(object output_size, float spatial_scale, int sampling_ratio, bool aligned = false)
: base(string.Empty)
{
this.output_size = output_size;
this.spatial_scale = spatial_scale;
this.sampling_ratio = sampling_ratio;
this.aligned = aligned;
}

public Tensor forward(Tensor input, object rois)
{
return torchvision.ops.roi_align(input, rois, this.output_size, this.spatial_scale, this.sampling_ratio, this.aligned);
}

public override string ToString()
{
return string.Format("{0}, output_size={1}, spatial_scale={2}, sampling_ratio={3}, aligned={4}",
this.GetType().Name, this.output_size, this.spatial_scale, this.sampling_ratio, this.aligned);
}
}
}
}
}
150 changes: 150 additions & 0 deletions src/TorchVision/Ops/Utils.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.

// A number of implementation details in this file have been translated from the Python version of torchvision,
// largely located in the files found in this folder:
//
// https://github.com/pytorch/vision/blob/f56e6f63aa1d37e648b0c4cb951ce26292238c53/torchvision/ops/_utils.py
//
// The origin has the following copyright notice and license:
//
// https://github.com/pytorch/vision/blob/main/LICENSE
//

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torchvision;
using TorchSharp.Utils;

#nullable enable
namespace TorchSharp
Expand All @@ -11,6 +29,103 @@ public static partial class torchvision
{
public static partial class ops
{
/// <summary>
/// Efficient version of torch.cat that avoids a copy if there is only a single element in a list
/// </summary>
/// <param name="tensors"></param>
/// <param name="dim"></param>
/// <returns></returns>
public static Tensor _cat(List<Tensor> tensors, long dim = 0)
{
if (tensors.Count == 1)
return tensors[0];
return torch.cat(tensors, dim);
}

/// <summary>
/// Converts list of Tensor to roi format.
/// </summary>
/// <param name="boxes"></param>
/// <returns></returns>
public static Tensor convert_boxes_to_roi_format(List<Tensor> boxes)
{
var concat_boxes = _cat(boxes, dim: 0);
var temp = new List<Tensor>();
for (int i = 0; i < boxes.Count; i++)
temp.Add(torch.full_like(boxes[i][TensorIndex.Colon, TensorIndex.Slice(stop: 1)], i));

var ids = _cat(temp, dim: 0);

var rois = torch.cat(new List<Tensor> { ids, concat_boxes }, dim: 1);
return rois;
}

/// <summary>
/// Checks if format of boxes is correct.
/// </summary>
/// <param name="boxes"></param>
public static void check_roi_boxes_shape(object boxes)
{
if (boxes is List<Tensor> boxesList)
foreach (var _tensor in boxesList)
Debug.Assert(
_tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
);
else if (boxes is Tensor tensor)
Debug.Assert(tensor.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]");
else
Debug.Assert(false, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]");
}

/// <summary>
/// Splits normalization parameters.
/// </summary>
/// <param name="model"></param>
/// <param name="norm_classes"></param>
/// <returns>Tuple of normalization parameters and othere parameters.</returns>
/// <exception cref="ArgumentException"></exception>
public static (List<Tensor>, List<Tensor>) split_normalization_params(
nn.Module model, List<Type>? norm_classes = null
)
{
if (norm_classes == null)
norm_classes = new List<Type> {
//nn.modules.batchnorm._BatchNorm,
typeof(LayerNorm),
typeof(GroupNorm),
//nn.modules.instancenorm._InstanceNorm,
typeof(LocalResponseNorm),
};

foreach (var t in norm_classes)
if (!(t.IsSubclassOf(typeof(nn.Module))))
throw new ArgumentException(string.Format("Class {0} is not a subclass of nn.Module.", t));

var classes = norm_classes;

var norm_params = new List<Tensor>();
var other_params = new List<Tensor>();
foreach (var named_module in model.named_modules()) {
var module = named_module.module;
var named_children = module.named_children().GetEnumerator();
if (named_children.MoveNext()) {
foreach (var p in module.named_parameters(false))
if (p.parameter.requires_grad)
other_params.Add(p.parameter);

} else if (classes.Contains(module.GetType())) {
foreach (var p in module.named_parameters(false))
if (p.parameter.requires_grad)
norm_params.Add(p.parameter);
} else {
foreach (var p in module.named_parameters(false))
if (p.parameter.requires_grad)
other_params.Add(p.parameter);
}
}
return (norm_params, other_params);
}

/// <summary>
/// Protects from numerical overflows in multiplications by upcasting to the equivalent higher type.
/// </summary>
Expand All @@ -21,6 +136,41 @@ public static Tensor _upcast(Tensor t)
else
return t.dtype == torch.int32 || t.dtype == torch.int64 ? t : t.@int();
}

/// <summary>
/// Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
/// </summary>
/// <param name="t"></param>
/// <returns></returns>
public static Tensor _upcast_non_float(Tensor t)
{
if (t.dtype != torch.float32 && t.dtype != torch.float64)
return t.@float();
return t;
}


public static (Tensor, Tensor) _loss_inter_union(
Tensor boxes1,
Tensor boxes2
)
{
var (x1, y1, x2, y2, _) = boxes1.unbind(dimension: -1);
var (x1g, y1g, x2g, y2g, __) = boxes2.unbind(dimension: -1);

//# Intersection keypoints
var xkis1 = torch.max(x1, x1g);
var ykis1 = torch.max(y1, y1g);
var xkis2 = torch.min(x2, x2g);
var ykis2 = torch.min(y2, y2g);

var intsctk = torch.zeros_like(x1);
var mask = (ykis2 > ykis1) & (xkis2 > xkis1);
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]);
var unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk;

return (intsctk, unionk);
}
}
}
}

0 comments on commit 63d0447

Please sign in to comment.