In [9]:
from lecarb.dataset.dataset import load_table
from lecarb.estimator.lstm.common import load_lstm_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import time
import logging
L = logging.getLogger(__name__)
'''
train-lstm dataset='census13' version='original' workload='lstm-small' hid_units='64' bins='200' train_num='10000' bs='32' sizelimit='0' seed='123':
    poetry run python -m lecarb train -s{{seed}} -d{{dataset}} -v{{version}} -w{{workload}} -elstm --params \
        "{'epochs': 500, 'bins': {{bins}}, 'hid_units': '{{hid_units}}', 'train_num': {{train_num}}, 'bs': {{bs}}}" --sizelimit {{sizelimit}}
'''

'\ntrain-lstm dataset=\'census13\' version=\'original\' workload=\'lstm-small\' hid_units=\'64\' bins=\'200\' train_num=\'10000\' bs=\'32\' sizelimit=\'0\' seed=\'123\':\n    poetry run python -m lecarb train -s{{seed}} -d{{dataset}} -v{{version}} -w{{workload}} -elstm --params         "{\'epochs\': 500, \'bins\': {{bins}}, \'hid_units\': \'{{hid_units}}\', \'train_num\': {{train_num}}, \'bs\': {{bs}}}" --sizelimit {{sizelimit}}\n'

In [10]:
# 加载数据集，将csv文件转为Table类
table = load_table("census13", "original")

dataset = load_lstm_dataset(table, "lstm-small", "123", "200")

[2024-02-20 15:16:33,969 INFO] lecarb.dataset.dataset: start building data census13_original...


[2024-02-20 15:16:34,233 INFO] lecarb.dataset.dataset: build finished: Table census13_original (48842 rows, 4.84MB, columns:
Column(age, type=int64, vocab size=74, min=17, max=90, has NaN=False)
Column(workclass, type=category, vocab size=9, min=?, max=Without-pay, has NaN=False)
Column(education, type=category, vocab size=16, min=10th, max=Some-college, has NaN=False)
Column(education_num, type=int64, vocab size=16, min=1, max=16, has NaN=False)
Column(marital_status, type=category, vocab size=7, min=Divorced, max=Widowed, has NaN=False)
Column(occupation, type=category, vocab size=15, min=?, max=Transport-moving, has NaN=False)
Column(relationship, type=category, vocab size=6, min=Husband, max=Wife, has NaN=False)
Column(race, type=category, vocab size=5, min=Amer-Indian-Eskimo, max=White, has NaN=False)
Column(sex, type=category, vocab size=2, min=Female, max=Male, has NaN=False)
Column(capital_gain, type=int64, vocab size=123, min=0, max=99999, has NaN=False)
Column(capital_loss, t

In [11]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, truecards):
        self.texts = texts
        self.labels = labels
        self.truecards = truecards

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # text_sequence = self.texts[idx]
        # text_tensor = torch.FloatTensor([item.numpy() for item in text_sequence])
        # label_item = self.labels[idx].item() if torch.is_tensor(self.labels[idx]) else self.labels[idx]
        # truecard_item = self.truecards[idx].item() if torch.is_tensor(self.truecards[idx]) else self.truecards[idx]
        # return text_tensor, label_item, truecard_item
        return self.texts[idx], self.labels[idx], self.truecards[idx]
    
def make_dataset(dataset, num=-1):
    X, y, gt = dataset
    # 将list转为tenser
    X = torch.tensor(X).view(num, 50, 11)
    y = torch.tensor(y)
    gt = torch.tensor(gt).view(num, 50)
    L.info(f"{X.shape}, {y.shape}, {gt.shape}")
    if num <= 0:
        return TextDataset(X, y, gt)
    else:
        return TextDataset(X[:num], y[:num], gt[:num])

In [12]:
train_dataset = make_dataset(dataset['train'], 1000)
valid_dataset = make_dataset(dataset['valid'], 100)

In [13]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(valid_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=32)
valid_loader = DataLoader(valid_dataset, batch_size=32)

Number of training samples: 1000
Number of validation samples: 100


In [14]:
# 假设每个字是1x13维的向量
input_dim = 11
hidden_dim = 64
output_dim = 10000  # 两类情感，你可以根据实际情况调整
# 定义模型
class TextSentimentModel(nn.Module):
    def __init__(self):
        super(TextSentimentModel, self).__init__()
        self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        # print(text.shape)
        output, _ = self.rnn(text)
        last_hidden_state = output[:, -1, :]
        sentiment_logits = self.fc(last_hidden_state)
        return sentiment_logits


In [15]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = TextSentimentModel().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [16]:
num_epochs = 10
best_valid_loss = 3.0

start_stmp = time.time()
for epoch in range(num_epochs):
    train_losses = []
    model.train()
    for _, data in enumerate(train_loader):
        inputs, labels, truecards = data
        inputs = inputs.to(DEVICE).float()
        labels = labels.to(DEVICE).float()
        
        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())  # 将损失值添加到列表中
    avg_train_loss = sum(train_losses) / len(train_losses)  # 计算平均损失
    dur_min = (time.time() - start_stmp) / 60
    print(f"Epoch {epoch+1}, loss: {avg_train_loss}, time since start: {dur_min:.1f} mins")
    
    print(f"Test on valid set...")
    valid_stmp = time.time()
    model.eval()
    val_losses = []
    for _, data in enumerate(valid_loader):
        inputs, labels, truecards = data
        inputs = inputs.to(DEVICE).float()
        labels = labels.to(DEVICE).float()
        
        with torch.no_grad():
            preds = model(inputs)
            val_loss = criterion(preds, labels)
            val_losses.append(val_loss.item())

    avg_val_loss = sum(val_losses) / len(val_losses)
    print(f"Validation Loss: {avg_val_loss}")
    
    if avg_val_loss < best_valid_loss:
        print('best valid loss for now!')
        best_valid_loss = avg_val_loss
        

Epoch 1, loss: 0.4665753049775958, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.42920874804258347
best valid loss for now!
Epoch 2, loss: 0.4350220048800111, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.40008002519607544
best valid loss for now!
Epoch 3, loss: 0.4144223863258958, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.4160211980342865
Epoch 4, loss: 0.3797496873885393, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.33823613077402115
best valid loss for now!
Epoch 5, loss: 0.38133747410029173, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.4453503042459488
Epoch 6, loss: 0.403705476783216, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.39526713639497757
Epoch 7, loss: 0.36968201119452715, time since start: 0.0 mins
Test on valid set...
Validation Loss: 0.3403623700141907
Epoch 8, loss: 0.38391980063170195, time since start: 0.0 mins
Test on valid set...
Val