In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
import random
import numpy as np
import pandas as pd
import pickle as pkl
import matplotlib.pyplot as plt

In [4]:
ic50_data_path = "GDSC2_fitted_dose_response_27Oct23.xlsx"
raw_data_path = "GDSC2_public_raw_data_27Oct23.csv"

In [5]:
ic50_data = pd.read_excel(ic50_data_path)

In [6]:
raw_data = pd.read_csv(raw_data_path)

In [7]:
cell_line_names_to_cosmic_ids = {}
cosmic_ids_to_cell_line_names = {}

for cosmic_id, cell_line_name in zip(ic50_data["COSMIC_ID"].tolist(), ic50_data["CELL_LINE_NAME"].tolist()):
    cell_line_names_to_cosmic_ids[cell_line_name] = cosmic_id
    cosmic_ids_to_cell_line_names[cosmic_id] = cell_line_name

In [8]:
cell_line_names = list(cell_line_names_to_cosmic_ids.keys())
cosmic_ids = list(cosmic_ids_to_cell_line_names.keys())

In [9]:
for cosmic_id, cell_line_name in zip(ic50_data["COSMIC_ID"].tolist(), ic50_data["CELL_LINE_NAME"].tolist()):
    assert cell_line_names_to_cosmic_ids[cell_line_name] == cosmic_id
    assert cosmic_ids_to_cell_line_names[cosmic_id] == cell_line_name

In [10]:
drug_names_to_drug_ids = {}
drug_ids_to_drug_names = {}

for drug_id, drug_name in zip(ic50_data["DRUG_ID"].tolist(), ic50_data["DRUG_NAME"].tolist()):
    drug_ids_to_drug_names[drug_id] = drug_name

drug_names = set(list(drug_ids_to_drug_names.values()))

In [11]:
for drug_name in drug_names:
    drug_names_to_drug_ids[drug_name] = []

for drug_id, drug_name in zip(ic50_data["DRUG_ID"].tolist(), ic50_data["DRUG_NAME"].tolist()):
    if not (drug_id in drug_names_to_drug_ids[drug_name]):
        drug_names_to_drug_ids[drug_name].append(drug_id)

In [12]:
drug_ids = list(drug_ids_to_drug_names.keys())

In [13]:
for drug_id, drug_name in zip(ic50_data["DRUG_ID"].tolist(), ic50_data["DRUG_NAME"].tolist()):
    assert drug_id in drug_names_to_drug_ids[drug_name]
    assert drug_name == drug_ids_to_drug_names[drug_id]

In [14]:
n_cell_lines = len(cell_line_names)
n_drugs = len(drug_ids)

In [15]:
(n_cell_lines, n_drugs)

(969, 295)

In [16]:
times_sampled = np.zeros((n_cell_lines, n_drugs))

for cell_line_name, drug_id in zip(ic50_data["CELL_LINE_NAME"].tolist(), ic50_data["DRUG_ID"].tolist()):
    times_sampled[cell_line_names.index(cell_line_name), drug_ids.index(drug_id)] += 1

In [17]:
ln_ic50s = np.zeros((n_cell_lines, n_drugs))

for cell_line_name, drug_id, ln_ic50 in zip(ic50_data["CELL_LINE_NAME"].tolist(), ic50_data["DRUG_ID"].tolist(), ic50_data["LN_IC50"]):
    ln_ic50s[cell_line_names.index(cell_line_name), drug_ids.index(drug_id)] = ln_ic50

In [18]:
raw_data.DRUG_ID = raw_data.DRUG_ID.fillna(-1.0).astype(int)

In [19]:
n_drugs_raw = len(set(raw_data.DRUG_ID.tolist())) - 1

drug_ids_raw = list(set(raw_data.DRUG_ID.tolist()))
cosmic_ids_raw = list(set(raw_data.COSMIC_ID.tolist()))

In [20]:
for raw_id in cosmic_ids_raw:
  assert raw_id in cosmic_ids
for id in cosmic_ids:
  assert id in cosmic_ids_raw

In [21]:
missing_drug_ids = []

for raw_id in drug_ids_raw:
  if raw_id > 0 and not (raw_id in drug_ids):
    missing_drug_ids.append(raw_id)

In [22]:
for id in drug_ids:
  assert id in drug_ids_raw

In [23]:
barcodes = list(set(raw_data["BARCODE"].values))

In [24]:
n_barcodes = len(barcodes)

In [29]:
n_barcodes

16354

In [25]:
drug_ids_raw.sort()

In [26]:
def retrieve_dose_intensity_data(cosmic_id, drug_id):
    assert cosmic_id in cosmic_ids, "Invalid COSMIC ID, please try again."
    assert drug_id in drug_ids_raw, "Drug ID not present in raw data, please try again."
    subset = raw_data[(raw_data["COSMIC_ID"] == demo_cosmic_id) & (raw_data["DRUG_ID"] == demo_drug_id)]
    subset_barcodes = set(subset["BARCODE"].values)
    data = []
    for barcode in subset_barcodes:
        further_subset = subset[subset["BARCODE"] == barcode]
        tags = further_subset["TAG"].values
        tags_passed = True
        dose_counts = [0] * 7
        for i, tag in enumerate(tags):
            if i % 7 == 0:
                for count in dose_counts: assert count == i // 7, "Doses not contiguous."
            split_tag = tag.split("-")
            if split_tag[0][0] != 'L' or split_tag[-1] != "S":
                tags_passed = False
                break
            dose_counts[int(split_tag[1][1]) - 1] += 1
        if not tags_passed: continue
        for count in dose_counts:
            assert count == dose_counts[0], f"Doses unevenly sampled for barcode {barcode}!"
        for i in range(dose_counts[0]):
            plate_data = {"barcode": barcode,
                         "seeding_density": further_subset["SEEDING_DENSITY"].iloc[0],
                         "dose": further_subset["CONC"].values[7*i:7*(i+1)],
                         "intensity": further_subset["INTENSITY"].values[7*i:7*(i+1)],
                         "position": further_subset["POSITION"].values[7*i:7*(i+1)]}
            data.append(plate_data)
    return data

In [27]:
def retrieve_baseline_data(barcode):
    assert barcode in barcodes, "Invalid barcode. Please try again."
    subset = raw_data[(raw_data["BARCODE"] == barcode) & (raw_data["TAG"] == "B")]
    plate_data = {"barcode": barcode,
                 "seeding_density": subset["SEEDING_DENSITY"].iloc[0],
                 "intensity": subset["INTENSITY"].values,
                 "position": subset["POSITION"].values}
    return plate_data

In [28]:
def retrieve_untreated_data(barcode):
    assert barcode in barcodes, "Invalid barcode. Please try again."
    subset_0 = raw_data[(raw_data["BARCODE"] == barcode) & (raw_data["TAG"] == "NC-0")]
    subset_1 = raw_data[(raw_data["BARCODE"] == barcode) & (raw_data["TAG"] == "NC-1")]
    assert len(subset_0.index) > 0 and len(subset_1.index) > 0, "Missing data for either NC-0 or NC-1!"
    plate_data = {"barcode": barcode,
                 "seeding_density": subset_0["SEEDING_DENSITY"].iloc[0],
                 "intensity_0": subset_0["INTENSITY"].values,
                 "intensity_1": subset_1["INTENSITY"].values,
                 "position_0": subset_0["POSITION"].values,
                 "position_1": subset_1["POSITION"].values}
    return plate_data

In [38]:
c = raw_data[raw_data["TAG"] == "B"].groupby("BARCODE")["INTENSITY"].median()

In [39]:
c

BARCODE
3230     932.0
3236     796.0
3242     796.0
3266     832.5
3277     995.0
         ...  
63252    287.0
63254    716.5
63255    669.0
63256    633.0
63257    925.5
Name: INTENSITY, Length: 16354, dtype: float64

In [41]:
from scipy.special import gammaln, gammaincc
from scipy.stats import gamma, poisson
from scipy.optimize import minimize

In [42]:
def ilogit(x):
    return 1. / (1+np.exp(-x))

def inc_gamma_loss(logr_logitp, x, c):
    logr, logitp = logr_logitp
    r = np.exp(logr)
    p = ilogit(logitp)
    f = r*np.log((1-p)) + x.sum()*np.log(p) - gammaln(r) + np.log(gammaincc(r+x, c/p)).sum() # gammaln(r + x) + gamma.logcdf(c/p, r+x)
    print(r,p, x, c, f)
    print(np.log(gammaincc(r+x, c/p)))

def approx_loss(ab, x, c):
    a, b = ab
    lam_grid = gamma.ppf(np.linspace(1e-6, 1-1e-6, 1000), a, scale=b)
    weights = gamma.pdf(lam_grid, a, scale=b) / max(1e-20,gamma.pdf(lam_grid, a, scale=b).sum())
    return -np.log((poisson.pmf(x[:,np.newaxis], lam_grid[np.newaxis,:]+c) * weights).sum(axis=1).clip(1e-6,np.inf)).sum()

def fit_gamma_hyperparameters(x, c, nrestarts=10):
    # Quick estimation of the mean and variance to initialize
    x_mean_true = x.mean()
    x_std = x.std()
    best_score, best_params = None, None
    for trial in range(nrestarts):
        x_mean = np.random.normal(x_mean_true, x_std/2)
        # theta0 = np.array([np.log(x_mean**2 / x_std**2), np.log(x_std**2 / x_mean)])
        theta0 = np.array([x_mean**2 / x_std**2, x_std**2 / x_mean])
        print('\t{} Initial: '.format(trial), theta0)
        results = minimize(approx_loss,
                            theta0,
                            args=(x,c),
                            bounds=((0.01,1e5), (0.01,1e8)),
                            method='SLSQP')
        print('\t{} Fit: '.format(trial), results.x, results.fun)
        if np.isnan(results.fun):
            continue
        if best_score is None or results.fun > best_score:
            best_score = results.fun
            # best_params = np.exp(results.x)
            best_params = results.x
    return best_params, best_score

In [45]:
for i, barcode in enumerate(barcodes):

    print(f"On barcode {i:05d} of {n_barcodes}.")
    
    untreated_data = retrieve_untreated_data(barcode)
    
    print(f"NC-0: {np.median(untreated_data['intensity_0'])}. NC-1: {np.median(untreated_data['intensity_1'])}.")
    print(f"NC-0: {len(untreated_data['intensity_0'])}. NC-1: {len(untreated_data['intensity_1'])}.")

On barcode 00000 of 16354.
NC-0: 38513.5. NC-1: 39278.0.
NC-0: 6. NC-1: 126.
On barcode 00001 of 16354.
NC-0: 34451.0. NC-1: 34737.5.
NC-0: 6. NC-1: 126.
On barcode 00002 of 16354.
NC-0: 69918.5. NC-1: 68915.0.
NC-0: 6. NC-1: 126.
On barcode 00003 of 16354.
NC-0: 67576.0. NC-1: 69440.5.
NC-0: 6. NC-1: 126.
On barcode 00004 of 16354.
NC-0: 40042.5. NC-1: 39182.5.
NC-0: 6. NC-1: 126.
On barcode 00005 of 16354.
NC-0: 40090.5. NC-1: 40903.0.
NC-0: 6. NC-1: 126.
On barcode 00006 of 16354.
NC-0: 37945.5. NC-1: 37653.0.
NC-0: 6. NC-1: 126.
On barcode 00007 of 16354.
NC-0: 43723.0. NC-1: 42624.0.
NC-0: 6. NC-1: 126.
On barcode 00008 of 16354.
NC-0: 43006.0. NC-1: 43149.5.
NC-0: 6. NC-1: 126.
On barcode 00009 of 16354.
NC-0: 41190.0. NC-1: 39947.0.
NC-0: 6. NC-1: 126.
On barcode 00010 of 16354.


KeyboardInterrupt: 

In [50]:
a = np.zeros(n_barcodes)
b = np.zeros(n_barcodes)
scores = np.zeros(n_barcodes)

In [49]:
for i, barcode in enumerate(barcodes):

    print(f"On barcode {i:05d} of {n_barcodes}.")
    
    untreated_data = retrieve_untreated_data(barcode)

    (a_l,b_l), score = fit_gamma_hyperparameters(untreated_data["intensity_1"], c[barcode])
    print(a_l, b_l, score)
    a[i] = a_l
    b[i] = b_l
    scores[i] = score

On barcode 00000 of 16354.
	0 Initial:  [ 60.96144657 626.61118574]
	0 Fit:  [  36.6680526 1072.0256834] 1247.5990101853174
	1 Initial:  [ 81.58239264 541.66098629]
	1 Fit:  [  35.15606639 1119.14869271] 1247.4972325532292
	2 Initial:  [ 60.89655864 626.94493758]


  weights = gamma.pdf(lam_grid, a, scale=b) / max(1e-20,gamma.pdf(lam_grid, a, scale=b).sum())


	2 Fit:  [  37.83341975 1038.3706502 ] 1247.752871919692
	3 Initial:  [ 72.37640392 575.07868993]
	3 Fit:  [ 39.42284413 996.5999794 ] 1248.056734690627
	4 Initial:  [ 70.68502321 581.91837743]
	4 Fit:  [  36.26956976 1084.20418981] 1247.560695019347
	5 Initial:  [ 67.62018552 594.95976407]
	5 Fit:  [  38.83500329 1011.82935941] 1247.9309994307057
	6 Initial:  [ 65.03782719 606.65637174]
	6 Fit:  [  36.05475676 1090.71405082] 1247.5437440002636
	7 Initial:  [ 63.34636597 614.70241943]
	7 Fit:  [  37.31696237 1053.14248568] 1247.676333737108
	8 Initial:  [ 67.69130625 594.64713095]
	8 Fit:  [  39.18073924 1002.02103806] 1248.0066820953436
	9 Initial:  [ 64.50448412 609.15922028]
	9 Fit:  [  38.7653983 1013.6556721] 1247.9170846602854
39.42284413309421 996.5999793993149 1248.056734690627
On barcode 00001 of 16354.
	0 Initial:  [202.68450426 181.35691367]
	0 Fit:  [ 87.43778872 390.07694179] 1169.7560137537444
	1 Initial:  [179.95240653 192.47111421]
	1 Fit:  [ 85.91323755 397.06933041] 1

KeyboardInterrupt: 

In [52]:
for idx in range(221):
    res = np.load(f"pos-ctrl-priors-{idx}.npz")
    a[idx*74:(idx+1)*74] = res["a"]
    b[idx*74:(idx+1)*74] = res["b"]
    scores[idx*74:(idx+1)*74] = res["scores"]

In [53]:
a

array([ 41.61200679, 100.59385568,  94.03724218, ...,  38.89919212,
        39.1920098 ,  55.52837015])

In [54]:
b

array([944.55684107, 339.19701138, 735.21877814, ..., 410.14827329,
       408.28108954, 698.71782918])

In [55]:
scores

array([1248.65276415, 1170.622583  , 1290.2968765 , ..., 1124.51726477,
       1125.84360091, 1233.56959629])

In [57]:
np.savez("pos-ctrl-priors.npz", barcodes=barcodes, a=a, b=b, scores=scores)

In [58]:
c_np = np.array([c[barcodes[i]] for i in range(n_barcodes)])

In [59]:
np.savez("neg-ctrl-priors.npz", barcodes=barcodes, c=c_np)

In [60]:
c_np

array([478. , 955. , 907.5, ..., 191. , 195. , 382. ])