In [None]:
import socket
import signal
import threading
from concurrent.futures import ThreadPoolExecutor
import time
from collections import namedtuple
import ssl 
import os
from subprocess import Popen, PIPE

In [None]:
Config = namedtuple('config', 'LISTENNQ MAX_REQUEST_LEN socket, timeout')
config = Config(LISTENNQ=10, MAX_REQUEST_LEN=1024, socket=('', 12345), timeout = 10.0)
terminator = b'\r\n\r\n'
blackList = ['sing.cse.ust.hk']

In [None]:
def join_with_script_dir(path):
    root = !pwd
    return os.path.join(root[0], path)

certdir = join_with_script_dir('certs/')
cakey = join_with_script_dir('ca.key')
cacert = join_with_script_dir('ca.crt')
certkey = join_with_script_dir('cert.key')

In [None]:
class httpHeader():
    def __init__(self, msg):
        self.raw = msg #in bytes
        msg = msg.strip(terminator).decode()
        msg = msg.split('\n')
        self.firstLine = msg[0].split(' ')
        header = msg[1:]
        tags = {}
        for row in header:
            field = row.split(':')
            key = field[0].strip()
            value = ':'.join(field[1:]).strip()
            if tags.get(key, -1) == -1:
                tags[key] = value
            else:
                tags[key] += ','+ value
        self.fields = tags
        
class responseHeader(httpHeader):
    def __init__(self, msg):
        super(responseHeader,self).__init__(msg)
        if not self.firstLine[0].startswith('HTTP'):
            print("not a vaild response")
            self = None
            return
        self.statusCode = self.firstLine[1]
        self.statusMessage = self.firstLine[2:]
        self.protocol = self.firstLine[0]
        
class requestHeader(httpHeader):
    def __init__(self, msg):
        super(requestHeader,self).__init__(msg)
        if not self.firstLine[-1].startswith('HTTP'):
            print("not a vaild requst")
            self = None
            return
        self.method = self.firstLine[0]
        self.link = self.firstLine[1]
        self.protocol = self.firstLine[2]
        
    def parseHost(self):
        url = self.fields['Host']
        portNo = url.find(':')
        if portNo == -1:
            port = 80
            webserver = url
        else:
            port = int(url[portNo+1:])
            webserver = url[:portNo]
        print(webserver, port)
        self.target = webserver; self.port = port
        
    def replaceRelative(self):
        url = self.link.encode()
        newUrl = url.replace(b'http://', b'')
        pathPos = newUrl.find(b'/')
        if pathPos == -1:
            self.raw = self.raw.replace(url, b'/', 1)
        else:
            self.raw = self.raw.replace(url, newUrl[pathPos:], 1)
            

In [None]:
def cache(request, response):
    rqCache = request.fields.get('Cache-Control',[])
    rpCache = response.fields.get('Cache-Control',[])
    noCache = ['no-store', 'private']
    if [s for s in [rqCache,rpCache] if any(xs in s for xs in noCache)]:
        return
    authButOk = ["must-revalidate", "public", "s-maxage"]
    if request.fields.get('Authorization', -1) != -1 and not \
        [s for s in [rqCache,rpCache] if any(xs in s for xs in authButOk)]:
        return

In [None]:
rqCache = request.fields.get('Cache-Control',[])
rpCache = response.fields.get('Cache-Control',[])
authButOk = ["must-revalidate", "public", "s-maxage"]
if request.fields.get('Authorization', -1) != -1 and not \
    [s for s in [rqCache,rpCache] if any(xs in s for xs in authButOk)]:
        print(1)

In [None]:
rpCache

In [None]:
def send_cacert(conn):
    with open(cacert, 'rb') as f:
        data = f.read()
    response = '''HTTP/1.1 200 OK\r
Content-Type: application/x-x509-ca-cert\r
Content-Length: %d\r
Connection: close\r
\r\n'''%(len(data))
    response = response.encode('utf8')
    payload = response + data
    conn.sendall(payload)
    
def createCert(certpath, target):
    epoch = "%d" % (time.time() * 1000)
    p1 = Popen(["openssl", "req", "-new", "-key", certkey, "-subj", "/CN=%s" % target], stdout=PIPE)
    p2 = Popen(["openssl", "x509", "-req", "-days", "3650", "-CA", cacert, "-CAkey", cakey, "-set_serial", epoch, "-out", certpath], stdin=p1.stdout, stderr=PIPE)
    p2.communicate()

In [None]:
def parseHost(url):
    portNo = url.find(':')
    if portNo == -1:
        port = 80
        webserver = url
    else:
        port = int(url[portNo+1:])
        webserver = url[:portNo]
    print(webserver, port)
    return webserver, port
    
def createResponse(protocol, statusCode):
    protocol = protocol.encode()
    if statusCode == 200:
        return b'%s 200 Connection Established%s'%(protocol, terminator)
    elif statusCode == 408:
        return b"%s 408 Request Timeout%s"%(protocol, terminator)
    elif statusCode == 404:
        return b"%s 404 Not Found%s"%(protocol, terminator)

In [None]:
class end(Exception): pass

In [None]:
def sendAll(sock, msg):
#     print(msg)
    totalsent = 0
    while totalsent < len(msg):
        sent = sock.send(msg[totalsent:])
        if sent == 0:
            raise RuntimeError("socket connection broken")
        totalsent = totalsent + sent
    return totalsent
        
def readSocket(sock):
    buffer = b''
    while 1:
        data = sock.recv(config.MAX_REQUEST_LEN)
        if len(data) == 0:
            return buffer, True
        buffer += data
        if terminator in buffer:    #check end
            bufferSplit = buffer.split(terminator)    
            header = bufferSplit[0]
            headerDetail = httpHeader(header)
            encoding = headerDetail.fields.get('Transfer-Encoding',-1)
            length = headerDetail.fields.get('Content-Length',-1)
            if (encoding == -1 and length == -1) or \
                (encoding == 'chunked' and b'\r\n0\r\n' in bufferSplit[1]) or \
                int(length) == len(bufferSplit[1]):
                    return buffer, False
                
#receive connect request from client
def handleRequest(conn):
    (clientSocket, client_address) = conn
    clientSocket.settimeout(config.timeout)
    buffer = b''
    while 1:
        try:
            request = clientSocket.recv(config.MAX_REQUEST_LEN)
            buffer += request
            if terminator in request:
                break
            elif len(request) == 0:
                raise RuntimeError("client close connection during init")
        except ConnectionResetError as e:
            raise RuntimeError("Connection Reset Error")
        except socket.timeout as t:
            sendAll(clientSocket, createResponse("HTTP/1.1", 408))
            raise RuntimeError("Client request timeout")
    request = requestHeader(buffer)
    request.parseHost()
    try:
    ##########special cases#################
        if request.target in blackList:
            print("black")
            response = createResponse(request.protocol, 404)
            sendAll(clientSocket, response)
            raise end()
        if request.target == 'www.proxy.test':
            send_cacert(clientSocket)
            raise end()
        ########################################
        redirect(clientSocket, request)
    except end:
        pass
    clientSocket.close()

#redirect following request and response to and from client
def redirect(conn, request):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(config.timeout)
    https = request.method == 'CONNECT'
    if https:
        context = ssl.create_default_context()
        s = context.wrap_socket(s, server_hostname=request.target) #ssl connection with real server
        s.connect((request.target, request.port))
        sendAll(conn, createResponse("HTTP/1.1", 200))
        certpath = "%s/%s.crt" % (certdir.rstrip('/'), request.target)
        if not os.path.isfile(certpath): #create self-signed cert
            createCert(certpath, request.target)
        conn = ssl.wrap_socket(conn, keyfile=certkey, certfile=certpath, server_side=True) # ssl connection with client
    else:
        s.connect((request.target, request.port))
        request.replaceRelative()
        sendAll(s, request.raw)
    first = True
    while 1:
        try:
            if https or not first:
                request, endOfClientSocket = readSocket(conn) #read client request
                if not https:
                    request = requestHeader(request)
                    request.replaceRelative()
                sendAll(s, request.raw)
                if endOfClientSocket:
                    break
            response, endOfServerSocket = readSocket(s) # read server response
            sendAll(conn, response)
            if request.method == 'GET':
                cache(request, responseHeader(response))
            if endOfServerSocket:
                break
            first = False
        except socket.timeout:
            pass
    s.close()

In [None]:
class Server():
    def __init__(self, config):
        signal.signal(signal.SIGINT, self.shutdown) 
        self.serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.serverSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.serverSocket.bind(config.socket)
        self.serverSocket.listen(config.LISTENNQ)
        self.jobs = []
    
    def shutdown(self):
        for job in self.jobs:
            while not job.done():
                pass
        self.serverSocket.shutdown(socket.SHUT_RDWR)
        self.serverSocket.close()
    

In [None]:
executor = ThreadPoolExecutor(config.LISTENNQ)

In [None]:
server = Server(config)

In [None]:
while True:
    conn,addr = server.serverSocket.accept()
    server.jobs.append(executor.submit(handleRequest, (conn, addr)))