diff --git a/SQLUsers.py b/SQLUsers.py index 67a45939..3760343e 100755 --- a/SQLUsers.py +++ b/SQLUsers.py @@ -316,14 +316,20 @@ def __init__(self, root, engine): metadata.create_all(engine) self.sessionmaker = sessionmaker(bind=engine, autoflush=True) self.session = self.sessionmaker() - + + def sess(self): + if self.session.is_active: + return self.session + self.session.rollback() + return self.session + def clientFromID(self, db_id): - entry = self.session.query(User).filter(User.id==db_id).first() + entry = self.sess().query(User).filter(User.id==db_id).first() if not entry: return None return OfflineClient(entry) def clientFromUsername(self, username): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if not entry: return None return OfflineClient(entry) @@ -410,10 +416,10 @@ def gen_user_hash(user_pwrd, user_salt, hash_func = SHA256_HASH_FUNC): def check_ban(self, user, ip, userid, now): ## FIXME: "Error reading from DB in in_LOGIN: () takes exactly 2 arguments (3 given)"? - userban = self.session.query(BanUser).filter(BanUser.user_id == userid, now <= BanUser.end_time).first() + userban = self.sess().query(BanUser).filter(BanUser.user_id == userid, now <= BanUser.end_time).first() if (not userban): - ipban = self.session.query(BanIP).filter(BanIP.ip == ip, now <= BanIP.end_time).first() + ipban = self.sess().query(BanIP).filter(BanIP.ip == ip, now <= BanIP.end_time).first() if userban: return True, userban if ipban: return True, ipban @@ -467,7 +473,7 @@ def legacy_login_user(self, username, password, ip, lobby_id, user_id, cpu, loca ## should only ever be one user with each name so we can just grab the first one :) ## password here is unicode(BASE64(MD5(...))), matches the register_user DB encoding - dbuser = self.session.query(User).filter(User.username == username).first() + dbuser = self.sess().query(User).filter(User.username == username).first() if (not dbuser): return False, 'Invalid username or password' @@ -480,7 +486,7 @@ def secure_login_user(self, username, password, ip, lobby_id, user_id, cpu, loca assert(type(username) == str) assert(type(password) == str) - db_user = self.session.query(User).filter(User.username == username).first() + db_user = self.sess().query(User).filter(User.username == username).first() if (not db_user): return False, 'Invalid username' @@ -503,11 +509,11 @@ def secure_login_user(self, username, password, ip, lobby_id, user_id, cpu, loca def end_session(self, db_id): - entry = self.session.query(User).filter(User.id==db_id).first() + entry = self.sess().query(User).filter(User.id==db_id).first() if entry and not entry.logins[-1].end: entry.logins[-1].end = datetime.now() entry.last_login = datetime.now() # in real its last online / last seen - self.session.commit() + self.sess().commit() @@ -528,7 +534,7 @@ def common_register_user(self, session, username, password): if (not status): return False, reason - dbuser = self.session.query(User).filter(User.username == username).first() + dbuser = self.sess().query(User).filter(User.username == username).first() if (dbuser): return False, 'Username already exists.' @@ -545,8 +551,8 @@ def legacy_register_user(self, username, password, ip, country): ## note: password here is BASE64(MD5(...)) and already in unicode entry = User(username, password, "", ip) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() return True, 'Account registered successfully.' def secure_register_user(self, username, password, ip, country): @@ -566,24 +572,24 @@ def secure_register_user(self, username, password, ip, country): def ban_user(self, owner, username, duration, reason): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if not entry: return "Couldn't ban %s, user doesn't exist" % (username) end_time = datetime.now() + timedelta(duration) ban = BanUser(entry.id, owner.db_id, reason, end_time) - self.session.add(ban) - self.session.commit() + self.sess().add(ban) + self.sess().commit() return 'Successfully banned %s for %s days.' % (username, duration) def unban_user(self, username): client = self.clientFromUsername(username) if not client: return "User %s doesn't exist" % username - results = self.session.query(BanUser).filter(BanUser.user_id==client.id) + results = self.sess().query(BanUser).filter(BanUser.user_id==client.id) if results: for result in results: - self.session.delete(result) - self.session.commit() + self.sess().delete(result) + self.sess().commit() return 'Successfully unbanned %s.' % username else: return 'No matching bans for %s.' % username @@ -592,23 +598,23 @@ def ban_ip(self, owner, ip, duration, reason): # TODO: add owner field to the database for bans end_time = datetime.now() + timedelta(duration) ban = BanIP(ip, owner.db_id, reason, end_time) - self.session.add(ban) - self.session.commit() + self.sess().add(ban) + self.sess().commit() return 'Successfully banned %s for %s days.' % (ip, duration) def unban_ip(self, ip): - results = self.session.query(BanIP).filter(BanIP.ip==ip) + results = self.sess().query(BanIP).filter(BanIP.ip==ip) if results: for result in results: - self.session.delete(result) - self.session.commit() + self.sess().delete(result) + self.sess().commit() return 'Successfully unbanned %s.' % ip else: return 'No matching bans for %s.' % ip def banlistuser(self): banlist = [] - for ban in self.session.query(BanUser, User.id, BanUser.end_time, BanUser.reason, User.username).join(User,BanUser.user_id == User.id ): + for ban in self.sess().query(BanUser, User.id, BanUser.end_time, BanUser.reason, User.username).join(User,BanUser.user_id == User.id ): banlist.append({ 'userid': ban.id, 'username': ban.username, @@ -619,7 +625,7 @@ def banlistuser(self): def banlistip(self): banlist = [] - for ban in self.session.query(BanIP): + for ban in self.sess().query(BanIP): banlist.append({ 'userid': ban.ip, 'endtime': ban.end_time, @@ -633,14 +639,14 @@ def rename_user(self, user, newname): if not self._root.SayHooks._nasty_word_censor(user): return False, 'New username failed to pass profanity filter.' if not newname == user: - results = self.session.query(User).filter(User.username==newname).first() + results = self.sess().query(User).filter(User.username==newname).first() if results: return False, 'Username already exists.' - entry = self.session.query(User).filter(User.username==user).first() + entry = self.sess().query(User).filter(User.username==user).first() if not entry: return False, 'You don\'t seem to exist anymore. Contact an admin or moderator.' entry.renames.append(Rename(user, newname)) entry.username = newname - self.session.commit() + self.sess().commit() # need to iterate through channels and rename junk there... # it might actually be a lot easier to use userids in the server... # later. return True, 'Account renamed successfully.' @@ -648,7 +654,7 @@ def rename_user(self, user, newname): def save_user(self, obj): ## assert(isinstance(obj, User) or isinstance(obj, Client)) - entry = self.session.query(User).filter(User.username==obj.username).first() + entry = self.sess().query(User).filter(User.username==obj.username).first() if (entry != None): ## caller might have changed these! @@ -660,25 +666,25 @@ def save_user(self, obj): entry.last_id = obj.last_id entry.email = obj.email - self.session.commit() + self.sess().commit() def confirm_agreement(self, client): - entry = self.session.query(User).filter(User.username==client.username).first() + entry = self.sess().query(User).filter(User.username==client.username).first() if entry: entry.access = 'user' - self.session.commit() + self.sess().commit() def get_lastlogin(self, username): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if entry: return True, entry.last_login else: return False, 'User not found.' def get_registration_date(self, username): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if entry and entry.register_date: return True, entry.register_date else: return False, 'user or date not found in database' def get_ingame_time(self, username): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if entry: return True, entry.ingame_time else: return False, 'user not found in database' @@ -689,21 +695,21 @@ def get_account_access(self, username): else: return False, 'user not found in database' def find_ip(self, ip): - results = self.session.query(User).filter(User.last_ip==ip) + results = self.sess().query(User).filter(User.last_ip==ip) return results def get_ip(self, username): - entry = self.session.query(User).filter(User.username==username).first() + entry = self.sess().query(User).filter(User.username==username).first() if not entry: return None return entry.last_ip def remove_user(self, user): - entry = self.session.query(User).filter(User.username==user).first() + entry = self.sess().query(User).filter(User.username==user).first() if not entry: return False, 'User not found.' - self.session.delete(entry) - self.session.commit() + self.sess().delete(entry) + self.sess().commit() return True, 'Success.' def clean_users(self): @@ -711,92 +717,92 @@ def clean_users(self): now = datetime.now() #delete users: # which didn't accept aggreement after one day - self.session.query(User).filter(User.register_date < now - timedelta(days=1)).filter(User.access == "agreement").delete(synchronize_session=False) + self.sess().query(User).filter(User.register_date < now - timedelta(days=1)).filter(User.access == "agreement").delete(synchronize_session=False) # which have no ingame time, last login > 30 days and no bot - self.session.query(User).filter(User.ingame_time == 0).filter(User.last_login < now - timedelta(days=30)).filter(User.bot == 0).filter(User.access == "user").delete(synchronize_session=False) + self.sess().query(User).filter(User.ingame_time == 0).filter(User.last_login < now - timedelta(days=30)).filter(User.bot == 0).filter(User.access == "user").delete(synchronize_session=False) # last login > 3 years - self.session.query(User).filter(User.last_login < now - timedelta(days=1095)).delete(synchronize_session=False) + self.sess().query(User).filter(User.last_login < now - timedelta(days=1095)).delete(synchronize_session=False) # old messages > 2 weeks - self.session.query(ChannelHistory).filter(ChannelHistory.time < now - timedelta(days=14)).delete(synchronize_session=False) + self.sess().query(ChannelHistory).filter(ChannelHistory.time < now - timedelta(days=14)).delete(synchronize_session=False) - self.session.commit() + self.sess().commit() def ignore_user(self, user_id, ignore_user_id, reason=None): entry = Ignore(user_id, ignore_user_id, reason) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() def unignore_user(self, user_id, unignore_user_id): - entry = self.session.query(Ignore).filter(Ignore.user_id == user_id).filter(Ignore.ignored_user_id == unignore_user_id).one() - self.session.delete(entry) - self.session.commit() + entry = self.sess().query(Ignore).filter(Ignore.user_id == user_id).filter(Ignore.ignored_user_id == unignore_user_id).one() + self.sess().delete(entry) + self.sess().commit() # returns id-s of users who had their ignore removed def globally_unignore_user(self, unignore_user_id): - q = self.session.query(Ignore).filter(Ignore.ignored_user_id == unignore_user_id) + q = self.sess().query(Ignore).filter(Ignore.ignored_user_id == unignore_user_id) userids = [ignore.user_id for ignore in q.all()] # could be done in one query + hook, fix if bored - self.session.query(Ignore).filter(Ignore.ignored_user_id == unignore_user_id).delete() - self.session.commit() + self.sess().query(Ignore).filter(Ignore.ignored_user_id == unignore_user_id).delete() + self.sess().commit() return userids def is_ignored(self, user_id, ignore_user_id): - exists = self.session.query(Ignore).filter(Ignore.user_id == user_id).filter(Ignore.ignored_user_id == ignore_user_id).count() > 0 + exists = self.sess().query(Ignore).filter(Ignore.user_id == user_id).filter(Ignore.ignored_user_id == ignore_user_id).count() > 0 return exists def get_ignore_list(self, user_id): - users = self.session.query(Ignore).filter(Ignore.user_id == user_id).all() + users = self.sess().query(Ignore).filter(Ignore.user_id == user_id).all() users = [(user.ignored_user_id, user.reason) for user in users] return users def get_ignored_user_ids(self, user_id): - user_ids = self.session.query(Ignore.ignored_user_id).filter(Ignore.user_id == user_id).all() + user_ids = self.sess().query(Ignore.ignored_user_id).filter(Ignore.user_id == user_id).all() user_ids = [user_id for user_id, in user_ids] return user_ids def friend_users(self, user_id, friend_user_id): entry = Friend(user_id, friend_user_id) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() def unfriend_users(self, first_user_id, second_user_id): - self.session.query(Friend).filter(Friend.first_user_id == first_user_id).filter(Friend.second_user_id == second_user_id).delete() - self.session.query(Friend).filter(Friend.second_user_id == first_user_id).filter(Friend.first_user_id == second_user_id).delete() - self.session.commit() + self.sess().query(Friend).filter(Friend.first_user_id == first_user_id).filter(Friend.second_user_id == second_user_id).delete() + self.sess().query(Friend).filter(Friend.second_user_id == first_user_id).filter(Friend.first_user_id == second_user_id).delete() + self.sess().commit() def are_friends(self, first_user_id, second_user_id): - q1 = self.session.query(Friend).filter(Friend.first_user_id == first_user_id) - q2 = self.session.query(Friend).filter(Friend.second_user_id == second_user_id) + q1 = self.sess().query(Friend).filter(Friend.first_user_id == first_user_id) + q2 = self.sess().query(Friend).filter(Friend.second_user_id == second_user_id) exists = q1.union(q2).count() > 0 return exists def get_friend_user_ids(self, user_id): - q1 = self.session.query(Friend.second_user_id).filter(Friend.first_user_id == user_id) - q2 = self.session.query(Friend.first_user_id).filter(Friend.second_user_id == user_id) + q1 = self.sess().query(Friend.second_user_id).filter(Friend.first_user_id == user_id) + q2 = self.sess().query(Friend.first_user_id).filter(Friend.second_user_id == user_id) user_ids = q1.union(q2).all() user_ids = [user_id for user_id, in user_ids] return user_ids def has_friend_request(self, user_id, friend_user_id): - request = self.session.query(FriendRequest).filter(FriendRequest.user_id == user_id).filter(FriendRequest.friend_user_id == friend_user_id) + request = self.sess().query(FriendRequest).filter(FriendRequest.user_id == user_id).filter(FriendRequest.friend_user_id == friend_user_id) exists = request.count() > 0 return exists def add_friend_request(self, user_id, friend_user_id, msg=None): entry = FriendRequest(user_id, friend_user_id, msg) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() def remove_friend_request(self, user_id, friend_user_id): - self.session.query(FriendRequest).filter(FriendRequest.user_id == user_id).filter(FriendRequest.friend_user_id == friend_user_id).delete() - self.session.commit() + self.sess().query(FriendRequest).filter(FriendRequest.user_id == user_id).filter(FriendRequest.friend_user_id == friend_user_id).delete() + self.sess().commit() # this returns all friend requests sent _to_ user_id def get_friend_request_list(self, user_id): - reqs = self.session.query(FriendRequest).filter(FriendRequest.friend_user_id == user_id).all() + reqs = self.sess().query(FriendRequest).filter(FriendRequest.friend_user_id == user_id).all() users = [(req.user_id, req.msg) for req in reqs] return users @@ -804,16 +810,16 @@ def add_channel_message(self, channel_id, user_id, msg, date = None): if date is None: date = datetime.now() entry = ChannelHistory(channel_id, user_id, msg, date) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() #returns a list of channel messages since starttime for the specific userid when he is subscribed to the channel # [[date, user, msg], [date, user, msg], ...] def get_channel_messages(self, user_id, channel_id, starttime): - entry = self.session.query(ChannelHistorySubscription).filter(ChannelHistorySubscription.channel_id == channel_id).filter(ChannelHistorySubscription.user_id == user_id).first() + entry = self.sess().query(ChannelHistorySubscription).filter(ChannelHistorySubscription.channel_id == channel_id).filter(ChannelHistorySubscription.user_id == user_id).first() if not entry: return [] - reqs = self.session.query(ChannelHistory, User).filter(ChannelHistory.channel_id == channel_id).filter(ChannelHistory.time >= starttime).filter(ChannelHistory.user_id == User.id).all() + reqs = self.sess().query(ChannelHistory, User).filter(ChannelHistory.channel_id == channel_id).filter(ChannelHistory.time >= starttime).filter(ChannelHistory.user_id == User.id).all() msgs = [(history.time, user.username, history.msg) for history, user in reqs ] if len(msgs)>0: assert(type(msgs[0][2]) == str) @@ -824,8 +830,8 @@ def add_channelhistory_subscription(self, channel_id, user_id): assert(user_id > 0) try: entry = ChannelHistorySubscription(channel_id, user_id) - self.session.add(entry) - self.session.commit() + self.sess().add(entry) + self.sess().commit() except IntegrityError: return False, "Already subscribed" except Exception as e: @@ -834,14 +840,14 @@ def add_channelhistory_subscription(self, channel_id, user_id): def remove_channelhistory_subscription(self, channel_id, user_id): try: - self.session.query(ChannelHistorySubscription).filter(ChannelHistorySubscription.channel_id == channel_id).filter(ChannelHistorySubscription.user_id == user_id).delete() - self.session.commit() + self.sess().query(ChannelHistorySubscription).filter(ChannelHistorySubscription.channel_id == channel_id).filter(ChannelHistorySubscription.user_id == user_id).delete() + self.sess().commit() except Exception as e: return False, str(e) return True, "" def get_channel_subscriptions(self, user_id): - reqs = self.session.query(ChannelHistorySubscription, Channel).filter(ChannelHistorySubscription.user_id == user_id).filter(ChannelHistorySubscription.channel_id == Channel.id) .all() + reqs = self.sess().query(ChannelHistorySubscription, Channel).filter(ChannelHistorySubscription.user_id == user_id).filter(ChannelHistorySubscription.channel_id == Channel.id) .all() channels = [(channel.name) for sub, channel in reqs] return channels @@ -852,12 +858,18 @@ def __init__(self, root, engine): self.sessionmaker = sessionmaker(bind=engine, autoflush=True) self.session = self.sessionmaker() + def sess(self): + if self.session.is_active: + return self.session + self.session.rollback() + return self.session + def load_channel(self, name): - entry = self.session.query(Channel).filter(Channel.name == name).first() + entry = self.sess().query(Channel).filter(Channel.name == name).first() return entry def load_channels(self): - response = self.session.query(Channel) + response = self.sess().query(Channel) channels = {} for chan in response: channels[chan.name] = { @@ -873,27 +885,27 @@ def load_channels(self): return channels def setTopic(self, user, chan, topic): - entry = self.session.query(Channel).filter(Channel.name == chan.name).first() + entry = self.sess().query(Channel).filter(Channel.name == chan.name).first() if entry: entry.topic = topic entry.topic_time = datetime.now() entry.topic_owner = user - self.session.commit() + self.sess().commit() def setKey(self, chan, key): - entry = self.session.query(Channel).filter(Channel.name == chan.name).first() + entry = self.sess().query(Channel).filter(Channel.name == chan.name).first() if entry: entry.key = key - self.session.commit() + self.sess().commit() def setHistory(self, chan): - entry = self.session.query(Channel).filter(Channel.name == chan.name).first() + entry = self.sess().query(Channel).filter(Channel.name == chan.name).first() if entry: entry.store_history = chan.store_history - self.session.commit() + self.sess().commit() def register(self, channel, target): - entry = self.session.query(Channel).filter(Channel.name == channel.name).first() + entry = self.sess().query(Channel).filter(Channel.name == channel.name).first() if not entry: entry = Channel(channel.name) if channel.topic: @@ -903,14 +915,14 @@ def register(self, channel, target): else: entry.topic_time = datetime.now() entry.owner = target.username - self.session.add(entry) - self.session.commit() - entry = self.session.query(Channel).filter(Channel.name == channel.name).first() # set db id to runtime object + self.sess().add(entry) + self.sess().commit() + entry = self.sess().query(Channel).filter(Channel.name == channel.name).first() # set db id to runtime object channel.id = entry.id def unRegister(self, client, channel): - entry = self.session.query(Channel).filter(Channel.name == channel.name).delete() - self.session.commit() + entry = self.sess().query(Channel).filter(Channel.name == channel.name).delete() + self.sess().commit() if __name__ == '__main__': class root():