In [1]:
%matplotlib inline
import random
import torch
from d2l import torch as d2l

## Generating the Dataset

In [2]:
class SyntheticRegressionData(d2l.DataModule):
    def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        n = num_train + num_val;
        self.X = torch.randn(n, len(w))
        noise = torch.randn(n, 1) * noise
        self.y = torch.matmul(self.X, w.reshape((-1, 1))) + b + noise

In [3]:
data = SyntheticRegressionData(w=torch.tensor([2, -3,.4]), b=4.2)



In [4]:
print('features:', data.X[0], '\nlabel:', data.y[0])

features: tensor([-1.2981,  0.9511, -1.0873]) 
label: tensor([-1.6766])


## Reading the Dataset

In [15]:
@d2l.add_to_class(SyntheticRegressionData)
def get_tensorloader(self, tensors, train,indices=slice(0, None)):
    tensors = tuple(a[indices] for a in tensors)
    dataset = torch.utils.data.TensorDataset(*tensors)
    return torch.utils.data.DataLoader(dataset, self.batch_size, shuffle=train)
# def get_dataloader(self, train):
#     if train:
#         indices = list(range(0, self.num_train))
#         random.shuffle(indices)
#     else:
#         indices = list(range(self.num_train, self.num_train+self.num_val))
#     for i in range(0, len(indices), self.batch_size):
#         batch_indices = torch.tensor(indices[i: i + self.batch_size])
#         yield self.X[batch_indices], self.y[batch_indices]

@d2l.add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    i = slice(0, self.num_train) if train else slice(self.num_train, None)
    return self.get_tensorloader((self.X, self.y), train, i)

In [16]:
X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)
print('X value:', X, '\ny value:', y)

X shape: torch.Size([32, 3]) 
y shape: torch.Size([32, 1])
X value: tensor([[-1.1374, -0.5544,  0.5816],
        [ 0.5676,  1.0361,  0.2935],
        [ 0.5385,  0.1060,  0.9356],
        [-1.1440,  0.4479,  0.4874],
        [ 0.3709,  0.2960,  0.4912],
        [-0.3458, -1.0200,  0.9400],
        [-0.1638,  1.3628, -0.6725],
        [ 0.7966, -0.3050,  0.6049],
        [ 0.7849, -2.0347,  0.4122],
        [-0.3680, -0.8744,  1.0388],
        [-0.7126,  1.2849, -0.2203],
        [-0.8720, -0.9740,  0.6541],
        [-0.6467, -1.7180,  1.0640],
        [ 0.8427,  0.0262,  0.5981],
        [ 1.7992, -1.1749,  1.0011],
        [ 1.1797, -0.3570, -0.5026],
        [ 0.1377, -1.0657, -1.3413],
        [ 0.1467, -1.4413,  0.1901],
        [ 0.0165,  0.2945,  0.7226],
        [ 0.3449, -0.2068, -0.4132],
        [-0.4573,  0.6538,  0.1530],
        [ 0.2634,  0.4463,  0.2283],
        [ 0.7524, -0.0076,  0.4012],
        [ 1.4667,  1.3202,  0.2450],
        [ 0.6752, -0.7219, -0.7657],
       