Skip to content

Commit

Permalink
Merge pull request #115 from xenova/improvements
Browse files Browse the repository at this point in the history
Improvements
  • Loading branch information
xenova committed Aug 3, 2021
2 parents 705a7dd + e271443 commit 6aa8283
Show file tree
Hide file tree
Showing 17 changed files with 878 additions and 462 deletions.
56 changes: 29 additions & 27 deletions chat_downloader/chat_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(self,
self.init_params = locals()
self.init_params.pop('self')

log('debug', 'Python version: {}'.format(sys.version))
log('debug', 'Program version: {}'.format(__version__))
log('debug', 'Initialisation parameters: {}'.format(self.init_params))
log('debug', f'Python version: {sys.version}')
log('debug', f'Program version: {__version__}')
log('debug', f'Initialisation parameters: {self.init_params}')

# Track sessions using a dictionary (allows for reusing)
self.sessions = {}
Expand Down Expand Up @@ -213,22 +213,21 @@ def get_chat(self, url=None,
for k, v in original_params.items():
params[k] = site_object.get_site_value(v)

log('info', 'Site: {}'.format(site_object._NAME))
log('debug', 'Program parameters: {}'.format(params))
log('info', f'Site: {site_object._NAME}')
log('debug', f'Program parameters: {params}')

get_chat = getattr(site_object, function_name, None)
if not get_chat:
raise NotImplementedError(
'{} has not been implemented in {}.'.format(function_name, site.__name__))
f'{function_name} has not been implemented in {site.__name__}.')

chat = get_chat(match, params)
log('debug', 'Match found: "{}". Running "{}" function in "{}".'.format(
match, function_name, site.__name__))
log('debug',
f'Match found: "{match}". Running "{function_name}" function in "{site.__name__}".')

if chat is None:
raise ChatGeneratorError(
'No valid generator found in {} for url "{}"'.format(
site.__name__, url))
f'No valid generator found in {site.__name__} for url "{url}"')

if isinstance(params['max_messages'], int):
chat.chat = itertools.islice(
Expand All @@ -246,14 +245,14 @@ def get_chat(self, url=None,
start = time.time()

def log_on_timeout():
log('debug', 'Timeout occurred after {} seconds.'.format(
time.time() - start))
log('debug',
f'Timeout occurred after {time.time() - start} seconds.')
setattr(chat.chat, 'on_timeout', log_on_timeout)

if isinstance(params['inactivity_timeout'], (float, int)):
def log_on_inactivity_timeout():
log('debug', 'Inactivity timeout occurred after {} seconds.'.format(
params['inactivity_timeout']))
log('debug',
f"Inactivity timeout occurred after {params['inactivity_timeout']} seconds.")
setattr(chat.chat, 'on_inactivity_timeout',
log_on_inactivity_timeout)

Expand All @@ -263,39 +262,43 @@ def log_on_inactivity_timeout():

if params['output']:
chat.attach_writer(ContinuousWriter(
params['output'], indent=params['indent'], sort_keys=params['sort_keys'], overwrite=params['overwrite']))
params['output'],
indent=params['indent'],
sort_keys=params['sort_keys'],
overwrite=params['overwrite'],
lazy_initialise=True
))

chat.site = site_object

log('debug', 'Chat information: {}'.format(chat.__dict__))
log('info', 'Retrieving chat for "{}".'.format(chat.title))
log('debug', f'Chat information: {chat.__dict__}')
log('info', f'Retrieving chat for "{chat.title}".')

return chat

parsed = urlparse(url)
log('debug', str(parsed))

if parsed.netloc:
raise SiteNotSupported(
'Site not supported: {}'.format(parsed.netloc))
raise SiteNotSupported(f'Site not supported: {parsed.netloc}')
elif not parsed.scheme: # No scheme, try to correct
original_params['url'] = 'https://' + url
chat = self.get_chat(**original_params)
if chat:
return chat
else:
raise InvalidURL('Invalid URL: "{}"'.format(url))
raise InvalidURL(f'Invalid URL: "{url}"')

def create_session(self, chat_downloader_class, overwrite=False):
if not issubclass(chat_downloader_class, BaseChatDownloader):
raise TypeError('Unable to create session, class must extend BaseChatDownloader. Class given: {}'.format(
chat_downloader_class))
raise TypeError(
f'Unable to create session, class must extend BaseChatDownloader. Class given: {chat_downloader_class}')
elif chat_downloader_class == BaseChatDownloader:
raise TypeError(
'Unable to create session, class may not be BaseChatDownloader.')

session_name = chat_downloader_class.__name__
log('debug', 'Created {} session.'.format(session_name))
log('debug', f'Created {session_name} session.')

if session_name not in self.sessions or overwrite:
self.sessions[session_name] = chat_downloader_class(
Expand Down Expand Up @@ -357,22 +360,21 @@ def callback(item):
for message in chat:
callback(message)

log('info', 'Finished retrieving chat{}.'.format(
'' if chat.is_live else ' replay'))
log('info', 'Finished retrieving chat messages.')

except (
ChatGeneratorError,
ParsingError,
TestingException
) as e: # Errors which may be bugs
log('error', '{}. Please report this at https://github.com/xenova/chat-downloader/issues/new/choose'.format(e))
log('error', f'{e}. Please report this at https://github.com/xenova/chat-downloader/issues/new/choose')

except ChatDownloaderError as e: # Expected errors
log('error', e)

except ConnectionError as e:
log(
'error', 'Unable to establish a connection. Please check your internet connection. {}'.format(e))
'error', f'Unable to establish a connection. Please check your internet connection. {e}')

except RequestException as e:
log('error', e)
Expand Down
4 changes: 2 additions & 2 deletions chat_downloader/formatting/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, path=None):
if path is not None:
if not os.path.exists(path):
raise FormatFileNotFound(
'Format file not found: "{}"'.format(path))
f'Format file not found: "{path}"')

with open(path) as custom_formats:
self.format_file.update(json.load(custom_formats))
Expand Down Expand Up @@ -128,7 +128,7 @@ def format(self, item, format_name='default', format_object=None):
if not format_object:
if format_name != 'default':
raise FormatNotFound(
'Format not found: "{}"'.format(format_name))
f'Format not found: "{format_name}"')
else:
format_object = default_format_object # Set to default

Expand Down
2 changes: 1 addition & 1 deletion chat_downloader/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
__email__ = 'admin@xenova.com'
__copyright__ = '2020, 2021 xenova'
__url__ = 'https://github.com/xenova/chat-downloader'
__version__ = '0.1.7'
__version__ = '0.1.8'
137 changes: 84 additions & 53 deletions chat_downloader/output/continuous_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
class CW:
"""
Base class for continuous file writers.
Can be used as a context manager (using the `with` keyword).
Otherwise, the writer can be explicitly closed.
"""

def __init__(self, file_name, overwrite=True):
def __init__(self, file_name, overwrite=True, **kwargs):
"""Create a CW object.
:param file_name: The name of the file to write to
Expand All @@ -23,24 +20,11 @@ def __init__(self, file_name, overwrite=True):
:type overwrite: bool, optional
"""
self.file_name = file_name
# subclasses must set self.file

if not os.path.exists(file_name) or overwrite:
directory = os.path.dirname(file_name)
if directory: # (non-empty directory - i.e. not in current folder)
# must make parent directory
os.makedirs(directory, exist_ok=True)
open(file_name, 'w').close() # create an empty file

def __enter__(self):
return self
self.overwrite = overwrite

def close(self):
self.file.close()

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def write(self, item, flush=False):
"""Write a chat item to the file. This method should be implemented in subclasses.
Expand All @@ -63,15 +47,19 @@ class JSONCW(CW):
Class used to control the continuous writing of a list of dictionaries to a JSON file.
"""

def __init__(self, file_name, overwrite=True, indent=None, separator=', ', indent_character=' ', sort_keys=True):
super().__init__(file_name, overwrite)
def __init__(self, file_name, indent=None, separator=', ', indent_character=' ', sort_keys=True, **kwargs):
super().__init__(file_name, **kwargs)

self.indent = indent
self.separator = separator
self.indent_character = indent_character
self.sort_keys = sort_keys

# open file for appending and reading in binary mode.
self.file = open(self.file_name, 'rb+')

# self.file.seek(0) # go to beginning of file

previous_items = [] # save previous
if not overwrite: # may have other data
if not self.overwrite: # may have other data
try:
previous_items = json.load(self.file)
except json.decoder.JSONDecodeError:
Expand All @@ -80,11 +68,6 @@ def __init__(self, file_name, overwrite=True, indent=None, separator=', ', inden

self.file.truncate(0) # empty file

self.indent = indent
self.separator = separator
self.indent_character = indent_character
self.sort_keys = sort_keys

# rewrite with new formatting
for previous_item in previous_items:
self.write(previous_item)
Expand All @@ -111,13 +94,10 @@ def write(self, item, flush=False):
# If empty, write the start of an array
self.file.write('['.encode())
else:
# print(self.file.closed)
# seek to last character
self.file.seek(-len(indent_padding) - 1, os.SEEK_END)
self.file.write(self.separator.encode()) # Write the separator

# self.file.truncate()

self.file.write(to_write.encode()) # Dump the item
self.file.write((indent_padding + ']').encode()) # Close the array

Expand All @@ -130,13 +110,12 @@ class CSVCW(CW):
Class used to control the continuous writing of a list of dictionaries to a CSV file.
"""

def __init__(self, file_name, overwrite=True, sort_keys=True):
super().__init__(file_name, overwrite)

self.file = open(self.file_name, 'a+', newline='',
encoding='utf-8') # , buffering=1
def __init__(self, file_name, sort_keys=True, **kwargs):
super().__init__(file_name, **kwargs)
self.sort_keys = sort_keys
self.file = open(self.file_name, 'a+', newline='', encoding='utf-8')

if not overwrite:
if not self.overwrite:
# save previous data
self.file.seek(0) # go to beginning of file
csv_dict_reader = csv.DictReader(self.file)
Expand All @@ -147,7 +126,6 @@ def __init__(self, file_name, overwrite=True, sort_keys=True):
self.all_items = []

self._reset_dict_writer()
self.sort_keys = sort_keys

def _reset_dict_writer(self):
self.csv_dict_writer = csv.DictWriter(
Expand Down Expand Up @@ -182,10 +160,10 @@ class JSONLCW(CW):
Class used to control the continuous writing of a JSON lines.
"""

def __init__(self, file_name, overwrite=True, sort_keys=True):
super().__init__(file_name, overwrite)
self.file = open(self.file_name, 'a', encoding='utf-8')
def __init__(self, file_name, sort_keys=True, **kwargs):
super().__init__(file_name, **kwargs)
self.sort_keys = sort_keys
self.file = open(self.file_name, 'a', encoding='utf-8')

def write(self, item, flush=False):
print(json.dumps(item, sort_keys=self.sort_keys),
Expand All @@ -197,13 +175,12 @@ class TXTCW(CW):
Class used to control the continuous writing of a text to a TXT file.
"""

def __init__(self, file_name, overwrite=True):
super().__init__(file_name, overwrite)
self.file = open(self.file_name, 'a',
encoding='utf-8') # , buffering=1
def __init__(self, file_name, **kwargs):
super().__init__(file_name, **kwargs)
self.file = open(self.file_name, 'a', encoding='utf-8')

def write(self, item, flush=False):
print(item, file=self.file, flush=flush) # , flush=True
print(item, file=self.file, flush=flush)


class ContinuousWriter:
Expand All @@ -214,19 +191,73 @@ class ContinuousWriter:
'txt': TXTCW
}

def __init__(self, file_name, **kwargs):
extension = os.path.splitext(file_name)[1][1:].lower()
writer_class = self._SUPPORTED_WRITERS.get(extension, TXTCW)
def __init__(self, file_name=None, overwrite=True, format=None, lazy_initialise=False, **kwargs):
"""Create a ContinuousWriter object.
:param file_name: The name of the file to write to
:type file_name: str
:param overwrite: Whether to overwrite if the file already exists, defaults to True
:type overwrite: bool, optional
:param format: The output format, defaults to None (use the extension to decide)
:type format: str, optional
:param lazy_initialise: Skip file creation on initialisation, defaults to False.
:type lazy_initialise: bool, optional
"""
super().__setattr__('data', dict())
self.file_name = file_name
self.overwrite = overwrite
self.format = format
self.lazy_initialise = lazy_initialise

self.data.update(kwargs)

# remove invalid keyword arguments
new_kwargs = {
key: kwargs[key] for key in kwargs if key in writer_class.__init__.__code__.co_varnames}
self.writer = writer_class(file_name, **new_kwargs)
self._initialised = False
if not self.lazy_initialise:
self._real_init()

self.writer=None
def __getattr__(self, name):
if name in self.data:
return self.data[name]

raise AttributeError(
f"'ContinuousWriter' object has no attribute '{name}'")

def __setattr__(self, key, value):
self.data[key] = value

def is_default(self):
return isinstance(self.writer, TXTCW)

def is_initialised(self):
return self._initialised

def _real_init(self):
if self._initialised:
return

self._initialised = True

if self.file_name is None:
raise AttributeError('File name not set')

if not os.path.exists(self.file_name) or self.overwrite:
directory = os.path.dirname(self.file_name)
if directory: # (non-empty directory - i.e. not in current folder)
# must make parent directory
os.makedirs(directory, exist_ok=True)
open(self.file_name, 'w').close() # create an empty file

extension = self.format or os.path.splitext(self.file_name)[
1][1:].lower()
writer_class = ContinuousWriter._SUPPORTED_WRITERS.get(
extension, TXTCW)
self.writer = writer_class(**self.data)

def write(self, item, flush=False):
if not self._initialised: # create file when first item is written
self._real_init()

self.writer.write(item, flush)

def __enter__(self):
Expand Down

0 comments on commit 6aa8283

Please sign in to comment.