Skip to content

Commit

Permalink
Merge branch '9302-2.2' into 9302-3.0
Browse files Browse the repository at this point in the history
Conflicts:
	bin/cqlsh.py
	pylib/cqlshlib/copy.py
  • Loading branch information
Stefania Alborghetti committed Nov 9, 2015
2 parents 44e165c + 3bae239 commit 5e3fa29
Show file tree
Hide file tree
Showing 3 changed files with 683 additions and 333 deletions.
308 changes: 31 additions & 277 deletions bin/cqlsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import csv
import getpass
import locale
import multiprocessing as mp
import optparse
import os
import platform
Expand Down Expand Up @@ -427,7 +426,8 @@ def complete_copy_column_names(ctxt, cqlsh):


COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING',
'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']
'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS',
'CHUNKSIZE', 'PERCENTCOMPLETED', 'MAXBATCHSIZE', 'MINBATCHSIZE', 'REPORTFREQUENCY']


@cqlsh_syntax_completer('copyOption', 'optnames')
Expand All @@ -436,7 +436,9 @@ def complete_copy_options(ctxt, cqlsh):
direction = ctxt.get_binding('dir').upper()
opts = set(COPY_OPTIONS) - set(optnames)
if direction == 'FROM':
opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'])
opts -= set(['ENCODING', 'TIMEFORMAT', 'PAGESIZE', 'PAGETIMEOUT'])
elif direction == 'TO':
opts -= set(['CHUNKSIZE', 'PERCENTCOMPLETED', 'MAXBATCHSIZE', 'MINBATCHSIZE'])
return opts


Expand Down Expand Up @@ -576,6 +578,19 @@ def extend_cql_deserialization():
cassandra.cqltypes.CassandraType.support_empty_values = True


def insert_driver_hooks():
extend_cql_deserialization()
auto_format_udts()


def extend_cql_deserialization():
"""
The python driver returns BLOBs as string, but we expect them as bytearrays
"""
cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
cassandra.cqltypes.CassandraType.support_empty_values = True


def auto_format_udts():
# when we see a new user defined type, set up the shell formatting for it
udt_apply_params = cassandra.cqltypes.UserType.apply_parameters
Expand Down Expand Up @@ -1753,9 +1768,18 @@ def do_copy(self, parsed):
ESCAPE='\' - character to appear before the QUOTE char when quoted
HEADER=false - whether to ignore the first line
NULL='' - string that represents a null value
ENCODING='utf8' - encoding for CSV output (COPY TO only)
TIMEFORMAT= - timestamp strftime format (COPY TO only)
ENCODING='utf8' - encoding for CSV output (COPY TO)
TIMEFORMAT= - timestamp strftime format (COPY TO)
'%Y-%m-%d %H:%M:%S%z' defaults to time_format value in cqlshrc
JOBS='12' - the number of jobs each process can work on at a time
PAGESIZE='1000' - the page size for fetching results (COPY TO)
PAGETIMEOUT=10 - the page timeout for fetching results (COPY TO)
MAXATTEMPTS='5' - the maximum number of attempts for errors
CHUNKSIZE='2000' - the size of chunks passed to worker processes (COPY FROM)
PERCENTCOMPLETED='0.1' - the percentage or records already imported before we send more (COPY FROM)
MAXBATCHSIZE='20' - the maximum size of an import batch (COPY FROM)
MINBATCHSIZE='2' - the minimum size of an import batch (COPY FROM)
REPORTFREQUENCY='10000' - the frequency with which we display status updates
When entering CSV data on STDIN, you can use the sequence "\."
on a line by itself to end the data input.
Expand Down Expand Up @@ -1801,109 +1825,9 @@ def perform_csv_import(self, ks, cf, columns, fname, opts):
self.printerr('Unrecognized COPY FROM options: %s'
% ', '.join(unrecognized_options.keys()))
return 0
nullval, header = csv_options['nullval'], csv_options['header']

if fname is None:
do_close = False
print "[Use \. on a line by itself to end input]"
linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
else:
do_close = True
try:
linesource = open(fname, 'rb')
except IOError, e:
self.printerr("Can't open %r for reading: %s" % (fname, e))
return 0

current_record = None
processes, pipes = [], [],
try:
if header:
linesource.next()
reader = csv.reader(linesource, **dialect_options)

num_processes = copy.get_num_processes(cap=4)

for i in range(num_processes):
parent_conn, child_conn = mp.Pipe()
pipes.append(parent_conn)
processes.append(ImportProcess(self, child_conn, ks, cf, columns, nullval))

for process in processes:
process.start()

meter = copy.RateMeter(10000)
for current_record, row in enumerate(reader, start=1):
# write to the child process
pipes[current_record % num_processes].send((current_record, row))

# update the progress and current rate periodically
meter.increment()

# check for any errors reported by the children
if (current_record % 100) == 0:
if self._check_import_processes(current_record, pipes):
# no errors seen, continue with outer loop
continue
else:
# errors seen, break out of outer loop
break
except Exception, exc:
if current_record is None:
# we failed before we started
self.printerr("\nError starting import process:\n")
self.printerr(str(exc))
if self.debug:
traceback.print_exc()
else:
self.printerr("\n" + str(exc))
self.printerr("\nAborting import at record #%d. "
"Previously inserted records and some records after "
"this number may be present."
% (current_record,))
if self.debug:
traceback.print_exc()
finally:
# send a message that indicates we're done
for pipe in pipes:
pipe.send((None, None))

for process in processes:
process.join()

self._check_import_processes(current_record, pipes)

for pipe in pipes:
pipe.close()

if do_close:
linesource.close()
elif self.tty:
print

return current_record

def _check_import_processes(self, current_record, pipes):
for pipe in pipes:
if pipe.poll():
try:
(record_num, error) = pipe.recv()
self.printerr("\n" + str(error))
self.printerr(
"Aborting import at record #%d. "
"Previously inserted records are still present, "
"and some records after that may be present as well."
% (record_num,))
return False
except EOFError:
# pipe is closed, nothing to read
self.printerr("\nChild process died without notification, "
"aborting import at record #%d. Previously "
"inserted records are probably still present, "
"and some records after that may be present "
"as well." % (current_record,))
return False
return True
return copy.ImportTask(self, ks, cf, columns, fname, csv_options, dialect_options,
DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()

def perform_csv_export(self, ks, cf, columns, fname, opts):
csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
Expand Down Expand Up @@ -2289,176 +2213,6 @@ def printerr(self, text, color=RED, newline=True, shownum=None):
self.writeresult(text, color, newline=newline, out=sys.stderr)


class ImportProcess(mp.Process):

def __init__(self, parent, pipe, ks, cf, columns, nullval):
mp.Process.__init__(self)
self.pipe = pipe
self.nullval = nullval
self.ks = ks
self.cf = cf

# validate we can fetch metdata but don't store it since win32 needs to pickle
parent.get_table_meta(ks, cf)

self.columns = columns
self.consistency_level = parent.consistency_level
self.connect_timeout = parent.conn.connect_timeout
self.hostname = parent.hostname
self.port = parent.port
self.ssl = parent.ssl
self.auth_provider = parent.auth_provider
self.cql_version = parent.conn.cql_version
self.debug = parent.debug

def run(self):
new_cluster = Cluster(
contact_points=(self.hostname,),
port=self.port,
cql_version=self.cql_version,
protocol_version=DEFAULT_PROTOCOL_VERSION,
auth_provider=self.auth_provider,
ssl_options=sslhandling.ssl_settings(self.hostname, CONFIG_FILE) if self.ssl else None,
load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
compression=None,
connect_timeout=self.connect_timeout)
session = new_cluster.connect(self.ks)
conn = session._pools.values()[0]._connection

table_meta = new_cluster.metadata.keyspaces[self.ks].tables[self.cf]

pk_cols = [col.name for col in table_meta.primary_key]
cqltypes = [table_meta.columns[name].cql_type for name in self.columns]
pk_indexes = [self.columns.index(col.name) for col in table_meta.primary_key]
query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
protect_name(self.ks),
protect_name(self.cf),
', '.join(protect_names(self.columns)))

# we need to handle some types specially
should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet') for t in cqltypes]

insert_timestamp = int(time.time() * 1e6)

def callback(record_num, response):
# This is the callback we register for all inserts. Because this
# is run on the event-loop thread, we need to hold a lock when
# adjusting in_flight.
with conn.lock:
conn.in_flight -= 1

if not isinstance(response, ResultMessage):
# It's an error. Notify the parent process and let it send
# a stop signal to all child processes (including this one).
self.pipe.send((record_num, str(response)))
if isinstance(response, Exception) and self.debug:
traceback.print_exc(response)

current_record = 0
insert_num = 0
try:
while True:
# To avoid totally maxing out the connection,
# defer to the reactor thread when we're close
# to capacity
if conn.in_flight > (conn.max_request_id * 0.9):
conn._readable = True
time.sleep(0.05)
continue

try:
(current_record, row) = self.pipe.recv()
except EOFError:
# the pipe was closed and there's nothing to receive
sys.stdout.write('Failed to read from pipe:\n\n')
sys.stdout.flush()
conn._writable = True
conn._readable = True
break

# see if the parent process has signaled that we are done
if (current_record, row) == (None, None):
conn._writable = True
conn._readable = True
self.pipe.close()
break

# format the values in the row
for i, value in enumerate(row):
if value != self.nullval:
if should_escape[i]:
row[i] = protect_value(value)
elif i in pk_indexes:
# By default, nullval is an empty string. See CASSANDRA-7792 for details.
message = "Cannot insert null value for primary key column '%s'." % (pk_cols[i],)
if self.nullval == '':
message += " If you want to insert empty strings, consider using " \
"the WITH NULL=<marker> option for COPY."
self.pipe.send((current_record, message))
return
else:
row[i] = 'null'

full_query = query % (','.join(row),)
query_message = QueryMessage(
full_query, self.consistency_level, serial_consistency_level=None,
fetch_size=None, paging_state=None, timestamp=insert_timestamp)

request_id = conn.get_request_id()
conn.send_msg(query_message, request_id=request_id, cb=partial(callback, current_record))

with conn.lock:
conn.in_flight += 1

# every 50 records, clear the pending writes queue and read
# any responses we have
if insert_num % 50 == 0:
conn._writable = True
conn._readable = True

insert_num += 1
except Exception, exc:
self.pipe.send((current_record, str(exc)))
finally:
# wait for any pending requests to finish
while conn.in_flight > 0:
conn._readable = True
time.sleep(0.1)

new_cluster.shutdown()

def stop(self):
self.terminate()


class RateMeter(object):

def __init__(self, log_rate):
self.log_rate = log_rate
self.last_checkpoint_time = time.time()
self.current_rate = 0.0
self.current_record = 0

def increment(self):
self.current_record += 1

if (self.current_record % self.log_rate) == 0:
new_checkpoint_time = time.time()
new_rate = self.log_rate / (new_checkpoint_time - self.last_checkpoint_time)
self.last_checkpoint_time = new_checkpoint_time

# smooth the rate a bit
if self.current_rate == 0.0:
self.current_rate = new_rate
else:
self.current_rate = (self.current_rate + new_rate) / 2.0

output = 'Processed %s rows; Write: %.2f rows/s\r' % \
(self.current_record, self.current_rate)
sys.stdout.write(output)
sys.stdout.flush()


class SwitchCommand(object):
command = None
description = None
Expand Down

0 comments on commit 5e3fa29

Please sign in to comment.