**df处理label列，把其转化成 0, 1**
```python
df['Output'] = df['Output'].replace({'True': 1, 'False': 0})
```

**df处理某一列中的字符串，删除或替换某个字符**
```python
df['column_name'] = df['column_name'].str.replace('[', '').str.replace(']', '')
```


**df保存文件**
```python
df.to_csv('./output.csv',header = True, index = False)
```

**需要把读进来的字符串转化成数字**
```python
list(map(int, df['Output'].values.tolist()))
```
**或者直接**
```python
df['Output'].values
```

In [17]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

In [24]:
class SQLDataset(Dataset):
    def __init__(self, sentences, tables, labels, tokenizer, max_length):
        self.sentences = sentences
        self.tables = tables
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        table = self.tables[idx]
        label = self.labels[idx]
        inputs = self.tokenizer(
            f"{sentence} [SEP] {table}",
            return_tensors="pt",
            max_length=self.max_length,
            padding="max_length",
            truncation=True
        )
        inputs = {key: value.squeeze(0) for key, value in inputs.items()}  # Add this line
        inputs["labels"] = torch.tensor(label, dtype=torch.long)
        return inputs

In [19]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    seed=42,
    load_best_model_at_end=True,
)

In [20]:
import pandas as pd

In [21]:
df = pd.read_csv(r"./data/data.csv")

In [22]:
sentences = df['Input'].values
tables = df['Table'].values
labels = df['Output'].values

In [26]:
from sklearn.model_selection import train_test_split

# 将数据分为训练集和验证集
sentences_train, sentences_val, tables_train, tables_val, labels_train, labels_val = train_test_split(
    sentences, tables, labels, test_size=0.2, random_state=42
)

In [27]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# 将句子、表格和标签列表传递给SQLDataset实例
train_dataset = SQLDataset(sentences_train, tables_train, labels_train, tokenizer, max_length=128)
val_dataset = SQLDataset(sentences_val, tables_val, labels_val, tokenizer, max_length=128)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset  # Add this line
)

trainer.train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

  0%|          | 0/153 [00:00<?, ?it/s]

{'loss': 0.479, 'learning_rate': 4.673202614379085e-05, 'epoch': 0.2}
{'loss': 0.3877, 'learning_rate': 4.3464052287581704e-05, 'epoch': 0.39}
{'loss': 0.2404, 'learning_rate': 4.0196078431372555e-05, 'epoch': 0.59}
{'loss': 0.2577, 'learning_rate': 3.6928104575163405e-05, 'epoch': 0.78}
{'loss': 0.2715, 'learning_rate': 3.366013071895425e-05, 'epoch': 0.98}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.2195143848657608, 'eval_runtime': 1.0555, 'eval_samples_per_second': 190.433, 'eval_steps_per_second': 12.317, 'epoch': 1.0}
{'loss': 0.192, 'learning_rate': 3.0392156862745097e-05, 'epoch': 1.18}
{'loss': 0.1237, 'learning_rate': 2.7124183006535947e-05, 'epoch': 1.37}
{'loss': 0.325, 'learning_rate': 2.38562091503268e-05, 'epoch': 1.57}
{'loss': 0.1404, 'learning_rate': 2.058823529411765e-05, 'epoch': 1.76}
{'loss': 0.1477, 'learning_rate': 1.7320261437908496e-05, 'epoch': 1.96}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.19504684209823608, 'eval_runtime': 1.0753, 'eval_samples_per_second': 186.928, 'eval_steps_per_second': 12.09, 'epoch': 2.0}
{'loss': 0.1058, 'learning_rate': 1.4052287581699347e-05, 'epoch': 2.16}
{'loss': 0.0803, 'learning_rate': 1.0784313725490197e-05, 'epoch': 2.35}
{'loss': 0.1154, 'learning_rate': 7.5163398692810456e-06, 'epoch': 2.55}
{'loss': 0.092, 'learning_rate': 4.2483660130718954e-06, 'epoch': 2.75}
{'loss': 0.1442, 'learning_rate': 9.80392156862745e-07, 'epoch': 2.94}


  0%|          | 0/13 [00:00<?, ?it/s]

{'eval_loss': 0.18450163304805756, 'eval_runtime': 1.0691, 'eval_samples_per_second': 188.009, 'eval_steps_per_second': 12.16, 'epoch': 3.0}
{'train_runtime': 53.0274, 'train_samples_per_second': 45.373, 'train_steps_per_second': 2.885, 'train_loss': 0.20590507029707916, 'epoch': 3.0}


TrainOutput(global_step=153, training_loss=0.20590507029707916, metrics={'train_runtime': 53.0274, 'train_samples_per_second': 45.373, 'train_steps_per_second': 2.885, 'train_loss': 0.20590507029707916, 'epoch': 3.0})

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
device

device(type='cuda')

In [31]:
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [33]:
def predict(sentence, table, model, tokenizer, max_length=128):
    inputs = tokenizer(
        f"{sentence} [SEP] {table}",
        return_tensors="pt",
        max_length=max_length,
        padding="max_length",
        truncation=True
    )

    # 将输入数据移动到相同的设备
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        prediction = torch.argmax(logits, dim=-1).item()

    return prediction


In [37]:
# 示例输入
sentence = "What is the name of the product with the lowest sales?"
table = "id name product_id quantity sale_date"

# 预测
prediction = predict(sentence, table, model, tokenizer)

# 输出预测结果
if prediction == 1:
    print("The input sentence can be converted into an SQL query using the provided table.")
else:
    print("The input sentence cannot be converted into an SQL query using the provided table.")

The input sentence cannot be converted into an SQL query using the provided table.


In [38]:
# 指定保存目录
save_directory = "./model"

# 保存模型
model.save_pretrained(save_directory)

# 保存分词器
tokenizer.save_pretrained(save_directory)

('./model\\tokenizer_config.json',
 './model\\special_tokens_map.json',
 './model\\vocab.txt',
 './model\\added_tokens.json')