Skip to content

Commit

Permalink
cqlsh.py: fix server side describe after login command
Browse files Browse the repository at this point in the history
since login command create a new session, we run into an
issue when describe was called after login
seem like the row_factory wasn't used

the reason was in the main session we used `execution_profiles`
while in the one create in login we did not, which led to copying
the wrong values (i.e. like `row_factory` that was needed for server
side describe)

this change saves the profiles into the instance, and reuse them
when ever a new session is opened
it simplify the code and we could remove a few repetitions of
the same logic
  • Loading branch information
fruch committed May 30, 2024
1 parent 55aff23 commit bf00c39
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions bin/cqlsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def __init__(self, hostname, port, color=False,
if protocol_version is not None:
kwargs['protocol_version'] = protocol_version

profiles = {
self.profiles = {
EXEC_PROFILE_DEFAULT: ExecutionProfile(consistency_level=cassandra.ConsistencyLevel.ONE,
request_timeout=request_timeout,
row_factory=ordered_dict_factory)
Expand All @@ -487,10 +487,10 @@ def __init__(self, hostname, port, color=False,
if cloudconf is None:
if os.path.exists(self.hostname) and stat.S_ISSOCK(os.stat(self.hostname).st_mode):
kwargs['contact_points'] = (UnixSocketEndPoint(self.hostname),)
profiles[EXEC_PROFILE_DEFAULT].load_balancing_policy = WhiteListRoundRobinPolicy([UnixSocketEndPoint(self.hostname)])
self.profiles[EXEC_PROFILE_DEFAULT].load_balancing_policy = WhiteListRoundRobinPolicy([UnixSocketEndPoint(self.hostname)])
else:
kwargs['contact_points'] = (self.hostname,)
profiles[EXEC_PROFILE_DEFAULT].load_balancing_policy = WhiteListRoundRobinPolicy([self.hostname])
self.profiles[EXEC_PROFILE_DEFAULT].load_balancing_policy = WhiteListRoundRobinPolicy([self.hostname])
kwargs['port'] = self.port
kwargs['ssl_context'] = sslhandling.ssl_settings(hostname, CONFIG_FILE) if ssl else None
# workaround until driver would know not to lose the DNS names for `server_hostname`
Expand All @@ -503,7 +503,7 @@ def __init__(self, hostname, port, color=False,
auth_provider=self.auth_provider,
control_connection_timeout=connect_timeout,
connect_timeout=connect_timeout,
execution_profiles=profiles,
execution_profiles=self.profiles,
**kwargs)
self.owns_connection = not use_conn

Expand Down Expand Up @@ -2140,10 +2140,6 @@ def do_login(self, parsed):
kwargs['port'] = self.port
kwargs['ssl_context'] = self.conn.ssl_context
kwargs['ssl_options'] = self.conn.ssl_options
if os.path.exists(self.hostname) and stat.S_ISSOCK(os.stat(self.hostname).st_mode):
kwargs['load_balancing_policy'] = WhiteListRoundRobinPolicy([UnixSocketEndPoint(self.hostname)])
else:
kwargs['load_balancing_policy'] = WhiteListRoundRobinPolicy([self.hostname])
else:
kwargs['scylla_cloud'] = self.cloudconf

Expand All @@ -2152,6 +2148,7 @@ def do_login(self, parsed):
auth_provider=auth_provider,
control_connection_timeout=self.conn.connect_timeout,
connect_timeout=self.conn.connect_timeout,
execution_profiles=self.profiles,
**kwargs)

if self.current_keyspace:
Expand All @@ -2160,9 +2157,6 @@ def do_login(self, parsed):
session = conn.connect()

# Copy session properties
session.default_timeout = self.session.default_timeout
session.row_factory = self.session.row_factory
session.default_consistency_level = self.session.default_consistency_level
session.max_trace_wait = self.session.max_trace_wait

# Update after we've connected in case we fail to authenticate
Expand Down

0 comments on commit bf00c39

Please sign in to comment.