## Pytorch : Wine Dataset

In [1]:
import torch
import numpy as np
import pandas as pd
import math

## Custom Dataset 구축
1. custom dataset class로 `Dataset` 구축
2. `DataLoader`로 구축

In [41]:
from torch.utils.data import Dataset, DataLoader

In [45]:
class WineDataset(Dataset):
#dataset은 함수 3개만 사용할 수 있도록 규칙 있음
    
    #pytorch에서는 항상 float32 사용
    def __init__(self):
        wine = pd.read_csv('../../data/wine.csv')
        wine = wine.values.astype(np.float32)
        self.n_samples = wine.shape[0]
        
        self.X_data = torch.from_numpy(wine[:, 1:])
        self.y_data = torch.from_numpy(wine[:, [0]])
    
    def __getitem__(self, index):
        return self.X_data[index], self.y_data[index]
        
    def __len__(self):
        return self.n_samples

In [46]:
dataset = WineDataset()
dataset

<__main__.WineDataset at 0x1ba538a5090>

In [47]:
dataset[1]

(tensor([1.3200e+01, 1.7800e+00, 2.1400e+00, 1.1200e+01, 1.0000e+02, 2.6500e+00,
         2.7600e+00, 2.6000e-01, 1.2800e+00, 4.3800e+00, 1.0500e+00, 3.4000e+00,
         1.0500e+03]),
 tensor([1.]))

In [48]:
train_loader = DataLoader(dataset=dataset, batch_size=16, shuffle=True)

In [49]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x1ba538a65f0>

In [51]:
next(iter(train_loader))

[tensor([[1.3750e+01, 1.7300e+00, 2.4100e+00, 1.6000e+01, 8.9000e+01, 2.6000e+00,
          2.7600e+00, 2.9000e-01, 1.8100e+00, 5.6000e+00, 1.1500e+00, 2.9000e+00,
          1.3200e+03],
         [1.3560e+01, 1.7100e+00, 2.3100e+00, 1.6200e+01, 1.1700e+02, 3.1500e+00,
          3.2900e+00, 3.4000e-01, 2.3400e+00, 6.1300e+00, 9.5000e-01, 3.3800e+00,
          7.9500e+02],
         [1.2290e+01, 2.8300e+00, 2.2200e+00, 1.8000e+01, 8.8000e+01, 2.4500e+00,
          2.2500e+00, 2.5000e-01, 1.9900e+00, 2.1500e+00, 1.1500e+00, 3.3000e+00,
          2.9000e+02],
         [1.2530e+01, 5.5100e+00, 2.6400e+00, 2.5000e+01, 9.6000e+01, 1.7900e+00,
          6.0000e-01, 6.3000e-01, 1.1000e+00, 5.0000e+00, 8.2000e-01, 1.6900e+00,
          5.1500e+02],
         [1.2880e+01, 2.9900e+00, 2.4000e+00, 2.0000e+01, 1.0400e+02, 1.3000e+00,
          1.2200e+00, 2.4000e-01, 8.3000e-01, 5.4000e+00, 7.4000e-01, 1.4200e+00,
          5.3000e+02],
         [1.1840e+01, 8.9000e-01, 2.5800e+00, 1.8000e+01, 9.4000e