In [75]:
import socket
import signal
import threading
from concurrent.futures import ThreadPoolExecutor
import time
from collections import namedtuple
import ssl 
import os
from subprocess import Popen, PIPE
import email.utils as eut
import datetime
import logging
import sys

In [76]:
# Create logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# Create STDERR handler
handler = logging.StreamHandler(sys.stderr)

# Create formatter and add it to the handler
formatter = logging.Formatter('%(asctime)s: %(message)s')
handler.setFormatter(formatter)

# Set STDERR handler as the only handler 
logger.handlers = [handler]

In [164]:
Config = namedtuple('config', 'LISTENNQ MAX_REQUEST_LEN socket, timeout')
config = Config(LISTENNQ=10, MAX_REQUEST_LEN=1024, socket=('', 12345), timeout = 10.0)
terminator = b'\r\n\r\n'
# blackList = ['sing.cse.ust.hk']
blackList= []

In [78]:
def join_with_script_dir(path):
    root = !pwd
    return os.path.join(root[0], path)

certdir = join_with_script_dir('certs/')
cakey = join_with_script_dir('ca.key')
cacert = join_with_script_dir('ca.crt')
certkey = join_with_script_dir('cert.key')

In [139]:
class httpHeader():
    def __init__(self, msg):
        msg = msg.strip(terminator).decode()
        msg = msg.split('\n')
        self.firstLine = msg[0].split(' ')
        header = msg[1:]
        tags = {}
        for row in header:
            field = row.split(':')
            key = field[0].strip().lower()
            value = ':'.join(field[1:]).strip()
            if tags.get(key, -1) == -1:
                tags[key] = value
            else:
                tags[key] += ','+ value
        self.fields = tags
    
    def get(self, key, notFind =None):
        return self.fields.get(key.lower(), notFind)
    
    def appendField(self, tag, value):
        pos = self.raw.find(tag.encode())
        if pos != -1:
            pos = self.raw.find(tag.encode())
            end = self.raw.find(b'\r\n', pos)
            line = ("%s: %s"%(tag, value)).encode()
            self.raw = self.raw[:pos] + line + self.raw[end:]
        else:
            headerEnd = self.raw.find(terminator)
            line = ("\r\n%s: %s"%(tag, value)).encode()
            self.raw = self.raw[:headerEnd] + line + self.raw[headerEnd:]
        
class responseHeader(httpHeader):
    def __init__(self, msg):
        self.raw = msg #in bytes
        msg = msg.split(terminator)
        httpHeader.__init__(self, msg[0])
        if not self.firstLine[0].startswith('HTTP'):
            logger.info("not a vaild response")
            self = None
            return
        self.statusCode = self.firstLine[1]
        self.statusMessage = ' '.join(self.firstLine[2:])
        self.protocol = self.firstLine[0]
        
    def resetAge(self):
        self.startTime = datetime.datetime.now()
                
    def setLifeTime(self, public=False):
        self.startTime = datetime.datetime.now()
        if hasattr(self, 'lifeTime'):
            return True
        control = self.get('Cache-control', False)
        if control and 'max-age' in control:
            pos = control.find('max-age')
            end = control.find(',',pos)
            end = None if end == -1 else end
            self.lifeTime = int(control[pos+8:end])
            return True
        expire = self.get('Expires', False)
        date =self.get('Date', False)
        if expire and date:
            e = eut.parsedate_to_datetime(expire)
            d = eut.parsedate_to_datetime(date)
            delta = e - d
            self.lifeTime = delta.total_seconds()
            return True
        modified = self.get('Last-Modified', False)
        if modified:
            mo = eut.parsedate_to_datetime(modified)
            d = eut.parsedate_to_datetime(date)
            delta = d - mo
            self.lifeTime = delta.total_seconds() / 10
            return True
        if public:
            self.lifeTime = 60000
        return False
    
    def checkIfFresh(self):
        currentTime = datetime.datetime.now()
        if 'immutable' in self.get('Cache-Control',''):
            return int((currentTime - self.startTime).total_seconds())
        expirationTime = self.startTime + datetime.timedelta(0,self.lifeTime)
        delta = expirationTime - currentTime
        if delta.total_seconds() <= 0:
            return False
        return int((currentTime - self.startTime).total_seconds())
    
class requestHeader(httpHeader):
    def __init__(self, msg):
        self.raw = msg #in bytes
        httpHeader.__init__(self, msg)
        if not self.firstLine[-1].startswith('HTTP'):
            logger.info("not a vaild requst")
            self = None
            return
        self.method = self.firstLine[0]
        self.link = self.firstLine[1]
        self.protocol = self.firstLine[2]
        self._parseHost()
        
    def _parseHost(self):
        url = self.get('Host')
        portNo = url.find(':')
        if portNo == -1:
            port = 80
            webserver = url
        else:
            port = int(url[portNo+1:])
            webserver = url[:portNo]
        self.target = webserver; self.port = port
        
    def replaceRelative(self):
        url = self.link.encode()
        newUrl = url.replace(b'http://', b'')
        pathPos = newUrl.find(b'/')
        if pathPos == -1:
            self.link = '/'
            self.raw = self.raw.replace(url, b'/', 1)
        else:
            self.link = newUrl[pathPos:].decode()
            self.raw = self.raw.replace(url, newUrl[pathPos:], 1)

In [190]:
class Cache():
    def __init__(self):
        self.cacheData = {}
        
    def cacheable(self, request, response):
        rqCache = request.get('Cache-Control',[])
        rpCache = response.get('Cache-Control',[])
        noCache = ['no-store', 'private']
        for c in [rqCache,rpCache]:
            if any(xs in c for xs in noCache):
                return False
        authButOk = ["must-revalidate", "public", "s-maxage"]
        if request.get('Authorization', False):
            return any(xs in rpCache for xs in authButOk)
        other = ["s-maxage", 'max-age', 'public']
        if response.get('Expires', False) or \
            response.get('Etag', False) or \
            any(xs in rpCache for xs in other):
            return True
        return False 
    
    def storeResponse(self, request, response):
        if self.cacheable(request, response) and response.setLifeTime():
            pKey = request.link
            host = request.target
            vary = response.get('Vary', False)
            self.cacheData[host] = self.cacheData.get(host, {})
            if vary:
                vary = vary.split(',')
                vary.sort()
                conditions = []
                for v in vary:
                    v = v.strip()
                    tryGet = request.get(v, False)
                    if not tryGet:
                        tryGet = response.get(v, False)
                        if not tryGet:
                            return
                    conditions.append((v, tryGet))
                conditions = tuple(conditions)
                self.cacheData[host][pKey] = self.cacheData[host].get(pKey, {})
                self.cacheData[host][pKey][conditions] = response
            else:
                self.cacheData[host][pKey] = response
            logger.info("%s %s cached"%(host, pKey))
            
    def searchCache(self, request):
        if len(self.cacheData) == 0:
            return False, False
        out = None
        target = self.cacheData.get(request.target, {}).get(request.link,False)
        if target == False :
            return False, False
        if type(target) == responseHeader:
            out = target
        else:
            for vary, response in target.items():
                same = True
                for condition in vary: #check every conditions in vary
                    require = request.get(condition[0], False)
                    if not require or require != condition[1]:
                        same = False
                        break
                if same:
                    out = response
                    break
        if type(out) != responseHeader:
            return False, False
        age = out.checkIfFresh()
        if age: #fresh
            re = ['no-cache', 'revalidate']
            control = request.get('Cache-Control','')
            if any(xs in control for xs in re):
                return out, True #revalidate
            else:
                out.appendField("Age", age)
                return out, False
        else: # check with server
            return out, True
        
    def updateCache(self, request, response, oldCache):
        if response.statusCode == '304' and \
                'Not Modified' in response.statusMessage:
            oldCache.resetAge()
            oldCache.appendField("Age", 0)
            logger.info('renew Cache with 304')
        elif response.statusCode == '200' and \
                'OK' in response.statusMessage and \
                response.setLifeTime():
            self.storeResponse(request, response)
            response.appendField("Age", 0)
            oldCache = response
            logger.info('renew Cache with 200')
        return oldCache

In [94]:
def constructValidate(request, cacheContent):
    etag = cacheContent.get('Etag', False)
    if etag:
        request.appendField('If-None-Match', etag)
    last = cacheContent.get('Last-Modified', False)
    if last:
        request.appendField('If-Modified-Since', last)

In [82]:
def send_cacert(conn):
    with open(cacert, 'rb') as f:
        data = f.read()
    response = '''HTTP/1.1 200 OK\r
Content-Type: application/x-x509-ca-cert\r
Content-Length: %d\r
Connection: close\r
\r\n'''%(len(data))
    response = response.encode('utf8')
    payload = response + data
    conn.sendall(payload)
    
def createCert(certpath, target):
    epoch = "%d" % (time.time() * 1000)
    p1 = Popen(["openssl", "req", "-new", "-key", certkey, "-subj", "/CN=%s" % target], stdout=PIPE)
    p2 = Popen(["openssl", "x509", "-req", "-days", "3650", "-CA", cacert, "-CAkey", cakey, "-set_serial", epoch, "-out", certpath], stdin=p1.stdout, stderr=PIPE)
    p2.communicate()

In [83]:
def createResponse(protocol, statusCode):
    protocol = protocol.encode()
    if statusCode == 200:
        return b'%s 200 Connection Established%s'%(protocol, terminator)
    elif statusCode == 408:
        return b"%s 408 Request Timeout%s"%(protocol, terminator)
    elif statusCode == 404:
        return b"%s 404 Not Found%s"%(protocol, terminator)

In [84]:
class end(Exception): pass

In [191]:
def sendAll(sock, msg):
    totalsent = 0
    while totalsent < len(msg):
        sent = sock.send(msg[totalsent:])
        if sent == 0:
            raise RuntimeError("socket connection broken")
        totalsent = totalsent + sent
    return totalsent
        
def readSocket(sock):
    buffer = b''
    while 1:
        data = sock.recv(config.MAX_REQUEST_LEN)
        if len(data) == 0:
            return buffer, True
        buffer += data
        if terminator in buffer:    #check end
            bufferSplit = buffer.split(terminator)
            header = bufferSplit[0]
            body = terminator.join(bufferSplit[1:])
            headerDetail = httpHeader(header)
            encoding = headerDetail.get('Transfer-Encoding',-1)
            length = headerDetail.get('Content-Length',-1)
            if (encoding == -1 and length == -1) or \
                (encoding == 'chunked' and b'0\r\n\r\n' in body) or \
                int(length) == len(body):
                    return buffer, False
                
#receive connect request from client
def handleRequest(conn):
    (clientSocket, client_address) = conn
    clientSocket.settimeout(config.timeout)
    buffer = b''
    while 1:
        try:
            request = clientSocket.recv(config.MAX_REQUEST_LEN)
            buffer += request
            if terminator in request:
                break
            elif len(request) == 0:
                raise RuntimeError("client close connection during init")
        except ConnectionResetError as e:
            raise RuntimeError("Connection Reset Error")
        except socket.timeout as t:
            sendAll(clientSocket, createResponse("HTTP/1.1", 408))
            raise RuntimeError("Client request timeout")
    request = requestHeader(buffer)
    try:
    ##########special cases#################
        if request.target in blackList:
            response = createResponse(request.protocol, 404)
            sendAll(clientSocket, response)
            raise end()
        if request.target == 'www.proxy.test':
            send_cacert(clientSocket)
            raise end()
        ########################################
        redirect(clientSocket, request)
    except end:
        pass
    clientSocket.close()

#redirect following request and response to and from client
def redirect(conn, request):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(config.timeout)
    conn.settimeout(config.timeout)
    https = request.method == 'CONNECT'
    if https:
        context = ssl.create_default_context()
        s = context.wrap_socket(s, server_hostname=request.target) #ssl connection with real server
        s.connect((request.target, request.port))
        sendAll(conn, createResponse("HTTP/1.1", 200))
        certpath = "%s/%s.crt" % (certdir.rstrip('/'), request.target)
        if not os.path.isfile(certpath): #create self-signed cert
            createCert(certpath, request.target)
        conn = ssl.wrap_socket(conn, keyfile=certkey, certfile=certpath, server_side=True) # ssl connection with client
    else:
        s.connect((request.target, request.port))
    first = True; endOfClientSocket = False; endOfServerSocket = False
    while 1:
        try:
            if https or not first:
                request, endOfClientSocket = readSocket(conn) #read client request
                request = requestHeader(request)
            if not https:
                request.replaceRelative()
            ###################Search Cache#######################
            tryCache, needValidate = cache.searchCache(request)
            if tryCache: #cache found
                if needValidate:
                    constructValidate(request, tryCache)
                    sendAll(s, request.raw)
                    update, endOfServerSocket = readSocket(s)
                    update = responseHeader(update)
                    tryCache = cache.updateCache(request, update, tryCache)
                sendAll(conn, tryCache.raw)
                logger.info("%s %s reply with cache"%(request.target, request.link))
                first = False
                if endOfServerSocket:
                    break
                continue
            sendAll(s, request.raw)
            if endOfClientSocket:
                break
            #######################################################
            response, endOfServerSocket = readSocket(s) # read server response
            sendAll(conn, response)
            ###################Cache############################
            if request.method == 'GET':
                response = responseHeader(response)
                cache.storeResponse(request, response)
            ####################################################
            if endOfServerSocket:
                break
            first = False
        except socket.timeout:
            logger.info('Timeout')
            break
    s.close()

In [86]:
class Server():
    def __init__(self, config):
        signal.signal(signal.SIGINT, self.shutdown) 
        self.serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.serverSocket.bind(config.socket)
        self.serverSocket.listen(config.LISTENNQ)
        self.jobs = []
    
    def shutdown(self):
        for job in self.jobs:
            while not job.done():
                pass
        self.serverSocket.shutdown(socket.SHUT_RDWR)
        self.serverSocket.close()
    

In [42]:
executor = ThreadPoolExecutor(config.LISTENNQ)

In [194]:
server = Server(config)

In [192]:
cache = Cache()

In [193]:
server.shutdown()

In [None]:
while True:
    conn,addr = server.serverSocket.accept()
    server.jobs.append(executor.submit(handleRequest, (conn, addr)))
# while True:
#     conn,addr = server.serverSocket.accept()
#     handleRequest((conn, addr))
server.shutdown()

2019-05-23 17:13:02,166: example.com / cached
2019-05-23 17:13:06,239: renew Cache with 304
2019-05-23 17:13:06,241: example.com / reply with cache
2019-05-23 17:13:06,554: example.com /favicon.ico cached
2019-05-23 17:13:16,476: example.com / cached
2019-05-23 17:13:16,478: renew Cache with 200
2019-05-23 17:13:16,479: example.com / reply with cache
2019-05-23 17:13:16,789: example.com /favicon.ico reply with cache
2019-05-23 17:13:26,801: Timeout
2019-05-23 17:13:40,305: example.com / reply with cache
2019-05-23 17:13:46,052: example.com / reply with cache
2019-05-23 17:13:50,525: renew Cache with 304
2019-05-23 17:13:50,527: example.com / reply with cache
2019-05-23 17:13:50,839: example.com /favicon.ico reply with cache


In [152]:
server.jobs

[<Future at 0x7f9876d8c160 state=finished returned NoneType>,
 <Future at 0x7f9876daef60 state=finished returned NoneType>,
 <Future at 0x7f9876daed30 state=finished returned NoneType>,
 <Future at 0x7f9876daea20 state=finished returned NoneType>,
 <Future at 0x7f9876d97278 state=finished raised RuntimeError>,
 <Future at 0x7f9877ee37b8 state=finished returned NoneType>,
 <Future at 0x7f98745388d0 state=finished raised RuntimeError>,
 <Future at 0x7f9874538a20 state=finished raised RuntimeError>,
 <Future at 0x7f9874538940 state=finished returned NoneType>,
 <Future at 0x7f9874538cc0 state=finished returned NoneType>,
 <Future at 0x7f9877e3e668 state=finished returned NoneType>,
 <Future at 0x7f9877df62b0 state=finished returned NoneType>,
 <Future at 0x7f989c348630 state=finished returned NoneType>,
 <Future at 0x7f9877ed88d0 state=finished returned NoneType>,
 <Future at 0x7f9876d97828 state=finished returned NoneType>,
 <Future at 0x7f9876d97f28 state=finished returned NoneType>,
 <

2019-05-23 16:39:42,465: Timeout


In [19]:
request =b'''GET / HTTP/1.1
Host: example.com
User-Agent: Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:67.0) Gecko/20100101 Firefox/67.0
Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
Accept-Language: en-US,en;q=0.5
Accept-Encoding: gzip, deflate, br
Connection: keep-alive
Upgrade-Insecure-Requests: 1
If-Modified-Since: Fri, 09 Aug 2013 23:54:35 GMT
If-None-Match: "1541025663"
Cache-Control: max-age=0\r\n\r\n'''

In [20]:
request = requestHeader(request)

example.com 80


In [157]:
    cache.cacheData['sing.cse.ust.hk']['/themes/default/scripts/modernizr-2.6.1.min.js']

2592000