Skip to content

Commit

Permalink
dotnet#867 Add IntermediateLayerGetter
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuan8 committed Dec 16, 2022
1 parent 80a2d3d commit 40517b9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/TorchSharp/NN/ModuleDict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void clear()
/// Return an enumeration of the ParameterDict key/value pairs.
/// </summary>
/// <returns></returns>
public IEnumerator<(string, T)> items() => _list.GetEnumerator();
public IEnumerable<(string, T)> items() => _list;

/// <summary>
/// Return the ParameterDict keys.
Expand Down
83 changes: 83 additions & 0 deletions src/TorchVision/models/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
// ==============================================================================

using System;
using System.Collections.Generic;
using TorchSharp.Modules;
using TorchSharp.Utils;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;

namespace TorchSharp
{
Expand All @@ -35,6 +40,84 @@ public static partial class models
{
internal static partial class _utils
{
/// <summary>
/// Module wrapper that returns intermediate layers from a model
/// It has a strong assumption that the modules have been registered
/// into the model in the same order as they are used.
/// This means that one should **not** reuse the same nn.Module
/// twice in the forward if you want this to work.
/// Additionally, it is only able to query submodules that are directly
/// assigned to the model. So if `model` is passed, `model.feature1` can
/// be returned, but not `model.feature1.layer2`.
///
/// >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
/// >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
/// >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
/// >>> {'layer1': 'feat1', 'layer3': 'feat2'})
/// >>> out = new_m(torch.rand(1, 3, 224, 224))
/// >>> print([(k, v.shape) for k, v in out.items()])
/// >>> [('feat1', torch.Size([1, 64, 56, 56])),
/// >>> ('feat2', torch.Size([1, 256, 14, 14]))]
/// </summary>
internal class IntermediateLayerGetter : ModuleDict<Module<Tensor, Tensor>>
{
private Dictionary<string, string> return_layers;

/// <summary>
/// Constructor.
/// </summary>
/// <param name="model">model on which we will extract the features</param>
/// <param name="return_layers">
/// a dict containing the names
/// of the modules for which the activations will be returned as
/// the key of the dict, and the value of the dict is the name
/// of the returned activation (which the user can specify).
/// </param>
public IntermediateLayerGetter(nn.Module model, Dictionary<string, string> return_layers)
{
foreach (var key in return_layers.Keys) {
bool exists = false;
foreach (var (name, _) in model.named_children()) {
if (name == key) {
exists = true;
break;
}
}
if (!exists) {
throw new ArgumentException("return_layers are not present in model");
}
}
var orig_return_layers = new Dictionary<string, string>();
foreach (var pair in return_layers)
orig_return_layers[pair.Key] = pair.Value;
var layers = new OrderedDict<string, nn.Module<Tensor, Tensor>>();
foreach (var (name, module) in model.named_children()) {
layers[name] = module as nn.Module<Tensor, Tensor>;
if (return_layers.ContainsKey(name)) {
return_layers.Remove(name);
}
if (return_layers.Count == 0)
break;
}
foreach (var pair in layers)
base.Add(pair);
this.return_layers = orig_return_layers;
}

public OrderedDict<string, Tensor> forward(Tensor x)
{
OrderedDict<string, Tensor> @out = new OrderedDict<string, Tensor>();
foreach (var (name, module) in this.items()) {
x = module.forward(x);
if (this.return_layers.ContainsKey(name)) {
string out_name = this.return_layers[name];
@out[out_name] = x;
}
}
return @out;
}
}

/// <summary>
/// This function is taken from the original tf repo.
/// It ensures that all layers have a channel number that is divisible by 8
Expand Down

0 comments on commit 40517b9

Please sign in to comment.