# Code for Figure S7

Classify plate for overlapping healthy patients to test for plate effects.

In [None]:
from data import PlateDataset

import torch
import numpy as np
import pandas as pd
from tqdm import trange, tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as T 
import seaborn as sns
import umap
from scipy.stats import median_abs_deviation
from sklearn.preprocessing import StandardScaler
from sklearn import svm, linear_model
import matplotlib as mpl


device = 'cuda:0'

plate_healthy = {1 : ['H01', 'H02', 'H03', 'H04', 'H05'],
                 2 : ['H04', 'H05', 'H06', 'H07'],
                 3 : ['H06', 'H07', 'H30', 'H23'],
                 4 : ['H30', 'H23', 'H10', 'H40'],
                 5 : ['H10', 'H40', 'H31', 'H39'],
                 6 : ['H31', 'H39', 'H37', 'H22'],
                 7 : ['H37', 'H22', 'H26', 'H47'],
                 8 : ['H26', 'H47', 'H20', 'H36'],
                 9 : ['H20', 'H36', 'H32', 'H33'],
                10 : ['H32', 'H33', 'H16', 'H29'],
                11 : ['H16', 'H29', 'H19', 'H43'],
                12 : ['H19', 'H43', 'H09', 'H49'],
                13 : ['H09', 'H49', 'H25', 'H48'],
                14 : ['H25', 'H48', 'H18', 'H45'],
                15 : ['H18', 'H45', 'H13', 'H15'],
                16 : ['H13', 'H15', 'H21', 'H24']}

In [None]:
data = PlateDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], load_masks=True)

  0%|                                                                                                                                                                                          | 0/16 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:59<00:00,  7.44s/it]


In [None]:
data.info

Unnamed: 0,plate,well,series,cell,patient,time,qc,group
0,1,A02,0,25,H01,0,True,healthy
1,1,A02,0,30,H01,0,True,healthy
2,1,A02,0,37,H01,0,True,healthy
3,1,A02,0,43,H01,0,True,healthy
4,1,A02,0,44,H01,0,True,healthy
...,...,...,...,...,...,...,...,...
1093961,16,H12,9,565,H15,0,True,healthy
1093962,16,H12,9,569,H15,0,True,healthy
1093963,16,H12,9,570,H15,0,True,healthy
1093964,16,H12,9,572,H15,0,True,healthy


In [None]:
from torchvision.models import resnet18 as make_resnet18
from torchvision.models.feature_extraction import create_feature_extractor
from torch.utils.data import DataLoader


def extract_resnet_patch_features(imgs, transform=None):
  model = make_resnet18(weights="DEFAULT").to(device)
  return_nodes = {
      'flatten': 'z',
  }
  feature_extractor = create_feature_extractor(model.eval().to(device), return_nodes=return_nodes)
  z = torch.zeros((len(imgs), 512))
  i = 0
  loader = DataLoader(imgs, batch_size=128, shuffle=False)
  for img_batch in tqdm(loader):
    img_batch = img_batch.to(device).repeat(1, 3, 1, 1)
    if transform is not None:
      img_batch = transform(img_batch)
    with torch.no_grad():
      z[i:i+len(img_batch)] = feature_extractor(img_batch)['z'].cpu()
    i += len(img_batch)
  return z


res_zs = extract_resnet_patch_features(data.imgs)
res_zs.shape

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8547/8547 [00:50<00:00, 168.38it/s]


torch.Size([1093966, 512])

In [None]:
def get_id_subs(patient, plate_a, plate_b, m=200, seed=12312):
  np.random.seed(seed)

  wells_a = data.info[(data.info['plate'] == plate_a).values & (data.info['patient'] == patient).values]['well'].unique()
  wells_b = data.info[(data.info['plate'] == plate_b).values & (data.info['patient'] == patient).values]['well'].unique()
  assert len(wells_a) == 2 and len(wells_b) == 2

  wa1, wa2 = wells_a
  wb1, wb2 = wells_b

  train_test_combos = [(wa1, wb1, wa2, wb2),
                       (wa1, wb2, wa2, wb1),
                       (wa2, wb1, wa1, wb2),
                       (wa2, wb2, wa1, wb1)]
  ids = []
  for train_wa, train_wb, test_wa, test_wb in train_test_combos:
    subs_train_wa = data.info[(data.info['well'] == train_wa).values & (data.info['plate'] == plate_a).values].index
    subs_train_wb = data.info[(data.info['well'] == train_wb).values & (data.info['plate'] == plate_b).values].index
    sub_size = min(len(subs_train_wa), len(subs_train_wb), m)
    if sub_size < m:
      print('warning: training sub_size is only', sub_size, 'for', patient, plate_a, train_wa, plate_b, train_wb)
    idx_train_a = np.random.choice(subs_train_wa, size=sub_size, replace=False)
    idx_train_b = np.random.choice(subs_train_wb, size=sub_size, replace=False)

    subs_test_wa = data.info[(data.info['well'] == test_wa).values & (data.info['plate'] == plate_a).values].index
    subs_test_wb = data.info[(data.info['well'] == test_wb).values & (data.info['plate'] == plate_b).values].index
    sub_size = min(len(subs_test_wa), len(subs_test_wb), m)
    if sub_size < m:
      print('warning: testing sub_size is only', sub_size, 'for', patient, plate_a, test_wa, plate_b, test_wb)
    idx_test_a = np.random.choice(subs_test_wa, size=sub_size, replace=False)
    idx_test_b = np.random.choice(subs_test_wb, size=sub_size, replace=False)

    # check disjoint
    assert len(set(idx_train_a) | set(idx_train_b) | set(idx_test_a) | set(idx_test_b)) == len(idx_train_a) + len(idx_train_b) + len(idx_test_a) + len(idx_test_b)

    ids.append((idx_train_a, idx_train_b, idx_test_a, idx_test_b))

  return train_test_combos, ids


# both SVC and logistic regression
def classify_and_test(idx_train_a, idx_train_b, idx_test_a, idx_test_b, seed=24523, standardize=True, use_features=None):
  feats_train = res_zs[np.concatenate([idx_train_a, idx_train_b])]
  labels_train = np.zeros(len(feats_train))
  labels_train[:len(idx_train_a)] = 1

  if standardize:
    scaler = StandardScaler().fit(feats_train)
    scaled_train = scaler.transform(feats_train)
  else:
    scaled_train = feats_train

  feats_test = res_zs[np.concatenate([idx_test_a, idx_test_b])]
  labels_test = np.zeros(len(feats_test))
  labels_test[:len(idx_test_a)] = 1

  if standardize:
    scaled_test = scaler.transform(feats_test)
  else:
    scaled_test = feats_test

  if use_features is not None:
    scaled_train = scaled_train[:, use_features]
    scaled_test = scaled_test[:, use_features]

  # logistic regression
  clf_lr = linear_model.LogisticRegression(penalty='l1', solver='liblinear', random_state=seed)
  clf_lr.fit(scaled_train, labels_train)
  pred_train = clf_lr.predict(scaled_train)
  train_acc_lr = (labels_train == pred_train).sum() / len(pred_train)

  pred_test = clf_lr.predict(scaled_test)
  test_acc_lr = (labels_test == pred_test).sum() / len(pred_test)

  # SVC
  clf_svc = svm.SVC(decision_function_shape='ovr', random_state=seed)
  clf_svc.fit(scaled_train, labels_train)
  pred_train = clf_svc.predict(scaled_train)
  train_acc_svc = (labels_train == pred_train).sum() / len(pred_train)

  pred_test = clf_svc.predict(scaled_test)
  test_acc_svc = (labels_test == pred_test).sum() / len(pred_test)

  return train_acc_svc, test_acc_svc, train_acc_lr, test_acc_lr, clf_lr.coef_


plates = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
def test_for_plate_effect(standardize=True, use_features=None): 
  results = []
  for i in trange(len(plates) - 1):
    plate_a = plates[i]
    plate_b = plates[i+1]
    for patient in set(plate_healthy[plate_a]) & set(plate_healthy[plate_b]):
      for combo, ids in zip(*get_id_subs(patient, plate_a, plate_b)):
        res = classify_and_test(*ids, standardize=standardize, use_features=use_features)
        results.append((plate_a, plate_b, patient, *combo, *res))

  results = pd.DataFrame(results)
  results.columns = ['plate_a', 'plate_b', 'patient',
                     'train_well_a', 'train_well_b', 'test_well_a', 'test_well_b',
                     'train_acc_svc', 'test_acc_svc', 'train_acc_lr', 'test_acc_lr',
                     'lr_coeffs']
  return results


# All 512 features

In [None]:
results = test_for_plate_effect()
results

  7%|███████████▊                                                                                                                                                                      | 1/15 [00:03<00:47,  3.38s/it]



 13%|███████████████████████▋                                                                                                                                                          | 2/15 [00:06<00:44,  3.43s/it]



 20%|███████████████████████████████████▌                                                                                                                                              | 3/15 [00:09<00:38,  3.22s/it]



 27%|███████████████████████████████████████████████▍                                                                                                                                  | 4/15 [00:12<00:33,  3.03s/it]



 47%|███████████████████████████████████████████████████████████████████████████████████                                                                                               | 7/15 [00:22<00:25,  3.20s/it]



 60%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                       | 9/15 [00:29<00:20,  3.44s/it]



 67%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                           | 10/15 [00:32<00:17,  3.47s/it]



 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                       | 13/15 [00:42<00:06,  3.39s/it]



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:48<00:00,  3.23s/it]


Unnamed: 0,plate_a,plate_b,patient,train_well_a,train_well_b,test_well_a,test_well_b,train_acc_svc,test_acc_svc,train_acc_lr,test_acc_lr,lr_coeffs
0,1,2,H05,D03,A04,D12,G12,0.8800,0.7200,0.9675,0.7125,"[[0.0, 0.0, 0.0, 0.0, 0.8809444218608694, 0.12..."
1,1,2,H05,D03,G12,D12,A04,0.8775,0.7475,0.9700,0.7425,"[[0.0, 0.0, -0.4946893963550723, -0.6260194454..."
2,1,2,H05,D12,A04,D03,G12,0.8975,0.6825,0.9875,0.7025,"[[0.0, 0.0, -0.25773586475461246, 0.0, 0.0, 0...."
3,1,2,H05,D12,G12,D03,A04,0.9075,0.7500,0.9875,0.7475,"[[0.0, 0.3967251685124443, 0.0, -0.73680238278..."
4,1,2,H04,A04,B10,E05,D03,0.8525,0.7100,0.9625,0.7050,"[[0.0, 0.0, -0.049237178093853515, 0.0, 0.0, 0..."
...,...,...,...,...,...,...,...,...,...,...,...,...
115,15,16,H15,G07,H12,E07,B10,0.9975,0.9150,1.0000,0.9600,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
116,15,16,H13,C09,D12,H10,G09,0.9875,0.9350,1.0000,0.9750,"[[0.0, 0.011518177361320745, 0.0, 0.0, 0.0, 0...."
117,15,16,H13,C09,G09,H10,D12,0.9925,0.9425,1.0000,0.9600,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
118,15,16,H13,H10,D12,C09,G09,0.9925,0.9050,1.0000,0.9700,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."


In [None]:
results.to_csv('results/patch_effect_resnet_patch_features_all.csv')

In [None]:
results.groupby(['plate_a', 'plate_b']).agg({'train_acc_svc':['mean', 'std'], 'train_acc_lr':['mean', 'std'],
                                             'test_acc_svc':['mean', 'std'], 'test_acc_lr':['mean', 'std']})

Unnamed: 0_level_0,Unnamed: 1_level_0,train_acc_svc,train_acc_svc,train_acc_lr,train_acc_lr,test_acc_svc,test_acc_svc,test_acc_lr,test_acc_lr
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std,mean,std,mean,std
plate_a,plate_b,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
1,2,0.88,0.024016,0.970313,0.01775,0.695312,0.047197,0.680937,0.059039
2,3,0.831488,0.035898,0.962649,0.020865,0.566324,0.042541,0.52369,0.026913
3,4,0.891733,0.029282,0.984277,0.013997,0.674045,0.03363,0.666817,0.056708
4,5,0.917969,0.051252,0.983179,0.016986,0.76664,0.066646,0.749893,0.101849
5,6,0.946096,0.021967,0.991756,0.009157,0.837323,0.042222,0.873265,0.037973
6,7,0.821562,0.011873,0.935,0.012536,0.587187,0.045009,0.559688,0.045207
7,8,0.864062,0.042129,0.957187,0.019153,0.606563,0.15645,0.60875,0.143378
8,9,0.86388,0.05442,0.964561,0.024572,0.617646,0.146767,0.604131,0.131363
9,10,0.835625,0.030582,0.950937,0.024926,0.545,0.035279,0.5075,0.035807
10,11,0.853685,0.042057,0.953714,0.028537,0.621555,0.059586,0.586533,0.045387
