Skip to content

Commit

Permalink
Rest
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Mar 21, 2024
1 parent 2271cda commit 268c6c9
Show file tree
Hide file tree
Showing 46 changed files with 864 additions and 1,591 deletions.
45 changes: 45 additions & 0 deletions examples/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Any, Dict

from pydantic import BaseModel, ConfigDict, Json
from sqlalchemy import Column, String, JSON
from sqlalchemy.orm import declarative_base

from zvt.contract.api import get_db_session
from zvt.contract.register import register_schema
from zvt.contract.schema import Mixin

ZvtInfoBase = declarative_base()


class User(Mixin, ZvtInfoBase):
__tablename__ = "user"
added_col = Column(String)
json_col = Column(JSON)


class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: str
entity_id: str
timestamp: datetime
added_col: str
json_col: Dict


register_schema(providers=["zvt"], db_name="test", schema_base=ZvtInfoBase)

if __name__ == "__main__":
user_model = UserModel(
id="user_cn_jack_2020-01-01",
entity_id="user_cn_jack",
timestamp="2020-01-01",
added_col="test",
json_col={"a": 1},
)
session = get_db_session(provider="zvt", data_schema=User)

user = session.query(User).filter(User.id == "user_cn_jack_2020-01-01").first()
print(UserModel.validate(user))
2 changes: 2 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import pprint
import time

import eastmoneypy
import pandas as pd
Expand Down Expand Up @@ -30,6 +31,7 @@ def add_to_eastmoney(codes, group, entity_type="stock", over_write=True):
pass

for code in codes:
time.sleep(0.2)
eastmoneypy.add_to_group(code=code, entity_type=entity_type, group_name=group)


Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
requests==2.31.0
SQLAlchemy==2.0.28
pandas==2.0.3
pydantic==2.6.4
arrow==1.2.3
openpyxl==3.1.1
demjson3==3.0.6
marshmallow-sqlalchemy==1.0.0
marshmallow==3.21.1
plotly==5.13.0
dash==2.8.1
jqdatapy==0.1.8
dash-bootstrap-components==1.3.1
dash_daq==0.5.0
scikit-learn==1.2.1
scikit-learn==1.2.1
fastapi==0.110.0
File renamed without changes.
23 changes: 23 additions & 0 deletions scripts/report_stock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
import json

from examples.data_runner.kdata_runner import record_stock_data, record_stock_news, record_stock_events
from examples.reports.report_tops import report_top_stocks, report_top_blocks
from examples.reports.report_vol_up import report_vol_up_stocks
from examples.utils import get_hot_topics
from zvt import zvt_config
from zvt.factors.top_stocks import compute_top_stocks
from zvt.informer import EmailInformer
from zvt.utils import current_date

if __name__ == "__main__":

# record_stock_news()

record_stock_data()
# record_stock_events()
compute_top_stocks()

report_top_stocks()
# report_top_blocks()
report_vol_up_stocks()
10 changes: 10 additions & 0 deletions scripts/report_stockhk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from examples.data_runner.kdata_runner import record_stockhk_data
from examples.reports.report_tops import report_top_stockhks
from examples.reports.report_vol_up import report_vol_up_stockhks

if __name__ == "__main__":
record_stockhk_data()

report_top_stockhks()
report_vol_up_stockhks()
2 changes: 1 addition & 1 deletion src/zvt/api/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_entity_ids_by_filter(
entity_ids=None,
ignore_bj=False,
):
filters = []
filters = [entity_schema.timestamp.isnot(None)]
if not target_date:
target_date = current_date()
if ignore_new_stock:
Expand Down
17 changes: 16 additions & 1 deletion src/zvt/contract/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import json
import logging
import os
import platform
Expand Down Expand Up @@ -58,11 +59,23 @@ def get_db_engine(
engine_key = "{}_{}".format(provider, db_name)
db_engine = zvt_context.db_engine_map.get(engine_key)
if not db_engine:
db_engine = create_engine("sqlite:///" + db_path, echo=False)
db_engine = create_engine(
"sqlite:///" + db_path, echo=False, json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False)
)
zvt_context.db_engine_map[engine_key] = db_engine
return db_engine


def get_providers() -> List[str]:
return zvt_context.providers


def get_tradable_entity_types(provider=None) -> List[str]:
if provider:
return zvt_context.provider_map_tradable_entity_types.get(provider)
return zvt_context.tradable_entity_types


def get_schemas(provider: str) -> List[DeclarativeMeta]:
"""
get domain schemas supported by the provider
Expand Down Expand Up @@ -664,6 +677,8 @@ def get_entity_ids(
__all__ = [
"_get_db_name",
"get_db_engine",
"get_providers",
"get_tradable_entity_types",
"get_schemas",
"get_db_session",
"get_db_session_factory",
Expand Down
3 changes: 3 additions & 0 deletions src/zvt/contract/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def __init__(self) -> None:
#: all registered entity types(str)
self.tradable_entity_types = []

#: provider -> tradable entity types
self.provider_map_tradable_entity_types = {}

#: all entity schemas
self.tradable_entity_schemas = []

Expand Down
12 changes: 12 additions & 0 deletions src/zvt/contract/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from datetime import datetime

from pydantic import BaseModel, ConfigDict


class MixinModel(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: str
entity_id: str
timestamp: datetime
17 changes: 16 additions & 1 deletion src/zvt/contract/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,31 @@ def register_schema(
for item in schema_base.registry.mappers:
cls = item.class_
if type(cls) == DeclarativeMeta:
# register provider to the schema
for provider in providers:
# register provider to the schema
if issubclass(cls, Mixin):
cls.register_provider(provider)
# register tradable entity type
if issubclass(cls, TradableEntity):
if not entity_type:
print("register TradableEntity must set entity_type")
assert False

if not zvt_context.provider_map_tradable_entity_types.get(provider):
zvt_context.provider_map_tradable_entity_types[provider] = []
zvt_context.provider_map_tradable_entity_types[provider].append(entity_type)

if entity_type not in zvt_context.tradable_entity_types:
zvt_context.tradable_entity_types.append(entity_type)
zvt_context.tradable_entity_schemas.append(cls)
zvt_context.tradable_schema_map[entity_type] = cls

if zvt_context.dbname_map_schemas.get(db_name):
schemas = zvt_context.dbname_map_schemas[db_name]
zvt_context.schemas.append(cls)
if entity_type:
add_to_map_list(the_map=zvt_context.entity_map_schemas, key=entity_type, value=cls)

schemas.append(cls)

zvt_context.dbname_map_schemas[db_name] = schemas
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/contract/zvt_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RecorderState(ZvtInfoBase, StateMixin):
Schema for storing recorder state
"""

__tablename__ = "recoder_state"
__tablename__ = "recorder_state"


class TaggerState(ZvtInfoBase, StateMixin):
Expand Down
8 changes: 4 additions & 4 deletions src/zvt/domain/meta/block_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from sqlalchemy.orm import declarative_base

from zvt.contract import Portfolio, PortfolioStock
from zvt.contract.register import register_schema, register_entity
from zvt.contract.register import register_schema

BlockMetaBase = declarative_base()


#: 板块
@register_entity(entity_type="block")
class Block(BlockMetaBase, Portfolio):
__tablename__ = "block"

Expand All @@ -22,6 +20,8 @@ class BlockStock(BlockMetaBase, PortfolioStock):
__tablename__ = "block_stock"


register_schema(providers=["em", "eastmoney", "sina"], db_name="block_meta", schema_base=BlockMetaBase)
register_schema(
providers=["em", "eastmoney", "sina"], db_name="block_meta", schema_base=BlockMetaBase, entity_type="block"
)
# the __all__ is generated
__all__ = ["Block", "BlockStock"]
5 changes: 2 additions & 3 deletions src/zvt/domain/meta/country_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from sqlalchemy import Column, String, Float
from sqlalchemy.orm import declarative_base

from zvt.contract.register import register_schema
from zvt.contract.schema import TradableEntity
from zvt.contract.register import register_schema, register_entity

CountryMetaBase = declarative_base()


@register_entity(entity_type="country")
class Country(CountryMetaBase, TradableEntity):
__tablename__ = "country"

Expand All @@ -33,6 +32,6 @@ class Country(CountryMetaBase, TradableEntity):
latitude = Column(Float)


register_schema(providers=["wb"], db_name="country_meta", schema_base=CountryMetaBase)
register_schema(providers=["wb"], db_name="country_meta", schema_base=CountryMetaBase, entity_type="country")
# the __all__ is generated
__all__ = ["Country"]
5 changes: 2 additions & 3 deletions src/zvt/domain/meta/currency_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from sqlalchemy.orm import declarative_base

from zvt.contract.register import register_schema, register_entity
from zvt.contract.register import register_schema
from zvt.contract.schema import TradableEntity

CurrencyMetaBase = declarative_base()


@register_entity(entity_type="currency")
class Currency(CurrencyMetaBase, TradableEntity):
__tablename__ = "currency"


register_schema(providers=["em"], db_name="currency_meta", schema_base=CurrencyMetaBase)
register_schema(providers=["em"], db_name="currency_meta", schema_base=CurrencyMetaBase, entity_type="currency")
# the __all__ is generated
__all__ = ["Currency"]
6 changes: 2 additions & 4 deletions src/zvt/domain/meta/etf_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from sqlalchemy.orm import declarative_base

from zvt.contract import Portfolio, PortfolioStockHistory
from zvt.contract.register import register_schema, register_entity
from zvt.contract.register import register_schema
from zvt.utils.time_utils import now_pd_timestamp

EtfMetaBase = declarative_base()


#: etf
@register_entity(entity_type="etf")
class Etf(EtfMetaBase, Portfolio):
__tablename__ = "etf"
category = Column(String(length=64))
Expand All @@ -27,6 +25,6 @@ class EtfStock(EtfMetaBase, PortfolioStockHistory):
__tablename__ = "etf_stock"


register_schema(providers=["exchange", "joinquant"], db_name="etf_meta", schema_base=EtfMetaBase)
register_schema(providers=["exchange", "joinquant"], db_name="etf_meta", schema_base=EtfMetaBase, entity_type="etf")
# the __all__ is generated
__all__ = ["Etf", "EtfStock"]
6 changes: 2 additions & 4 deletions src/zvt/domain/meta/fund_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from sqlalchemy.orm import declarative_base

from zvt.contract import Portfolio, PortfolioStockHistory
from zvt.contract.register import register_entity, register_schema
from zvt.contract.register import register_schema
from zvt.utils import now_pd_timestamp

FundMetaBase = declarative_base()


#: 个股
@register_entity(entity_type="fund")
class Fund(FundMetaBase, Portfolio):
__tablename__ = "fund"
#: 基金管理人
Expand Down Expand Up @@ -54,6 +52,6 @@ class FundStock(FundMetaBase, PortfolioStockHistory):
__tablename__ = "fund_stock"


register_schema(providers=["joinquant"], db_name="fund_meta", schema_base=FundMetaBase)
register_schema(providers=["joinquant"], db_name="fund_meta", schema_base=FundMetaBase, entity_type="fund")
# the __all__ is generated
__all__ = ["Fund", "FundStock"]
5 changes: 2 additions & 3 deletions src/zvt/domain/meta/future_meta.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
from sqlalchemy.orm import declarative_base

from zvt.contract.register import register_schema, register_entity
from zvt.contract.register import register_schema
from zvt.contract.schema import TradableEntity

FutureMetaBase = declarative_base()


@register_entity(entity_type="future")
class Future(FutureMetaBase, TradableEntity):
__tablename__ = "future"


register_schema(providers=["em"], db_name="future_meta", schema_base=FutureMetaBase)
register_schema(providers=["em"], db_name="future_meta", schema_base=FutureMetaBase, entity_type="future")

# the __all__ is generated
__all__ = ["Future"]
6 changes: 2 additions & 4 deletions src/zvt/domain/meta/index_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from sqlalchemy.orm import declarative_base

from zvt.contract import Portfolio, PortfolioStockHistory
from zvt.contract.register import register_schema, register_entity
from zvt.contract.register import register_schema

IndexMetaBase = declarative_base()


#: 指数
@register_entity(entity_type="index")
class Index(IndexMetaBase, Portfolio):
__tablename__ = "index"

Expand All @@ -27,6 +25,6 @@ class IndexStock(IndexMetaBase, PortfolioStockHistory):
__tablename__ = "index_stock"


register_schema(providers=["em", "exchange"], db_name="index_meta", schema_base=IndexMetaBase)
register_schema(providers=["em", "exchange"], db_name="index_meta", schema_base=IndexMetaBase, entity_type="index")
# the __all__ is generated
__all__ = ["Index", "IndexStock"]

0 comments on commit 268c6c9

Please sign in to comment.