In [1]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

global_aa = list("ACDEFGHIKLMNPQRSTVWY")

In [2]:
polars = set(['G', 'S', 'T', 'Y', 'C', 'Q', 'N'])
basics = set(['K', 'R', 'H'])
acids = set(['D', 'E'])
#The rest are hydrophobic and black

def readimg(fname, col):
    arr = plt.imread(fname)
    newarr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    uparr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    downarr = np.zeros(shape=(len(arr), len(arr[0]), 4), dtype = "uint8")
    for x in range(len(arr)):
        for y in range(len(arr[0])):
            newarr[x][y][3] = 255 - arr[x][y][0] * 255
            newarr[x][y][0] = col[0]
            newarr[x][y][1] = col[1]
            newarr[x][y][2] = col[2]
    return newarr

imdic = {}
updic = {}
downdic = {}
for i in range(26):
    c = chr(i + ord('A'))
    col = (0,0,0)
    if c in polars:
        col = (0,192,0)
    elif c in basics:
        col = (24,32,255)
    elif c in acids:
        col = (255,0,0)
        
    imdic[c] = readimg("./{}_crop.png".format(c), col)

In [15]:
def drawColumn(ax, seqs, hgts, shift, wide):
    i = 0
    for i,x in enumerate(range(1,10)):
        l = x-wide/2 + shift
        r = x+wide/2 + shift
        b = 0
        for c,p in seqs[i][::-1]:
            hgt = hgts[i] * p
            t = b + hgt
            if hgt == 0:
                b = t
                continue
                
            mgn = hgt * 0
            ax.imshow(imdic[c], aspect = "auto", extent = (l,r,b+mgn,t-mgn), interpolation = 'bilinear')
            b = t
    ax.set_xticks(range(1,10))
    ax.set(xlim = (0,10), ylim = (0,max(hgts)*1.1))

def drawLogo(seqlist):
    fig, ax = plt.subplots(1,1,figsize=(16, 6))
    w = 0.25
    ws = 0.28
    h = 0
    for i in range(-1,2):
        fractions, heights = seqlist[i+1]
        drawColumn(ax, fractions, heights, ws*i, w)
        h = max(h, max(heights))
    ax.set(ylim = (0, 1.1*h), xlim = (0.5, 9.5))
    
    xtc = []
    xlab = []
    for i in range(1,10):
        for j in range(-1,2):
            pos = i + j*ws
            lab = "P{} & {}".format(i, ["HLA-DR401", "Both", "HLA-DR402"][j+1])
            xtc.append(pos)
            xlab.append(lab)
    ax.set_xticks(xtc)
    ax.set_xticklabels(xlab, rotation=270)
    ax.set_xlabel("Positions & MHC allele")
    ax.set_ylabel("Bits")
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

In [17]:
# drawLogo takes in a triplet of sequences
# Each sequence is a tuple of (fraction data, height data)
# Fraction data is a list of 9 tuples of (amino acid, relative height)
# Height data is a list of 9 numbers denoting height of the column

# Example:

dummies = []
for i in range(3):
    dummy = []
    for j in range(9):
        dummy.append([(c,0.05) for c in global_aa])
    dummies.append((dummy, [2] * 9))
    
drawLogo(dummies)