In [34]:
from io import StringIO
import sys

class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self
    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout

In [40]:
"""Utility functions used to download, open and display
 the contents of Wikimedia SQL dump files.
"""

import gzip
import sys
from contextlib import contextmanager
from pathlib import Path, PosixPath
from typing import Iterator, Optional, TextIO, Union
from urllib.error import HTTPError

import wget  # type: ignore

# Custom type
PathObject = Union[str, Path]


# TODO: eventually will want to update the function calls to match rest of library -- e.g., file_path: string, mode: string, etc.
# Done!
@contextmanager
def _open_file(
    file_path: PathObject, encoding: Optional[str] = None
) -> Iterator[TextIO]:
    """Custom context manager for opening both .gz and uncompressed files.

    :param file_path: The path to the file
    :type file_path: PathObject
    :param encoding: Text encoding, defaults to None
    :type encoding: Optional[str], optional
    :yield: A file handle
    :rtype: Iterator[TextIO]
    """

    if str(file_path).endswith(".gz"):
        infile = gzip.open(file_path, mode="rt", encoding=encoding)
    else:
        infile = open(file_path, mode="r", encoding=encoding)
    try:
        yield infile
    finally:
        infile.close()


def head(file_path: PathObject, n_lines: int = 10, encoding: str = "utf-8") -> None:
    """Display first n lines of a file. Works with both
    .gz and uncompressed files. Defaults to 10 lines.

    :param file_path: The path to the file
    :type file_path: PathObject
    :param n_lines: Lines to display, defaults to 10
    :type n_lines: int, optional
    :param encoding: Text encoding, defaults to "utf-8"
    :type encoding: str, optional
    """

    with _open_file(file_path, encoding=encoding) as infile:
        for line in infile:
            if n_lines == 0:
                break
            try:
                print(line.strip())
                n_lines -= 1
            except StopIteration:
                return
    return


# Minor but I would just get rid of the width parameter if you aren't going to use it
# I tried but wget wouldn't work without it. Haven't actually looked into it,
# but what I *think* happens is that while the progress_bar func itself doesn't use the width param, it gets passed as a kwarg to wget where it's necessary.
def _progress_bar(
    current: Union[int, float], total: Union[int, float], width: int = 60
) -> None:
    """Custom progress bar for wget downloads.

    :param current: bytes downloaded so far
    :type current: Union[int, float]
    :param total: Total size of download in bytes or megabytes
    :type total: Union[int, float]
    :param width: Progress bar width in chars, defaults to 60
    :type width: int, optional
    """

    unit = "bytes"

    # Show file size in MB for large files
    if total >= 100000:
        MB = 1024 * 1024
        current = current / MB
        total = total / MB
        unit = "MB"

    progress = current / total
    progress_message = f"Progress: \
    {progress:.0%} [{current:.1f} / {total:.1f}] {unit}"
    sys.stdout.write("\r" + progress_message)
    sys.stdout.flush()


def load(database: str, filename: str, date: str = "latest") -> Optional[PathObject]:
    """Load a dump file from a Wikimedia public directory if the
    user is in a supported environment (PAWS, Toolforge...). Otherwise, download dump file from the web and save in the current working directory. In both cases,the function returns a path-like object which can be used to access the file. Does not check if the file already exists on the path.

    :param database: The database backup dump to download a file from,
        e.g. 'enwiki' (English Wikipedia). See a list of available
        databases here: https://dumps.wikimedia.org/backup-index-bydb.html
    :type database: str
    :param filename: The name of the file to download, e.g. 'page' loads the
        file {database}-{date}-page.sql.gz
    :type filename: str
    :param date: Date the dump was generated, defaults to "latest". If "latest"
        is not used, the date format should be "YYYYMMDD"
    :type date: str, optional
    :return: Path to dump file
    :rtype: Optional[PathObject]
    """

    # style: generally I only use ALL_CAPS variables when it's global so I would just change these to normal_var_names
    # Oh, cool! I though all caps were for constants in general but TIL
    # they're specifically for module level constants
    paws_root_dir = Path("/public/dumps/public/")
    dumps_url = "https://dumps.wikimedia.org/"
    subdir = Path(database, date)
    extended_filename = f"{database}-{date}-{filename}.sql.gz"
    file_path = Path(extended_filename)

    if paws_root_dir.exists():
        dump_file = Path(paws_root_dir, subdir, file_path)

    else:
        url = f"{dumps_url}{str(subdir)}/{str(file_path)}"
        try:
            print(f"Downloading {url}")
            dump_file = wget.download(url, bar=_progress_bar)
        except HTTPError:
            print("File not found")
            return None

    return Path(dump_file)


In [50]:
f = load('simplewiki', 'change_tag_def', 'latest')

Downloading https://dumps.wikimedia.org/simplewiki/latest/simplewiki-latest-change_tag_def.sql.gz
Progress:     100% [2131.0 / 2131.0] bytes

In [44]:
f == PosixPath('simplewiki-20210701-change_tag_def.sql.gz')

True

In [51]:
with _open_file(f) as infile:    
    for line in infile:
        print(line)
        break

-- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)



In [55]:
head(f, 10)

-- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)
--
-- Host: 10.64.32.82    Database: simplewiki
-- ------------------------------------------------------
-- Server version	10.4.19-MariaDB-log

/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
/*!40101 SET NAMES utf8mb4 */;


In [60]:
with Capturing() as output:
    head(f, 42)
output

['-- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)',
 '--',
 '-- Host: 10.64.32.82    Database: simplewiki',
 '-- ------------------------------------------------------',
 '-- Server version\t10.4.19-MariaDB-log',
 '',
 '/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;',
 '/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;',
 '/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;',
 '/*!40101 SET NAMES utf8mb4 */;',
 '/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;',
 "/*!40103 SET TIME_ZONE='+00:00' */;",
 '/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;',
 '/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;',
 "/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;",
 '/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;',
 '',
 '--',
 '-- Table structure for table `change_tag_def`',
 '--',
 '',
 'DROP TABLE IF EXISTS `change_tag_def`;',


In [22]:
d = {n: item for n, item in enumerate(output)}

for n, item in enumerate(output):
    print(f'{n}: {item}')

0: -- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)
1: --
2: -- Host: 10.64.32.82    Database: simplewiki
3: -- ------------------------------------------------------
4: -- Server version	10.4.19-MariaDB-log
5: 
6: /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
7: /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
8: /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
9: /*!40101 SET NAMES utf8mb4 */;
10: /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
11: /*!40103 SET TIME_ZONE='+00:00' */;
12: /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
13: /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
14: /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
15: /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;
16: 
17: --
18: -- Table structure for table `change_tag_def`
19: --
20: 
21: DROP TABLE IF EXISTS `change_tag_def`;
22: /*!4010

In [56]:
out = ['-- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)',
 '--',
 '-- Host: 10.64.32.82    Database: simplewiki',
 '-- ------------------------------------------------------',
 '-- Server version\t10.4.19-MariaDB-log',
 '',
 '/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;',
 '/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;',
 '/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;',
 '/*!40101 SET NAMES utf8mb4 */;']

In [57]:
len(out)

10

In [8]:
assert out == output

In [30]:
"""Parser functions used in src/dump.py"""

import csv
import re
import warnings
from typing import Any, Dict, Iterator, List, Optional


def _has_sql_attribute(line: str, attr_type: str) -> bool:
    """Check whether a string contains a specific SQL element
    or statement.

    :param line: A line from a SQL dump file.
    :type line: str
    :param attr_type: Element or statement type, e.g "primary_key"
        for a table's primary key or "insert" for INSERT INTO statements.
    :type attr_type: str
    :return: True or False
    :rtype: bool
    """

    # FYI: no need to update this because I think it's nice and simple but if you were trying to
    # expand it to more use cases and finding the rules to get more complex, you would likely want to consider
    # a regex for each like in get_sql_attribute.
    # e.g., something like (with the caveat that I am not good at regexes): re.match('(^--).*(Database: )', line)
    line_start = {
        "database": "--",
        "insert": "INSERT INTO",
        "create": "CREATE TABLE",
        "primary_key": "PRIMARY KEY",
        "col_name": "`",
    }
    contains_element = line.strip().startswith(line_start[attr_type])

    if attr_type == "database":
        return contains_element and "Database: " in line

    return contains_element


def _get_sql_attribute(line: str, attr_type: str) -> Any:
    """Extract a SQL attribute from a string that contains it.

    :param line: A line from a SQL dump file.
    :type line: str
    :param attr_type: Element or statement type, e.g "primary_key"
        for a table's primary key or "col_name" for a column (field) name.
    :type attr_type: str
    :return: A SQL attribute such as database(name), table name,
        primary_key, etc.
    :rtype: Optional[str]
    """

    attr_pattern = {
        "table_name": r"`([\S]*)`",
        "col_name": r"`([\S]*)`",
        "dtype": r"` ((.)*),",
        "primary_key": r"`([\S]*)`",
    }

    attr: Optional[str] = None

    try:
        if attr_type == "database":
            attr = line.strip().partition("Database: ")[-1]

        elif attr_type in ("table_name", "col_name", "dtype"):
            # ignore type - mypy does not understand try... except here
            attr = re.search(attr_pattern[attr_type], line).group(1)  # type: ignore

        elif attr_type == "primary_key":
            attr = (
                re.search(attr_pattern[attr_type], line)
                .group(1)  # type: ignore
                .replace("`", "")
                .split(",")
            )

    except AttributeError:
        return None

    # probably define attr as None before the try clause. right now if not of the if-else clauses matched, would throw a weird error
    # Done!
    return attr


# I don't know much about the intricacies of types but I like this -- good and simple!
def _map_dtypes(sql_dtypes: Dict[str, str]) -> Dict[str, type]:
    """Create mapping from SQL data types to Python data types.

    :param sql_dtypes: A mapping from the column names in a SQL table
        to their respective SQL data types.
        Example: {"ct_id": int(10) unsigned NOT NULL AUTO_INCREMENT}
    :type sql_dtypes: Dict[str, str]
    :return: A mapping from the column names in a SQL table
        to their respective Python data types. Example: {"ct_id": int}
    :rtype: Dict[str, type]
    """

    types: Dict[str, type] = {}
    for key, val in sql_dtypes.items():
        if "int" in val:
            types[key] = int
        elif any(dtype in val for dtype in ("float", "double", "decimal", "numeric")):
            types[key] = float
        else:
            types[key] = str
    return types


def _convert(values: List[str], dtypes: List[type], strict: bool = False) -> List[Any]:
    """Cast numerical values in a list of strings to float or int
    as specified by the dtypes parameter.

    :param values: A list of strings representing a row in a SQL table
        E.g. ['28207', 'April', '4742783', '0.9793'].
    :type values: List[str]
    :param dtypes: A list of Python data types. E.g. [int, str, int, float]
    :type dtypes: List[type]
    :param strict: When set to False, if any of the items in the list
        cannot be converted, it is returned unchanged, i.e. as a str.
    :type strict: bool, optional
    :raises ValueError: If `values` is not the same length as `dtypes`,
        or if `strict` is set to True and some of the values in the
        list couldn't be converted.
    :return: A list where the numerical values have been cast as int
        or string as defined by `dtypes`. E.g. the example list from
        above is returned as [28207, 'April', 4742783, 0.9793]
    :rtype: List[Any]
    """

    len_values = len(values)
    len_dtypes = len(dtypes)

    warn = False

    if len_values != len_dtypes:
        if not strict:
            return values

        raise ValueError("values and dtypes are not the same length")

    converted = []
    for i in range(len_dtypes):
        dtype = dtypes[i]
        val = values[i]

        try:
            conv = dtype(val)
            converted.append(conv)

        except ValueError as e:
            if values[i] == "":
                # why not convert to None?
                converted.append(val)
            elif not strict:
                warn = True
                converted.append(val)
            else:
                # my PyCharm installation doesn't like this and things it won't work FYI. I haven't tested it though.
                # You're right - I've changed this now.
                print(f"ValueError: {e}")

    if warn:
        # low priority: perhaps include the values too? or problematic value?
        # > I need to think about how to handle this because some files, notably
        # externallinks, have > 10^3 such values
        warnings.warn("some rows could not be converted to Python dtypes")

    return converted


def _split_tuples(line: str) -> List[str]:
    """Split an INSERT INTO statement into a list of strings each
    representing a SQL table row.

    :param line: An INSERT INTO statement, e.g. "INSERT INTO `change_tag_def`
        VALUES (1,'mw-replace',0,10200),(2,'visualeditor',0,305860);"
    :type line: str
    :return: A list with items representing SQL rows,
        e.g. ["1,'mw-replace',0,10200", "2,'visualeditor',0,305860"]
    :rtype: List[str]
    """

    # I think the NULL replacement might need some tweaking. The challenge is two-fold:
    # * making NULL into something that doesn't break the parser -- that's easy, either add quotes like you do or replace with None
    # > I have opted for replacing NULL with the empty string when it's used
    # to denote missing values (i.e. not part of some other string). The reason
    # is that `None` is somewhat specific to pure Python while Pandas, Numpy, R, CSV, and others recognize the empty string as a missing value and sub it with their own null equivalent (NaN, NA, <na>, ...)
    # * not making this replacement when e.g., NULL is just part of a real value like a page title as happens in Commons sometimes
    # For the latter, I think you might need to a regex that only does the replacement when it sees any of the following
    # which in theory should capture all the ways that NULL shows up as a full field value:
    # * ,NULL,
    # * (NULL,
    # * ,NULL)
    # > Good suggestion - I have implemented this regex.
    tuples = line.partition(" VALUES ")[-1].strip()
    # Sub NULL with the empty string
    pattern = r"(?<=[,(])NULL(?=[,)])"
    values = re.sub(pattern, "", tuples)
    # Remove `;` at the end of the last `INSERT INTO` statement
    if values[-1] == ";":
        values = values[:-1]
    records = re.split(r"\),\(", values[1:-1])  # Strip `(` and `)`

    return records


def _parse(
    line: str,
    delimiter: str = ",",
    escape_char: str = "\\",
    quote_char: str = "'",
    doublequote: bool = False,
    strict: bool = True,
) -> Iterator[List[str]]:
    """Parse an INSERT INTO statement and return a generator that yields from a list of CSV-formatted strings, each representing a SQL table row. This
    is essentially a wrapper around a csv.reader object and takes the same
    parameters, except it takes a string as input instead of an iterator-type
    object.

    :param line: An INSERT INTO statement, e.g. "INSERT INTO `change_tag_def`
        VALUES (1,'mw-replace',0,10200),(2,'visualeditor',0,305860);"
    :type line: str
    :param delimiter: A one-character string used to separate fields,
        defaults to ","
    :type delimiter: str, optional
    :param escape_char: A one-character string used by the reader to remove
        any special meaning from the following character, defaults to "\"
    :type escape_char: str, optional
    :param quote_char: A one-character string used to quote fields
        containing special characters, such as the delimiter or quotechar,
        or which contain new-line characters, defaults to "'"
    :type quote_char: str, optional
    :param doublequote: Controls how instances of quotechar appearing inside
        a field should themselves be quoted. When True, the character
        is doubled. When False, the escapechar is used as a prefix
        to the quotechar. Defaults to False.
    :type doublequote: bool, optional
    :param strict: When True, raise exception Error on bad CSV input.
        Defaults to True.
    :type strict: bool, optional
    :return: A generator that yields from a list of CSV-formatted strings.
    :rtype: Iterator[List[str]]
    """

    records = _split_tuples(line)
    reader = csv.reader(
        records,
        delimiter=delimiter,
        escapechar=escape_char,
        quotechar=quote_char,
        doublequote=doublequote,
        strict=strict,
    )
    return reader


In [21]:
_get_sql_attribute(output[26], 'col_name')

'ctd_name'

In [23]:
d

{0: '-- MySQL dump 10.18  Distrib 10.3.27-MariaDB, for debian-linux-gnu (x86_64)',
 1: '--',
 2: '-- Host: 10.64.32.82    Database: simplewiki',
 3: '-- ------------------------------------------------------',
 4: '-- Server version\t10.4.19-MariaDB-log',
 5: '',
 6: '/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;',
 7: '/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;',
 8: '/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;',
 9: '/*!40101 SET NAMES utf8mb4 */;',
 10: '/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;',
 11: "/*!40103 SET TIME_ZONE='+00:00' */;",
 12: '/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;',
 13: '/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;',
 14: "/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;",
 15: '/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;',
 16: '',
 17: '--',
 18: '-- Table structure for table `change_

In [26]:
"""Utility functions used to download, open and display
 the contents of Wikimedia SQL dump files.
"""

import gzip
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Optional, TextIO, Union
from urllib.error import HTTPError

import wget  # type: ignore

# Custom type
PathObject = Union[str, Path]


# TODO: eventually will want to update the function calls to match rest of library -- e.g., file_path: string, mode: string, etc.
# Done!
@contextmanager
def _open_file(
    file_path: PathObject, encoding: Optional[str] = None
) -> Iterator[TextIO]:
    """Custom context manager for opening both .gz and uncompressed files.

    :param file_path: The path to the file
    :type file_path: PathObject
    :param encoding: Text encoding, defaults to None
    :type encoding: Optional[str], optional
    :yield: A file handle
    :rtype: Iterator[TextIO]
    """

    if str(file_path).endswith(".gz"):
        infile = gzip.open(file_path, mode="rt", encoding=encoding)
    else:
        infile = open(file_path, mode="r", encoding=encoding)
    try:
        yield infile
    finally:
        infile.close()


def head(file_path: PathObject, n_lines: int = 10, encoding: str = "utf-8") -> None:
    """Display first n lines of a file. Works with both
    .gz and uncompressed files. Defaults to 10 lines.

    :param file_path: The path to the file
    :type file_path: PathObject
    :param n_lines: Lines to display, defaults to 10
    :type n_lines: int, optional
    :param encoding: Text encoding, defaults to "utf-8"
    :type encoding: str, optional
    """

    with _open_file(file_path, encoding=encoding) as infile:
        for line in infile:
            if n_lines == 0:
                break
            try:
                print(line.strip())
                n_lines -= 1
            except StopIteration:
                return
    return


# Minor but I would just get rid of the width parameter if you aren't going to use it
# I tried but wget wouldn't work without it. Haven't actually looked into it,
# but what I *think* happens is that while the progress_bar func itself doesn't use the width param, it gets passed as a kwarg to wget where it's necessary.
def _progress_bar(
    current: Union[int, float], total: Union[int, float], width: int = 60
) -> None:
    """Custom progress bar for wget downloads.

    :param current: bytes downloaded so far
    :type current: Union[int, float]
    :param total: Total size of download in bytes or megabytes
    :type total: Union[int, float]
    :param width: Progress bar width in chars, defaults to 60
    :type width: int, optional
    """

    unit = "bytes"

    # Show file size in MB for large files
    if total >= 100000:
        MB = 1024 * 1024
        current = current / MB
        total = total / MB
        unit = "MB"

    progress = current / total
    progress_message = f"Progress: \
    {progress:.0%} [{current:.1f} / {total:.1f}] {unit}"
    sys.stdout.write("\r" + progress_message)
    sys.stdout.flush()


def load(database: str, filename: str, date: str = "latest") -> Optional[PathObject]:
    """Load a dump file from a Wikimedia public directory if the
    user is in a supported environment (PAWS, Toolforge...). Otherwise, download dump file from the web and save in the current working directory. In both cases,the function returns a path-like object which can be used to access the file. Does not check if the file already exists on the path.

    :param database: The database backup dump to download a file from,
        e.g. 'enwiki' (English Wikipedia). See a list of available
        databases here: https://dumps.wikimedia.org/backup-index-bydb.html
    :type database: str
    :param filename: The name of the file to download, e.g. 'page' loads the
        file {database}-{date}-page.sql.gz
    :type filename: str
    :param date: Date the dump was generated, defaults to "latest". If "latest"
        is not used, the date format should be "YYYYMMDD"
    :type date: str, optional
    :return: Path to dump file
    :rtype: Optional[PathObject]
    """

    # style: generally I only use ALL_CAPS variables when it's global so I would just change these to normal_var_names
    # Oh, cool! I though all caps were for constants in general but TIL
    # they're specifically for module level constants
    paws_root_dir = Path("/public/dumps/public/")
    dumps_url = "https://dumps.wikimedia.org/"
    subdir = Path(database, date)
    extended_filename = f"{database}-{date}-{filename}.sql.gz"
    file_path = Path(extended_filename)

    if paws_root_dir.exists():
        dump_file = Path(paws_root_dir, subdir, file_path)

    else:
        url = f"{dumps_url}{str(subdir)}/{str(file_path)}"
        try:
            print(f"Downloading {url}")
            dump_file = wget.download(url, bar=_progress_bar)
        except HTTPError:
            print("File not found")
            return None

    return Path(dump_file)


In [25]:
"""A set of utilities for processing MediaWiki SQL dump data"""

import csv
import sys
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Type, TypeVar, Union

# from .parser import (
#     _convert,
#     _get_sql_attribute,
#     _has_sql_attribute,
#     _map_dtypes,
#     _parse,
# )
# from .utils import _open_file

# Allow long field names
csv.field_size_limit(sys.maxsize)

# Custom types
PathObject = Union[str, Path]
T = TypeVar("T", bound="Dump")


class Dump:
    """Class for parsing an SQL dump file and processing its contents"""

    def __init__(
        self,
        database: Optional[str],
        table_name: Optional[str],
        col_names: List[str],
        col_sql_dtypes: Dict[str, str],
        primary_key: Optional[str],
        source_file: PathObject,
        encoding: str,
    ) -> None:
        """Dump class constructor.

        :param database: The wiki database, e.g. 'enwiki' or 'dewikibooks'
        :type database: Optional[str]
        :param table_name: The SQL table name
        :type table_name: Optional[str]
        :param col_names: The SQL table column (field) names
        :type col_names: List[str]
        :param col_sql_dtypes: A mapping from the column names in a SQL table
            to their respective SQL data types.
            Example: {"ct_id": int(10) unsigned NOT NULL AUTO_INCREMENT}
        :type col_sql_dtypes: Dict[str, str]
        :param primary_key: The primary key of the SQL table
            Can be unique or composite.
        :type primary_key: Optional[str]
        :param source_file: The path to the SQL dump file
        :type source_file: PathObject
        :param encoding: Text encoding
        :type encoding: str
        """

        self.db = database
        self.name = table_name
        self.col_names = col_names
        self.sql_dtypes = col_sql_dtypes
        self.primary_key = primary_key
        self.size = Path(source_file).stat().st_size
        self._dtypes: Optional[Dict[str, type]] = None
        self._source_file = source_file
        self._encoding = encoding

    def __str__(self) -> str:
        return f"Dump(database={self.db}, name={self.name}, size={self.size})"

    def __repr__(self) -> str:
        return str(self)

    def __iter__(self) -> Iterator[List[Any]]:
        return self.rows()

    @property
    def encoding(self) -> str:
        """Get the encoding used to read the dump file.

        :return: Text encoding
        :rtype: str
        """

        return self._encoding

    @encoding.setter
    def encoding(self, new_encoding: str) -> None:
        """Set the encoding used to read the dump file.

        :param new_encoding: Text encoding
        :type new_encoding: str
        """

        self._encoding = new_encoding

    @property
    def dtypes(self) -> Dict[str, type]:
        """Mapping between col_names and native Python dtypes.

        :return: A mapping from the column names in a SQL table
        to their respective Python data types. Example: {"ct_id": int}
        :rtype: Dict[str, type]
        """

        if self._dtypes is None:
            self._dtypes = _map_dtypes(self.sql_dtypes)
        return self._dtypes

    @classmethod
    def from_file(cls: Type[T], file_path: PathObject, encoding: str = "utf-8") -> T:
        """Initialize Dump object from dump file.

        :param cls: A Dump class instance
        :type cls: Dump
        :param file_path: Path to source SQL dum file. Can be a .gz or an
            uncompressed file
        :type file_path: PathObject
        :param encoding: Text encoding, defaults to "utf-8" If you get
        an encoding error when processing the file, try setting this
            parameter to 'Latin-1'
        :type encoding: str, optional
        :return: A Dump class instance
        :rtype: Dump
        """

        source_file = file_path
        database = None
        table_name = None
        primary_key = None
        col_names = []
        col_sql_dtypes = {}

        # Extract meta data from dump file
        with _open_file(file_path, encoding=encoding) as infile:
            for line in infile:
                if _has_sql_attribute(line, "database"):
                    database = _get_sql_attribute(line, "database")

                elif _has_sql_attribute(line, "create"):
                    table_name = _get_sql_attribute(line, "table_name")

                elif _has_sql_attribute(line, "col_name"):
                    col_name = _get_sql_attribute(line, "col_name")
                    dtype = _get_sql_attribute(line, "dtype")
                    col_names.append(col_name)
                    col_sql_dtypes[col_name] = dtype

                elif _has_sql_attribute(line, "primary_key"):
                    primary_key = _get_sql_attribute(line, "primary_key")

                elif _has_sql_attribute(line, "insert"):
                    break

            return cls(
                database,
                table_name,
                col_names,  # type: ignore
                col_sql_dtypes,  # type: ignore
                primary_key,
                source_file,
                encoding,
            )

    def rows(
        self,
        convert_dtypes: bool = False,
        strict_conversion: bool = False,
        **fmtparams: Any,
    ) -> Iterator[List[Any]]:
        """Create a generator object from the rows.

        :param convert_dtypes: When set to True, numerical types are
            converted from str to int or float. Defaults to False.
        :type convert_dtypes: bool, optional
        :param strict_conversion: When True, raise exception Error on
            bad input when converting from SQL dtypes to Python dtypes.
            Defaults to False.
        :type strict_conversion: bool, optional
        :param fmtparams: Any kwargs you want to pass to the csv.reader()
            function that does the actual parsing.
        :yield: A generator used to iterate over the rows in the SQL table
        :rtype: Iterator[List[Any]]
        """

        if convert_dtypes:
            dtypes = list(self.dtypes.values())

        with _open_file(self._source_file, encoding=self.encoding) as infile:
            for line in infile:
                if _has_sql_attribute(line, "insert"):
                    rows = _parse(line, **fmtparams)
                    for row in rows:
                        if convert_dtypes:
                            converted_row = _convert(
                                row, dtypes, strict=strict_conversion
                            )
                            yield converted_row
                        else:
                            yield row

    def to_csv(self, file_path: PathObject, **fmtparams: Any) -> None:
        """Write Dump object to CSV file.

        :param file_path: The file to write to. Will be created if it
            doesn't already exist. Will be overwritten if it does exist.
        :type file_path: PathObject
        """

        with open(file_path, "w") as outfile:
            writer = csv.writer(outfile, **fmtparams)
            writer.writerow(self.col_names)
            for row in self:
                writer.writerow(row)

    def head(self, n_lines: int = 10, convert_dtypes: bool = False) -> None:
        """Display first n rows.

        :param n_lines: Number of rows to display, defaults to 10
        :type n_lines: int, optional
        :param convert_dtypes: Optionally, shows numerical types as int
            or float instead of all str. Defaults to False.
        :type convert_dtypes: bool, optional
        """

        rows = self.rows(convert_dtypes=convert_dtypes)
        print(self.col_names)

        for _ in range(n_lines):
            try:
                print(next(rows))
            except StopIteration:
                return
        return


In [27]:
!ls

capture.ipynb
capture.py
simplewiki-latest-change_tag_def.sql (1).gz
simplewiki-latest-change_tag_def.sql.gz


In [31]:
d = Dump.from_file('simplewiki-latest-change_tag_def.sql.gz')

In [32]:
d.dtypes

{'ctd_id': int, 'ctd_name': str, 'ctd_user_defined': int, 'ctd_count': int}

In [33]:
d.sql_dtypes

{'ctd_id': 'int(10) unsigned NOT NULL AUTO_INCREMENT',
 'ctd_name': 'varbinary(255) NOT NULL',
 'ctd_user_defined': 'tinyint(1) NOT NULL',
 'ctd_count': 'bigint(20) unsigned NOT NULL DEFAULT 0'}