In [1]:
import json
import pandas as pd

# Load the dataset
with open('/Users/stevenslater/Desktop/TextToSQL/spider/train_spider.json', 'r') as file:
    data = json.load(file)

# Convert to DataFrame for easier handling
df = pd.DataFrame(data)

# Preprocess data (this is a simplified example)
# Actual preprocessing would be more involved
df['question'] = df['question'].str.lower()  # Example preprocessing step


In [2]:
df.head()

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,department_management,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...",how many heads of the departments are older th...,"[How, many, heads, of, the, departments, are, ...","{'from': {'table_units': [['table_unit', 1]], ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","list the name, born state and age of the heads...","[List, the, name, ,, born, state, and, age, of...","{'from': {'table_units': [['table_unit', 1]], ..."
2,department_management,"SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","list the creation year, name and budget of eac...","[List, the, creation, year, ,, name, and, budg...","{'from': {'table_units': [['table_unit', 0]], ..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...",what are the maximum and minimum budget of the...,"[What, are, the, maximum, and, minimum, budget...","{'from': {'table_units': [['table_unit', 0]], ..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...",what is the average number of employees of the...,"[What, is, the, average, number, of, employees...","{'from': {'table_units': [['table_unit', 0]], ..."


In [3]:
#print length of df
print(len(df))

7000


In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

# Tokenize the inputs and labels
input_sequences = tokenizer(df['question'].tolist(), padding=True, truncation=True, return_tensors="pt")


In [5]:
import torch
from torch.utils.data import Dataset

class SpiderDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        query = self.dataframe.iloc[idx]['question']
        sql = self.dataframe.iloc[idx]['query']

        # Tokenizing the query and SQL
        inputs = self.tokenizer(
            query,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        targets = self.tokenizer(
            sql,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        input_ids = inputs['input_ids'].flatten()
        attention_mask = inputs['attention_mask'].flatten()
        target_ids = targets['input_ids'].flatten()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': target_ids
        }


In [6]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = SpiderDataset(df, tokenizer)


In [7]:
from torch.utils.data import DataLoader

batch_size = 16
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(1):
    for batch in train_dataloader:
        # Forward pass
        outputs = model(input_ids=batch['input_ids'], labels=batch['labels'])
        loss = outputs.loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Batch complete, Batch Loss: ", loss.item())

    print(f"Epoch {epoch} finished")

torch.save(model, 'Spidermodel.pth')


In [None]:
def generate_sql(query):
    input_ids = tokenizer.encode(query, return_tensors="pt")
    output_ids = model.generate(input_ids)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Example usage
print(generate_sql("Show me all employees in the sales department."))
