Skip to content

Commit

Permalink
Finish 3.2.2
Browse files Browse the repository at this point in the history
many fix
  • Loading branch information
Akkariiin committed May 22, 2018
2 parents 1c4680b + e0867ea commit 190bf5e
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 60 deletions.
2 changes: 1 addition & 1 deletion configloader.py
@@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
import importloader

g_config = None
Expand Down
30 changes: 24 additions & 6 deletions db_transfer.py
@@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-

import logging
import time
Expand All @@ -9,6 +9,7 @@
from shadowsocks import common, shell, lru_cache, obfs
from configloader import load_config, get_config
import importloader
import copy

switchrule = None
db_instance = None
Expand Down Expand Up @@ -80,8 +81,10 @@ def push_db_all_user(self):
def del_server_out_of_bound_safe(self, last_rows, rows):
#停止超流量的服务
#启动没超流量的服务
keymap = {}
try:
switchrule = importloader.load('switchrule')
keymap = switchrule.getRowMap()
except Exception as e:
logging.error('load switchrule.py fail')
cur_servers = {}
Expand All @@ -106,7 +109,10 @@ def del_server_out_of_bound_safe(self, last_rows, rows):
read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user']
for name in read_config_keys:
if name in row and row[name]:
cfg[name] = row[name]
if name in keymap:
cfg[keymap[name]] = row[name]
else:
cfg[name] = row[name]

merge_config_keys = ['password'] + read_config_keys
for name in cfg.keys():
Expand Down Expand Up @@ -392,11 +398,17 @@ def pull_db_all_user(self):
return rows

def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e:
keys = self.key_list
logging.error('load switchrule.py fail')

cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
Expand Down Expand Up @@ -520,11 +532,17 @@ def update_all_user(self, dt_transfer):
return update_transfer

def pull_db_users(self, conn):
keys = copy.copy(self.key_list)
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
keymap = switchrule.getRowMap()
for key in keymap:
if keymap[key] in keys:
keys.remove(keymap[key])
keys.append(key)
keys = switchrule.getKeys(keys)
except Exception as e:
keys = self.key_list
logging.error('load switchrule.py fail')

cur = conn.cursor()

Expand Down
2 changes: 1 addition & 1 deletion importloader.py
@@ -1,5 +1,5 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-

def load(name):
try:
Expand Down
10 changes: 5 additions & 5 deletions server_pool.py
Expand Up @@ -117,14 +117,14 @@ def new_server(self, port, user_config):
else:
a_config = self.config.copy()
a_config.update(user_config)
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]":
if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][-1] == b"]":
a_config['server_ipv6'] = a_config['server_ipv6'][1:-1]
a_config['server'] = a_config['server_ipv6']
a_config['server'] = common.to_str(a_config['server_ipv6'])
a_config['server_port'] = port
a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method'])
try:
logging.info("starting server at [%s]:%d" % (common.to_str(a_config['server']), port))
logging.info("starting server at [%s]:%d" % (a_config['server'], port))

tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False, stat_counter=self.stat_counter)
tcp_server.add_to_loop(self.loop)
Expand All @@ -134,7 +134,7 @@ def new_server(self, port, user_config):
udp_server.add_to_loop(self.loop)
self.udp_ipv6_servers_pool.update({port: udp_server})

if common.to_str(a_config['server_ipv6']) == "::":
if a_config['server_ipv6'] == "::":
ipv6_ok = True
except Exception as e:
logging.warn("IPV6 %s " % (e,))
Expand All @@ -150,7 +150,7 @@ def new_server(self, port, user_config):
a_config['max_connect'] = 128
a_config['method'] = common.to_str(a_config['method'])
try:
logging.info("starting server at %s:%d" % (common.to_str(a_config['server']), port))
logging.info("starting server at %s:%d" % (a_config['server'], port))

tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop)
Expand Down
36 changes: 33 additions & 3 deletions shadowsocks/common.py
Expand Up @@ -121,7 +121,19 @@ def is_ip(address):
return False


def sync_str_bytes(obj, target_example):
"""sync (obj)'s type to (target_example)'s type"""
if type(obj) != type(target_example):
if type(target_example) == str:
obj = to_str(obj)
if type(target_example) == bytes:
obj = to_bytes(obj)
return obj


def match_regex(regex, text):
# avoid 'cannot use a string pattern on a bytes-like object'
regex = sync_str_bytes(regex, text)
regex = re.compile(regex)
for item in regex.findall(text):
return True
Expand Down Expand Up @@ -381,12 +393,12 @@ def test_inet_conv():

def test_parse_header():
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
(0, b'www.google.com', 80, 18)
(0, ADDRTYPE_HOST, b'www.google.com', 80, 18)
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
(0, b'8.8.8.8', 53, 7)
(0, ADDRTYPE_IPV4, b'8.8.8.8', 53, 7)
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
b'\x00\x10\x11\x00\x50')) == \
(0, b'2404:6800:4005:805::1011', 80, 19)
(0, ADDRTYPE_IPV6, b'2404:6800:4005:805::1011', 80, 19)


def test_pack_header():
Expand All @@ -411,7 +423,25 @@ def test_ip_network():
assert 'www.google.com' not in ip_network


def test_sync_str_bytes():
assert sync_str_bytes(b'a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes('a\.b', b'a\.b') == b'a\.b'
assert sync_str_bytes(b'a\.b', 'a\.b') == 'a\.b'
assert sync_str_bytes('a\.b', 'a\.b') == 'a\.b'
pass


def test_match_regex():
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b')
assert match_regex(r'\bgoogle\.com\b', b' google.com ')
pass

if __name__ == '__main__':
test_sync_str_bytes()
test_match_regex()
test_inet_conv()
test_parse_header()
test_pack_header()
Expand Down
4 changes: 3 additions & 1 deletion shadowsocks/crypto/util.py
Expand Up @@ -68,7 +68,9 @@ def find_library(possible_lib_names, search_symbol, library_name):
if path:
paths.append(path)

if not paths:
# always find lib on extend path that to avoid ```CDLL()``` failed on some strange linux environment
# in that case ```ctypes.util.find_library()``` have different find path from ```CDLL()```
if True:
# We may get here when find_library fails because, for example,
# the user does not have sufficient privileges to access those
# tools underlying find_library on linux.
Expand Down
19 changes: 12 additions & 7 deletions shadowsocks/encrypt.py
Expand Up @@ -46,7 +46,7 @@ def try_cipher(key, method=None):
Encryptor(key, method)


def EVP_BytesToKey(password, key_len, iv_len):
def EVP_BytesToKey(password, key_len, iv_len, cache):
# equivalent to OpenSSL's EVP_BytesToKey() with count 1
# so that we make the same key and iv as nodejs version
cached_key = '%s-%d-%d' % (password, key_len, iv_len)
Expand All @@ -66,13 +66,14 @@ def EVP_BytesToKey(password, key_len, iv_len):
ms = b''.join(m)
key = ms[:key_len]
iv = ms[key_len:key_len + iv_len]
cached_keys[cached_key] = (key, iv)
cached_keys.sweep()
if cache:
cached_keys[cached_key] = (key, iv)
cached_keys.sweep()
return key, iv


class Encryptor(object):
def __init__(self, key, method, iv = None):
def __init__(self, key, method, iv = None, cache = False):
self.key = key
self.method = method
self.iv = None
Expand All @@ -81,6 +82,7 @@ def __init__(self, key, method, iv = None):
self.iv_buf = b''
self.cipher_key = b''
self.decipher = None
self.cache = cache
method = method.lower()
self._method_info = self.get_method_info(method)
if self._method_info:
Expand All @@ -105,7 +107,7 @@ def get_cipher(self, password, method, op, iv):
password = common.to_bytes(password)
m = self._method_info
if m[0] > 0:
key, iv_ = EVP_BytesToKey(password, m[0], m[1])
key, iv_ = EVP_BytesToKey(password, m[0], m[1], self.cache)
else:
# key_length == 0 indicates we should use the key directly
key, iv = password, b''
Expand All @@ -119,6 +121,9 @@ def get_cipher(self, password, method, op, iv):

def encrypt(self, buf):
if len(buf) == 0:
if not self.iv_sent:
self.iv_sent = True
return self.cipher_iv
return buf
if self.iv_sent:
return self.cipher.update(buf)
Expand Down Expand Up @@ -155,7 +160,7 @@ def encrypt_all(password, method, op, data):
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else:
key = password
if op:
Expand All @@ -172,7 +177,7 @@ def encrypt_key(password, method):
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
key, _ = EVP_BytesToKey(password, key_len, iv_len, True)
else:
key = password
return key
Expand Down

0 comments on commit 190bf5e

Please sign in to comment.