In [None]:
#default_exp service.db

# Service Db

> Handles sqlite DB access and initialization.

In [None]:
#export
import sqlite3
from pathlib import Path

In [None]:
import tempfile,shutil

## DB init script

Put this here to avoid having to locate a "schema.sql" file in different environments.

This script will
- Initialize the database.
- Drop any existing data and create empty tables.

In [None]:
#export
SCHEMA_SQL="""
DROP TABLE IF EXISTS user;
DROP TABLE IF EXISTS post;

CREATE TABLE user (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  username TEXT UNIQUE NOT NULL,
  password TEXT NOT NULL
);

CREATE TABLE post (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  author_id INTEGER NOT NULL,
  created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
  title TEXT NOT NULL,
  body TEXT NOT NULL,
  is_deleted INTEGER NOT NULL DEFAULT 0,
  FOREIGN KEY (author_id) REFERENCES user (id)
);
"""

In [None]:
#export
class ServiceDb:
    
    def __init__(self,database_file):
        self.db = sqlite3.connect(
            database_file, detect_types=sqlite3.PARSE_DECLTYPES
        )
        self.db.row_factory = sqlite3.Row
    
    def read_user_by_id(self,id): 
        return self.db.execute(
            "SELECT * FROM user WHERE id = ?", (id,)
        ).fetchone()
    
    def read_user_by_username(self,username): 
        return self.db.execute(
            "SELECT * FROM user WHERE username = ?", (username,)
        ).fetchone()
    
    def create_user(self,username,password):
        cursor=self.db.cursor()
        cursor.execute(
            "INSERT INTO user (username, password) VALUES (?, ?)",
            (username, password),
        )
        self.db.commit()
        return cursor.lastrowid
        
    def read_posts_by_author_id(self,author_id): 
        return self.db.execute(
            "SELECT p.id, title, body, strftime('%Y-%m-%d %H:%M:%S',created) created, author_id, username, is_deleted"
            "  FROM post p JOIN user u ON p.author_id = u.id"
            " WHERE u.id = ?"
            "   AND p.is_deleted = 0"
            " ORDER BY created DESC",
            (author_id,)
        ).fetchall()
    
    def read_post_by_id(self,author_id,id): 
        return self.db.execute(
            "SELECT p.id, title, body, strftime('%Y-%m-%d %H:%M:%S',created) created, author_id, username, is_deleted"
            "  FROM post p JOIN user u ON p.author_id = u.id"
            " WHERE p.author_id = ?"
            "   AND p.id = ?",
            (author_id,id,),
        ).fetchone()
    
    def create_post(self,author_id,title,body):
        cursor=self.db.cursor()
        cursor.execute(
            "INSERT INTO post (author_id, title, body) VALUES (?, ?, ?)",
            (author_id, title, body),
        )
        self.db.commit()
        return cursor.lastrowid
    
    def update_post_by_id(self,author_id,id,title,body):
        # TODO: use author_id
        self.db.execute(
            "UPDATE post SET title = ?, body = ? WHERE id = ?", (title, body, id)
        )
        self.db.commit()
    
    def delete_post_by_id(self,author_id,id):
        # TODO: use author_id
        self.db.execute("UPDATE post SET is_deleted = 1 WHERE id = ?", (id,))
        self.db.commit()
        
    def prepare_posts_file_by_author_id(self,author_id):
        # TODO: implement for DB service
        return None,None

In [None]:
def _compare(expected,actual):
    "Check values match for all keys in `expected`, which might not be all keys in `actual`"
    for k in expected.keys(): 
        try:
            if expected[k]!=actual[k]: print(k,expected[k],actual[k])
        except Exception as ex:
            print('failed to check',k,ex)
            print('actual.keys()',actual.keys())
        assert expected[k]==actual[k]

temp_path = tempfile.mkdtemp()
try:
    service=ServiceDb(Path(temp_path)/'web_journal_test.sqlite')
    # init
    service.db.executescript(SCHEMA_SQL)
    # user section
    assert service.read_user_by_id(1234) is None
    assert service.read_user_by_username('test.user') is None
    user_id=service.create_user('test.user','badPassword')
    expected_user=dict(id=user_id,username='test.user',password='badPassword')
    actual_user=service.read_user_by_username('test.user')
    _compare(expected_user,actual_user)
    assert actual_user==service.read_user_by_id(user_id)
    _compare(expected_user,service.read_user_by_id(user_id))
    # post section
    assert service.read_posts_by_author_id(123)==[]
    assert service.read_posts_by_author_id(user_id)==[]
    assert service.read_post_by_id(user_id,123) is None
    for i in range(3): service.create_post(user_id,f'title{i}','body')
    post_id=service.create_post(user_id,'title','body')
    for i in range(3): service.create_post(user_id,f'title{i}2','body')
    # don't add created to `expected_post` as we don't know what it's value will be
    expected_post=dict(id=post_id,author_id=user_id,title='title',body='body',username='test.user',is_deleted=0)
    posts=service.read_posts_by_author_id(user_id)
    assert len(posts)==7
    _compare(expected_post,posts[3])
    assert isinstance(posts[3]['created'],str)
    post=service.read_post_by_id(user_id,post_id)
    assert post==posts[3]
    assert post['is_deleted']==0
    assert post!=service.delete_post_by_id(user_id,post_id)
    expected_post['is_deleted']=1
    # deleted posts are readable by ID ...
    _compare(expected_post,service.read_post_by_id(user_id,post_id))
    # but are not returned when reading all posts by author
    assert len(service.read_posts_by_author_id(user_id))==6
    
finally:
    service.db.close()
    shutil.rmtree(temp_path)

In [None]:
#export
def before_request(app):
    return ServiceDb(Path(app.config['DATA_DIR'])/'web_journal.sqlite')

In [None]:
#export
def after_request(app,service):
    service.db.close()

In [None]:
#export
def init_service(app):
    print('service.db.init_service')
    service=ServiceDb(Path(app.config['DATA_DIR'])/'web_journal.sqlite')
    try:
        # This will raise an error if the user table has not been created ...
        service.read_user_by_id(0)
    except:
        # ... so we know we need to create the schema
        service.db.executescript(SCHEMA_SQL)

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 40a_service_db.ipynb.
Converted 40b_service_filesystem.ipynb.
Converted 50_web_app.ipynb.
Converted 50b_web_auth.ipynb.
Converted 50c_web_blog.ipynb.
Converted index.ipynb.
