In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
!pip install scanpy --quiet

In [3]:
import random
import torch
import sys
import os
import gc
import collections 
import anndata as ad
from argparse import Namespace

config = Namespace(
    LEARNING_RATE = 0.00002,
    DEVICE = 'cuda',
    BATCH_SIZE = 100,
    NUM_WORKERS = 4,
    N_GENES = 13431,
    N_PEAKS = 116465,
    MAX_SEQ_LEN_GEX = 1500,
    MAX_SEQ_LEN_ATAC = 15000,
)

In [4]:
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/data.py")
execfile("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/code/resources/models.py")

## Import data

In [5]:
index = get_chr_index(ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad"))

In [6]:
gc.collect()

277

In [7]:
batch = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['batch']
batch = list(batch)
train_id = [a for a, l in enumerate(batch) if l not in ['s2d4','s1d1']]
val_id =  [a for a, l in enumerate(batch) if l == 's1d1']
test_id = [a for a, l in enumerate(batch) if l == 's2d4']

cell_type_all = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").obs['cell_type']

csr_gex = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/GEX_processed.h5ad").layers['log_norm']
csr_atac = ad.read_h5ad("drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/data/ATAC_processed.h5ad").layers['log_norm']

In [8]:
gc.collect()

240

In [9]:
random.seed(0)

# idx_train = [train_id[i] for i in random.sample(range(0, 45000), 1024)]
idx_train = train_id # Full dataset
gex_train = csr_gex[idx_train,:]
atac_train = csr_atac[idx_train,:]
cell_type_train = [cell_type_all[j] for j in idx_train]

data_train = get_dataloaders(gex_train, atac_train, cell_type_train)

In [10]:
def train(model, criterion, optimizer, data_train, epochs, loss_type):

  model.train()

  for e in range(epochs):
    running_loss = 0.0
    running_loss_cross = 0.0
    running_loss_triplet = 0.0
    running_ct_prob = 0.0
    running_cell_prob = 0.0
    for iter, data in enumerate(data_train):
      gex_input = data['gex'].to(config.DEVICE)
      atac_input = data['atac'].to(config.DEVICE)
      cell_type_input = data['cell_type']
      # print(cell_type_input)

      model.zero_grad()
      optimizer.zero_grad()

      ### Forward
      gex_out_0, gex_out_1, atac_out_0, atac_out_1 = model(gex_input, atac_input)

      ### Compute loss
      loss, loss_triplet, loss_cross, ct_match_prob, cell_match_prob = criterion(gex_out_0, gex_out_1, atac_out_0, atac_out_1, cell_type_input)
      
      ### Propagate loss
      if loss_type == "both":
        running_loss += loss.item()
        loss.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  
        
      elif loss_type == "entropy":
        running_loss += loss_cross.item()
        loss_cross.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()       
        running_loss_cross += loss_cross.item()  

      elif loss_type == "triplet":
        running_loss += loss_triplet.item()
        loss_triplet.backward()
        # Store other losses
        running_ct_prob += ct_match_prob.item()
        running_cell_prob += cell_match_prob.item()
        running_loss_triplet += loss_triplet.item()    
        running_loss_cross += loss_cross.item()  

      else:
        break

      ### update parameters
      optimizer.step()

      del gex_input
      del atac_input
      del cell_type_input
      torch.cuda.empty_cache()
      if (iter + 1) % 50 == 0: 
        print('Within-epoch iter', iter + 1, ': cross_loss =', loss_cross.item(), '; triplet_loss =', loss_triplet.item(), '; ct_match =', ct_match_prob.item(), '; cell_match =', cell_match_prob.item())

    # print('cross_loss = ', loss_cross.item(), '; triplet_loss = ', loss_triplet.item(), '; ct_match_prob = ', ct_match_prob.item())
    if (e+1) % 1 == 0: 
      print('Epoch-{0}: lr = {1}, loss = {2}, entropy_loss = {3}, triplet loss = {4}, cell type match prob = {5}, cell_match = {6}'.format(
          e+1, 
          optimizer.param_groups[0]['lr'], 
          running_loss / len(data_train), 
          running_loss_cross / len(data_train), 
          running_loss_triplet / len(data_train), 
          running_ct_prob / len(data_train),
          running_cell_prob / len(data_train)
          )
      )

    # scheduler.step()

## Select hyperparameters

### Hyperparam set 1

In [11]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 32

In [12]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)



In [13]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") # shuffle = True

Within-epoch iter 50 : cross_loss = 4.608196258544922 ; triplet_loss = 0.853897750377655 ; ct_match = 0.05894417315721512 ; cell_match = 0.010008595883846283
Within-epoch iter 100 : cross_loss = 4.609030723571777 ; triplet_loss = 0.7244704961776733 ; ct_match = 0.06723636388778687 ; cell_match = 0.010120362974703312
Within-epoch iter 150 : cross_loss = 4.640758037567139 ; triplet_loss = 0.6316145062446594 ; ct_match = 0.0533234179019928 ; cell_match = 0.010156986303627491
Within-epoch iter 200 : cross_loss = 4.65693998336792 ; triplet_loss = 0.6618552803993225 ; ct_match = 0.07907401770353317 ; cell_match = 0.010273211635649204
Within-epoch iter 250 : cross_loss = 4.644364833831787 ; triplet_loss = 0.6742469668388367 ; ct_match = 0.06817732751369476 ; cell_match = 0.010311472229659557
Within-epoch iter 300 : cross_loss = 4.632058620452881 ; triplet_loss = 0.5905972123146057 ; ct_match = 0.054844923317432404 ; cell_match = 0.010516515001654625
Within-epoch iter 350 : cross_loss = 4.6240

In [14]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 2

In [15]:
config.ALPHA = 0.2
config.MARGIN = 0.5
config.N_CHANNELS = 64

In [16]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [17]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss = 4.604972839355469 ; triplet_loss = 0.8084391355514526 ; ct_match = 0.05906464904546738 ; cell_match = 0.010090459138154984
Within-epoch iter 100 : cross_loss = 4.639263153076172 ; triplet_loss = 0.6651818752288818 ; ct_match = 0.06666548550128937 ; cell_match = 0.010075521655380726
Within-epoch iter 150 : cross_loss = 4.672473430633545 ; triplet_loss = 0.657804548740387 ; ct_match = 0.052823565900325775 ; cell_match = 0.010058930143713951
Within-epoch iter 200 : cross_loss = 4.690717697143555 ; triplet_loss = 0.6671064496040344 ; ct_match = 0.07673811912536621 ; cell_match = 0.00994568970054388
Within-epoch iter 250 : cross_loss = 4.659061431884766 ; triplet_loss = 0.6870446801185608 ; ct_match = 0.06766388565301895 ; cell_match = 0.010202813893556595
Within-epoch iter 300 : cross_loss = 4.672415733337402 ; triplet_loss = 0.650891125202179 ; ct_match = 0.05297708511352539 ; cell_match = 0.010073782876133919
Within-epoch iter 350 : cross_loss = 4.6474

In [18]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 3

In [19]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 32

In [20]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [21]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss = 4.6025800704956055 ; triplet_loss = 1.7675195932388306 ; ct_match = 0.05912341922521591 ; cell_match = 0.010095941834151745
Within-epoch iter 100 : cross_loss = 4.624183654785156 ; triplet_loss = 1.519217848777771 ; ct_match = 0.06771261990070343 ; cell_match = 0.010175034403800964
Within-epoch iter 150 : cross_loss = 4.704087257385254 ; triplet_loss = 1.6158097982406616 ; ct_match = 0.05319532752037048 ; cell_match = 0.01014780718833208
Within-epoch iter 200 : cross_loss = 4.6914777755737305 ; triplet_loss = 1.55301833152771 ; ct_match = 0.0791710913181305 ; cell_match = 0.010325218550860882
Within-epoch iter 250 : cross_loss = 4.702356815338135 ; triplet_loss = 1.6025354862213135 ; ct_match = 0.06748630106449127 ; cell_match = 0.010177303105592728
Within-epoch iter 300 : cross_loss = 4.660898208618164 ; triplet_loss = 1.4079663753509521 ; ct_match = 0.05454471707344055 ; cell_match = 0.010414592921733856
Within-epoch iter 350 : cross_loss = 4.62403

In [22]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 4

In [23]:
config.ALPHA = 0.2
config.MARGIN = 1
config.N_CHANNELS = 64

In [24]:
criterion = bidirectTripletLoss(alpha = config.ALPHA, margin = config.MARGIN).to(config.DEVICE)
model = Encoder(kernel_size_gex = 100, kernel_size_atac_1 = 30, kernel_size_atac_2 = 5, index = index).to(config.DEVICE) ## CHANGED TO SMALLER KERNAL SIZE FOR ATAC
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [10], gamma = 0.1)

In [25]:
train(model, criterion, optimizer, data_train, epochs = 15, loss_type = "triplet") 

Within-epoch iter 50 : cross_loss = 4.613767623901367 ; triplet_loss = 1.701754093170166 ; ct_match = 0.058985304087400436 ; cell_match = 0.010073868557810783
Within-epoch iter 100 : cross_loss = 4.707368850708008 ; triplet_loss = 1.632098913192749 ; ct_match = 0.06691852957010269 ; cell_match = 0.01008598692715168
Within-epoch iter 150 : cross_loss = 4.738455772399902 ; triplet_loss = 1.6226829290390015 ; ct_match = 0.05262773483991623 ; cell_match = 0.010019236244261265
Within-epoch iter 200 : cross_loss = 4.757017135620117 ; triplet_loss = 1.598933219909668 ; ct_match = 0.07685397565364838 ; cell_match = 0.009971520863473415
Within-epoch iter 250 : cross_loss = 4.729634761810303 ; triplet_loss = 1.68116295337677 ; ct_match = 0.06764290481805801 ; cell_match = 0.010210045613348484
Within-epoch iter 300 : cross_loss = 4.705041885375977 ; triplet_loss = 1.5416747331619263 ; ct_match = 0.053286533802747726 ; cell_match = 0.010177826508879662
Within-epoch iter 350 : cross_loss = 4.601577

In [26]:
file = 'drive/MyDrive/Colab_Notebooks/CPSC532S/final_project/model/trained_model_semihard_allcells_alpha' + \
        str(config.ALPHA) + '_margin' + str(config.MARGIN) + '_nchannels' + str(config.N_CHANNELS) + '_15epochs'
torch.save(model.state_dict(), file)

### Hyperparam set 5

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 32

Xinyao ran this set.

### Hyperparam set 6

In [None]:
config.ALPHA = 0.8
config.MARGIN = 0.5
config.N_CHANNELS = 64

Xinyao ran this set.

### Hyperparam set 7

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 32

Xinyao ran this set.

### Hyperparam set 8

In [None]:
config.ALPHA = 0.8
config.MARGIN = 1
config.N_CHANNELS = 64

Xinyao ran this set.