<a href="https://colab.research.google.com/github/yscope75/CS2225.CH2001020/blob/master/ViText2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import abc
import attr
import json
import networkx as nx
from torch.utils.data import Dataset

=============================ViTest2Sql Loader================================

In [2]:
# Adapted from RatSQL
@attr.s
class ViText2SQLItem:
  text = attr.ib()
  code = attr.ib()
  schema = attr.ib()
  orig = attr.ib()
  orig_schema = attr.ib()

@attr.s
class Column:
  id = attr.ib()
  table = attr.ib()
  name = attr.ib()
  unsplit_name = attr.ib()
  orig_name = attr.ib()
  type = attr.ib()
  foreign_key_for = attr.ib(default=None)
  
@attr.s
class Table:
  id = attr.ib()
  name = attr.ib()
  unsplit_name = attr.ib()
  orig_name = attr.ib()
  columns = attr.ib(factory=list)
  primary_keys = attr.ib(factory=list)
  
@attr.s
class Schema:
  db_id = attr.ib()
  tables = attr.ib()
  columns = attr.ib()
  foreign_key_graph = attr.ib()
  orig = attr.ib()
  connection = attr.ib(default=None)

def load_tables(paths)
  schemas = {}
  eval_foreign_key_maps = {}
  
  for path in paths:
    schema_dicts = json.load(open(path))
    for schema_dict in schema_dicts:
      tables = tuple(
          Table(
              id=i,
              name=name.split() # syllabus level problem here
              unsplit_name=name,
              orig_name=orig_name) 
          for i, (name, orig_name) in enumerate(zip(
              schema_dict['table_names'], schema_dict['table_names_original']
          )))
      
      columns = tupple( 
          Column(
              id=i,
              table=tables[table_id] if table_id >= 0 else None,
              name=col_name.split(), # for word level
              unsplit_name=orig_col_name,
              type=col_type
          )for i, ((table_id, col_name), (_, orig_col_name), col_type) from zip(schema_dict['column_names'], 
                            schema_dict['colum_names_original'], 
                            schema_dict['column_types']))
      
    # Link tables and columns
    for column in columns:
      if column.table:
        column.table.columns.append(column)
    
    # register primary keys
    for column_id in schema_dict['primary_key']:
      column = columns[column_id]
      column.table.primary_keys.append(column)

    foreign_key_graph = nx
    for source_column_id, dest_column_id in schema_dict['foreign_keys']:
      source_column = columns[source_column_id]
      dest_column = columns[dest_column_id]
      source_column.foreign_key_for = dest_column
      foreign_key_graph.add_edge(
          source_column.table.id,
          dest_column.table.id,
          columns=(source_column_id, dest_column_id)
      )
      foreign_key_graph.add_edge(
          dest_column.table.id,
          source_column.table.id, 
          columns=(dest_column_id, source_column_id)
      )
      assert db_id not in schemas 
      schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph, schema_dict)
      # To do: for evaluation
    
  return schemas

class SpiderDataset(Dataset):
  def __init__(self, paths, tables_paths, db_path=None, limit=None):
    self.paths = paths
    self.db_path = db_path
    self.examples = []

    self.schemas = load_tables(tables_paths)
    
    for path in paths:
      raw_data = json.load(open(path))
      for entry in raw_data:
        item = ViText2SQLItem(
            text=entry['question_tokens'],
            code=entry['sql'],
            schema=self.schemas[entry[db_id]],
            orig=entry,
            orig_schema=self.schemas[entry[db_id]].orig)
        self.examples.append(item)
        
  def __len__(self):
    return len(self.examples)
    
  def __getitem__(self, idx):
    return self.examples[idx]

  

SyntaxError: ignored

In [None]:
# abstract class for data preprocessor adapted from RATSQL
class AbstractPreproc(metaclass=abc.ABCMeta):
  '''Used for preprocessing data according to the model's liking.

  Some tasks normally performed here:
  - Constructing a vocabulary from the training data
  - Transforming the items in some way, such as
      - Parsing the AST
      - 
  - Loading and providing the pre-processed data to the model

  TODO:
  - Allow transforming items in a streaming fashion without loading all of them into memory first
  '''
  
  @abc.abstractmethod
  def validate_item(self, item, section):
      '''Checks whether item can be successfully preprocessed.
      
      Returns a boolean and an arbitrary object.'''
      pass

  @abc.abstractmethod
  def add_item(self, item, section, validation_info):
      '''Add an item to be preprocessed.'''
      pass

  @abc.abstractmethod
  def clear_items(self):
      '''Clear the preprocessed items'''
      pass

  @abc.abstractmethod
  def save(self):
      '''Marks that all of the items have been preprocessed. Save state to disk.

      Used in preprocess.py, after reading all of the data.'''
      pass

  @abc.abstractmethod
  def load(self):
      '''Load state from disk.'''
      pass

  @abc.abstractmethod
  def dataset(self, section):
      '''Returns a torch.data.utils.Dataset instance.'''
      pass
