In [None]:
#r "nuget: Yueyinqiu.Su.D2lTorchSharp, *-*"

using ScottPlot;
using System.Diagnostics;
using TorchSharp;
using TorchSharp.Modules;
using Yueyinqiu.Su.D2lTorchSharp;

D2l.Notebook.PrepareAll();

3.5.1. 读取数据集

TorchSharp 得到的数据集直接就以张量形式表示，不需要 `ToTensor` ：

In [None]:
var mnist_train = torchvision.datasets.FashionMNIST(
    root: "../gitignored/data", train: true, download: true);
var mnist_test = torchvision.datasets.FashionMNIST(
    root: "../gitignored/data", train: false, download: true);

In [None]:
(mnist_train.Count, mnist_test.Count)

In [None]:
mnist_train.GetTensor(0).Values.First().shape

由于后面绘图方式不同，这边我们改为传入单个 `label` ：

In [None]:
string GetFashionMnistLabel(torch.Tensor label)
{
    string[] text_labels = ["t-shirt", "trouser", "pullover", "dress", "coat", 
        "sandal", "shirt", "sneaker", "bag", "ankle boot"];
    return text_labels[label.ToInt32()];
}

这里我们也不提供 `show_images` ，而是使用 d2l 包提供的 `ToSkBitmap` 扩展方法实现 `subplot` 的效果。

In [None]:
var loader = torch.utils.data.DataLoader(mnist_train, batchSize: 18, disposeBatch: false);
var Xy = loader.First();

var X = Xy["data"].reshape(18, 28, 28);
var y = Xy["label"];

IEnumerable<Plot> CreatePlots()
{
    for (int i = 0; i < X.size(0); i++)
    {
        var img = X[i, .., ..];
        
        var plot = new Plot();
        plot.Add.Heatmap(img);
        plot.Title(GetFashionMnistLabel(y[i]));

        foreach (var axis in plot.Axes.GetAxes())
            axis.IsVisible = false;
        plot.HideGrid();

        yield return plot;
    }
}
CreatePlots().ToSKBitmap(9, 100, 140)

3.5.2. 读取小批量

In [None]:
var batch_size = 256;

int get_dataloader_workers() => 4;

var train_iter = torch.utils.data.DataLoader(
    mnist_train, batch_size, shuffle: true, 
    num_worker: get_dataloader_workers());

TorchSharp 在 C# 层面自己实现了 `DataLoader` ，它使用 `Task` 来实现异步读取，而不像 PyTorch 一样启动多个进程。

In [None]:
var timer = Stopwatch.StartNew();
foreach (var Xy in train_iter)
{
    var X = Xy["data"];
    var y = Xy["label"];
    continue;
}
timer.Stop();
$"{timer.Elapsed.TotalSeconds:0.00} sec"

In [None]:
(DataLoader trainIter, DataLoader testIter) load_data_fashion_mnist(
    int batch_size, int? resize = null)
{
    var trans = resize.HasValue ? 
        torchvision.transforms.Resize(resize.Value) :
        null;
    var mnist_train = torchvision.datasets.FashionMNIST(
        root: "../gitignored/data", train: true, target_transform: trans, download: true);
    var mnist_test = torchvision.datasets.FashionMNIST(
        root: "../gitignored/data", train: false, target_transform: trans, download: true);
    return (
        torch.utils.data.DataLoader(
            mnist_train, batch_size, shuffle: true,
            num_worker: get_dataloader_workers()),
        torch.utils.data.DataLoader(
            mnist_test, batch_size, shuffle: false,
            num_worker: get_dataloader_workers())
        );
}

In [None]:
var (train_iter, test_iter) = load_data_fashion_mnist(32, resize: 64);
using (train_iter)
using (test_iter)
{
    foreach (var Xy in train_iter)
    {
        var X = Xy["data"];
        var y = Xy["label"];
        Console.WriteLine(
            $"{X.shape.ToArrayString()} {X.dtype} {y.shape.ToArrayString()} {y.dtype}");
        break;
    }
}

在 d2l 包中，我们使用 `DataLoader<Dictionary<string, Tensor>, (Tensor data, Tensor label)>` 作为返回值。它在迭代时返回元组而不是字典，使用起来和原文更加一致：

In [None]:
var (train_iter, test_iter) = D2l.Ch3.load_data_fashion_mnist(32, resize: 64);
using (train_iter)
using (test_iter)
{
    foreach (var (X, y) in train_iter)
    {
        Console.WriteLine(
            $"{X.shape.ToArrayString()} {X.dtype} {y.shape.ToArrayString()} {y.dtype}");
        break;
    }
}