In [6]:
import torch
import numpy as np
import pandas as pd
import pymysql
from torch.utils.data import Dataset, DataLoader

In [7]:
class SampleTxtDataset(Dataset):
    def __init__(self):
        # 数据文件地址
        self.txt_file_path = "./data/sample_txt_data.txt"

    def __getitem__(self, item):
        txt_data = np.loadtxt(self.txt_file_path, delimiter=",")
        self._x = torch.from_numpy(txt_data[:, :2])
        self._y = torch.from_numpy(txt_data[:, 2])
        return self._x[item], self._y[item]

    def __len__(self):
        txt_data = np.loadtxt(self.txt_file_path, delimiter=",")
        self._len = len(txt_data)
        return self._len

sample_txt_dataset = SampleTxtDataset()

print("Data Size:",len(sample_txt_dataset))

print("One Sample:",next(iter(sample_txt_dataset)))

print("One Sample's Type:",type(next(iter(sample_txt_dataset))[0]))

Data Size: 10
One Sample: (tensor([0., 1.], dtype=torch.float64), tensor(2., dtype=torch.float64))
One Sample's Type: <class 'torch.Tensor'>


In [8]:
class SampleCsvDataset(Dataset):
    def __init__(self):
        self.csv_file_path = "./data/sample_boston.csv"


    def __getitem__(self, item):
        raw_data = pd.read_csv(self.csv_file_path)
        raw_data_shape = raw_data.shape
        self._x  = torch.from_numpy(raw_data.iloc[:,:raw_data_shape[1]-1].values)
        self._y  = torch.from_numpy(raw_data.iloc[:,raw_data_shape[1]-1].values)
        return self._x[item], self._y[item]

    def __len__(self):
        raw_data = pd.read_csv(self.csv_file_path)
        raw_data_shape = raw_data.shape
        self._len = raw_data_shape[0]
        return self._len

sample_csv_dataset = SampleCsvDataset()

print("Data Size:",len(sample_csv_dataset))

print("One Sample:",next(iter(sample_csv_dataset)))

print("One Sample's Type:",type(next(iter(sample_csv_dataset))[0]))

Data Size: 506
One Sample: (tensor([6.3200e-03, 1.8000e+01, 2.3100e+00, 0.0000e+00, 5.3800e-01, 6.5750e+00,
        6.5200e+01, 4.0900e+00, 1.0000e+00, 2.9600e+02, 1.5300e+01, 3.9690e+02,
        4.9800e+00], dtype=torch.float64), tensor(24., dtype=torch.float64))
One Sample's Type: <class 'torch.Tensor'>


In [None]:
class SampleMysqlDataset(Dataset):
    def __init__(self):
        # 初始化MySQL数据库连接配置参数
        self.mysql_host = "localhost"
        self.mysql_port = 3306
        self.mysql_user = "xxxxx"
        self.mysql_password = "xxxxx"
        self.mysql_db = "sakila"
        self.mysql_table = "payment"
        self.mysql_charset = "utf8"
        self.mysql_sql_data = "select payment_id, customer_id, staff_id, rental_id, amount from sakila.payment"
        self.mysql_sql_cnt = "select count(*) from sakila.payment"

    def __getitem__(self, item):
        # 创建数据库连接
        conn = pymysql.connect(host=self.mysql_host,
                        port=self.mysql_port,
                        user=self.mysql_user,
                        password=self.mysql_password,
                        db=self.mysql_db,
                        charset=self.mysql_charset)
        raw_dataframe = pd.read_sql(self.mysql_sql_data, conn)
        raw_dataframe_shape = raw_dataframe.shape
        self._x  = torch.from_numpy(raw_dataframe.iloc[:,:raw_dataframe_shape[1]-1].values)
        self._y  = torch.from_numpy(raw_dataframe.iloc[:,raw_dataframe_shape[1]-1].values)
        return self._x[item], self._y[item]

    def __len__(self):
        # 创建数据库连接
        conn = pymysql.connect(host=self.mysql_host,
                        port=self.mysql_port,
                        user=self.mysql_user,
                        password=self.mysql_password,
                        db=self.mysql_db,
                        charset=self.mysql_charset)
        raw_dataframe = pd.read_sql(self.mysql_sql_data, conn)
        raw_dataframe_shape = raw_dataframe.shape
        self._len = raw_dataframe_shape[0]
        return self._len

sample_mysql_dataset = SampleMysqlDataset()

print("Data Size:",len(sample_mysql_dataset))

print("One Sample:",next(iter(sample_mysql_dataset)))

print("One Sample's Type:",type(next(iter(sample_mysql_dataset))[0]))

In [9]:
sample_dataloader = DataLoader(dataset=sample_txt_dataset, batch_size=3, shuffle=True)
num_epochs = 4
for epoch in range(num_epochs):
    for iter, (batch_x, batch_y) in enumerate(sample_dataloader):
        print('Epoch: ', epoch, '| Iteration: ', iter, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

Epoch:  0 | Iteration:  0 | batch x:  [[24. 25.]
 [18. 19.]
 [ 3.  4.]] | batch y:  [26. 20.  5.]
Epoch:  0 | Iteration:  1 | batch x:  [[ 9. 10.]
 [ 0.  1.]
 [27. 28.]] | batch y:  [11.  2. 29.]
Epoch:  0 | Iteration:  2 | batch x:  [[21. 22.]
 [ 6.  7.]
 [15. 16.]] | batch y:  [23.  8. 17.]
Epoch:  0 | Iteration:  3 | batch x:  [[12. 13.]] | batch y:  [14.]
Epoch:  1 | Iteration:  0 | batch x:  [[27. 28.]
 [24. 25.]
 [12. 13.]] | batch y:  [29. 26. 14.]
Epoch:  1 | Iteration:  1 | batch x:  [[ 0.  1.]
 [18. 19.]
 [ 6.  7.]] | batch y:  [ 2. 20.  8.]
Epoch:  1 | Iteration:  2 | batch x:  [[ 3.  4.]
 [21. 22.]
 [15. 16.]] | batch y:  [ 5. 23. 17.]
Epoch:  1 | Iteration:  3 | batch x:  [[ 9. 10.]] | batch y:  [11.]
Epoch:  2 | Iteration:  0 | batch x:  [[27. 28.]
 [ 0.  1.]
 [12. 13.]] | batch y:  [29.  2. 14.]
Epoch:  2 | Iteration:  1 | batch x:  [[ 6.  7.]
 [24. 25.]
 [ 9. 10.]] | batch y:  [ 8. 26. 11.]
Epoch:  2 | Iteration:  2 | batch x:  [[15. 16.]
 [ 3.  4.]
 [18. 19.]] | batch 