# Train model using Mouse Brain dataset

In this notebook, we're going to train our model using the Mouse Brain dataset (GSE60361). 

This assumes that you've made the graph using the ```Infer GRN.ipynb``` code.

In [1]:
import os

import numpy as np
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Data, Dataset
from tqdm import tqdm
from datasets.datasetPatacseq import AtacSeqDataset
from scipy.special import softmax
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tqdm import tqdm
from sklearn.metrics import (auc, precision_recall_curve, roc_auc_score,
                             roc_curve)
from statistics import mean

Torch version: 1.8.0+cu111
Cuda available: True
Torch geometric version: 2.0.3


Load up the dataset. Read ```datasetMouseBrain.py``` on how the dataset was built. 

In [2]:
dataset = AtacSeqDataset("/gpfs/data/rsingh47/hzaki1/atacseqdataP0Chromatin", atacseq=True)

100%|██████████| 96185/96185 [01:24<00:00, 1140.67it/s]
Processing...
  4%|▎         | 190/5081 [47:19<20:18:07, 14.94s/it]


KeyboardInterrupt: 

In [None]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

In [None]:
torch.manual_seed(12345)
dataset = dataset.shuffle()


# shuffle_index = np.loadtxt('shuffle_indices/shuffleIndex_MouseBrain.txt')
# shuffle_index = shuffle_index.astype(np.int32)
# train_size, val_size = int(len(shuffle_index)* 0.8), int(len(shuffle_index)* 0.9)
# train_dataset = [dataset[i] for i in shuffle_index[0:train_size]]
# val_dataset = [dataset[i] for i in shuffle_index[train_size: val_size]]
# test_dataset =  [dataset[i] for i in shuffle_index[val_size:]]

# train_dataset = torch.load('trainDataset.pt')
# test_dataset = torch.load('testDataset.pt')

train_dataset = dataset[:4064]
test_dataset = dataset[4064:]

# train_dataset = dataset[:40]
# test_dataset = dataset[40:60]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

In [None]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
train_loader_testing = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
from gcnmodel import GCN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GCN(hidden_channels=128, data=dataset, output_size=9).to(device)
print(model)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
    
def test(loader, size):
    model.eval()
    output = np.zeros((len(loader), size))
    actual = np.zeros((len(loader), size))
    accuracy = 0
    for ind, data in enumerate(loader):  # Iterate in batches over the training/test dataset.
        data.x = torch.reshape(data.x, (data.x.shape[0], 2))
        data.x = data.x.type(torch.FloatTensor)
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        output[ind] = softmax(out.cpu().detach().numpy())
        actual[ind][data.y] = 1
        accuracy += int((out.argmax(dim=1) == data.y).sum())
    all_labels = list(dataset.cellToIndex.keys())
    actual = np.array(actual)
    precision = dict()
    recall = dict()
    averageAUROC = []
    averageAUPR = []
    for (idx, c_label) in enumerate(all_labels):
        
        fpr, tpr, thresholds = roc_curve(actual[:,idx].astype(int), output[:,idx])
        precision[idx], recall[idx], _ = precision_recall_curve(actual[:, idx],
                                                        output[:, idx])
        averageAUROC.append(auc(fpr, tpr))
        averageAUPR.append(round(auc(recall[idx], precision[idx]),4))

    return accuracy/len(loader.dataset), mean(averageAUROC), mean(averageAUPR)


def train():
    model.train()
    avgLoss = 0
    for data in tqdm(train_loader, total=136):  # Iterate in batches over the training dataset.
        data.x = torch.reshape(data.x, (data.x.shape[0], 2))
        data.x = data.x.type(torch.FloatTensor)
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)# Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        avgLoss += loss
    return avgLoss / 136


for epoch in range(1, 150):
        loss = train()
        train_acc, trainAUC, trainAUPR = test(train_loader_testing, 19)
        test_acc,testAUC, testAUPR = test(test_loader, 19)
        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Train AUC: {trainAUC:.4f}, Train AUPR: {trainAUPR:.4f}, Test Acc: {test_acc:.4f}, Test Auc: {testAUC:.4f}, Test AUPR: {testAUPR:.4f},  Loss: {loss:.4f}')

In [8]:
torch.save(model.state_dict(), 'model_weightsDec24_P0atacseq.pth')

In [9]:
model = GCN(hidden_channels=128, data=dataset, output_size=4).to(device)
model.load_state_dict(torch.load('model_weightsDec23_atacseq.pth'))
test_acc,testAUC, testAUPR = test(test_loader, 4)
print(f'Test Acc: {test_acc:.4f}, Test Auc: {testAUC:.4f}, Test AUPR: {testAUPR:.4f}')

RuntimeError: Error(s) in loading state_dict for GCN:
	size mismatch for lin.weight: copying a param with shape torch.Size([19, 128]) from checkpoint, the shape in current model is torch.Size([4, 128]).
	size mismatch for lin.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([4]).

In [1]:
import pandas as pd

In [10]:
pd.read_csv('/gpfs/data/rsingh47/hzaki1/atacseqdata_P0Brain/resources/P0_chromatin_counts.tsv', sep='\t', usecols=['Unnamed: 0']) 

Unnamed: 0.1,Unnamed: 0
0,TGGAATTTTCTC
1,CCAACAAACGCG
2,TGCGCATAGCCG
3,CTGTTTCCCACC
4,ACAGTCTACATG
...,...
3048,TACAGGCACATT
3049,AGTTCATTGATC
3050,AAAGATTCCGGA
3051,CCCAGACCCCGT


In [4]:
chunksize = 20
for chunk in pd.read_csv('/gpfs/data/rsingh47/hzaki1/atacseqdata_P0Brain/resources/P0_chromatin_counts.tsv', sep='\t', 
                         chunksize=chunksize, 
                         iterator=True):
    print(chunk)

      Unnamed: 0  chr1:3012650-3012823  chr1:3012853-3013002  \
0   TGGAATTTTCTC                     0                     0   
1   CCAACAAACGCG                     0                     0   
2   TGCGCATAGCCG                     0                     0   
3   CTGTTTCCCACC                     0                     0   
4   ACAGTCTACATG                     0                     0   
5   CGCACTTGCGAG                     0                     0   
6   GCAACCTGAACA                     0                     0   
7   TAAAACCCACCA                     0                     0   
8   AAGTTGACCAAG                     0                     0   
9   CCGTGAGCTGCA                     0                     0   
10  CAACAGCTCTCA                     0                     0   
11  ATCGGGTACCAA                     0                     0   
12  CGTCTCGAGTGG                     0                     0   
13  GCGATAATAGCC                     0                     0   
14  GCAATTAAGGAA                     0  

      Unnamed: 0  chr1:3012650-3012823  chr1:3012853-3013002  \
20  GCACCGCTATAA                     0                     0   
21  GAACATTACGAT                     0                     0   
22  AGTATGGCTGTT                     0                     0   
23  TCTCCACTATAG                     0                     0   
24  CGGTGGAAGTTC                     0                     0   
25  CCTTTAATGAGT                     0                     0   
26  GACATCGTTGGG                     0                     0   
27  CTTCGCCAAACC                     0                     0   
28  GCGAGCGGATTA                     0                     0   
29  CGTGCGGTTATT                     0                     0   
30  CACTCAAGACCT                     0                     0   
31  CAGGCTTACGAT                     0                     0   
32  CGCTTTGAGTCC                     0                     0   
33  GGCCATCTATAT                     0                     0   
34  GAGTAGTAAGAT                     0  

KeyboardInterrupt: 

In [12]:
pd.read_csv('/gpfs/data/rsingh47/hzaki1/scMultiOmics_Datasets/AdBrain_cDNA.tsv', sep='\t', header=0, index_col=0)

Unnamed: 0,1.97314109930028,-1.47355385312258,-0.448217220046899,-0.618177270733583,-0.523731410949991,-0.531463725933904,0.65650066109567,0.356099698983707,-1.63603368148679,-0.967244579516991,...,0.874157154589865,-1.33208933413113,-0.3649699626321,0.956946199388234,1.60455042928438,0.458417873825953,-0.36724101310449,-0.0783060150434303,0.865330026382709,0.186055380721716
0,-0.093930,-0.126002,2.642168,1.139502,0.425026,0.543024,-5.038763,-1.464720,-0.137934,-0.715391,...,0.848818,-0.442782,0.166991,0.777194,1.886406,0.818283,0.862078,-0.023490,-0.001514,1.137818
1,-1.364619,9.447015,-2.628073,5.984243,-3.475412,-3.017343,1.073840,1.755634,3.508800,3.074618,...,-0.503803,0.693757,0.698400,-0.613128,1.493784,-0.684394,-0.560127,2.892226,0.795175,-0.063960
2,-1.724235,1.915687,-7.886971,-9.203425,-0.686722,5.294880,-3.968006,2.248889,0.558640,1.097301,...,-1.585393,4.095959,2.076408,4.458320,1.004148,1.699451,2.238309,0.740241,-1.603753,3.018961
3,2.192419,-1.217848,-1.030330,-0.462981,-1.138515,1.105841,3.901958,0.727287,-0.596557,-1.591444,...,-0.787046,1.154933,1.673235,-0.330316,-1.616407,-0.006023,-1.466589,0.241086,-0.650903,-0.448479
4,1.379982,-0.958338,1.849119,1.334524,0.431960,4.136202,-1.471386,0.486075,0.984332,0.857327,...,0.840397,0.004148,-0.307390,-0.123819,0.378542,-0.307433,0.221493,0.018832,-0.353097,0.493338
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10303,3.023066,-1.553193,3.430513,0.338881,0.144317,1.777274,1.615223,0.799734,1.777042,0.338854,...,-0.033092,-0.899616,0.317689,0.083957,-0.536709,0.087572,-0.349393,0.050419,0.103424,-0.441508
10304,-0.041560,-1.085671,-0.623921,-0.489419,0.021923,4.905372,0.966132,0.359771,-1.182128,1.435619,...,0.288949,-0.896097,0.006128,-1.333366,-0.969107,-0.223355,-1.320361,0.239509,-0.287748,-0.048524
10305,1.416114,-0.820919,-1.654158,-1.400860,-1.769234,-1.732952,2.866754,-0.173519,2.488341,0.864748,...,0.480407,-0.623204,-0.168816,-1.954851,0.290394,-0.437938,-1.209863,0.423491,0.057374,0.776191
10306,0.643357,2.743948,-8.461166,-12.643758,-1.364028,3.115762,-2.159244,3.468491,0.181685,4.016421,...,-1.976758,-0.297906,-2.393352,1.762918,1.239178,0.867720,-1.640475,2.265162,-0.910289,0.844010


In [13]:
pd.read_csv('/gpfs/data/rsingh47/hzaki1/scMultiOmics_Datasets/AdBrain_chromatin.tsv', sep='\t', header=0, index_col=0)

Unnamed: 0_level_0,-7.228593930630781150e-01,7.559886283838490595e-02,-1.309199109879449985e+00,3.009467457550289993e-01,1.415285809331649824e+00,-1.929257297112769720e-01,1.085243040896859856e+00,6.291242803336670741e-01,-7.337399238905991661e-01,-2.967570810367570200e-01,9.898743246594541301e-01,-1.093108313508130092e+00,-7.815797181639820845e-01,-1.199772056985540125e+00,1.851957779876129973e+00,9.322247819574491778e-01
-9.503142994098140450e-01,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
-0.915965,-0.296099,-0.127501,-0.608419,2.914415,0.982705,-0.648370,-0.050112,-0.080515,-0.611811,1.657346,-0.885185,-0.275118,-0.346100,-0.821616,-0.034660,0.147004
-0.227151,-0.681414,0.620084,-0.419685,-0.168625,1.696643,1.869309,-0.474545,-0.837446,-0.913827,-0.591524,-1.004256,-0.383973,-0.066880,2.035149,-0.935240,0.483381
-0.362644,3.017248,-0.476411,-0.231703,-0.315676,1.604603,-0.370425,-0.545108,-0.662196,-0.553078,-0.248308,-0.390161,-0.523569,-0.221171,-0.294042,-0.673771,1.246412
2.392906,-0.631587,0.216773,2.435128,0.045678,0.210309,-0.042039,-0.778954,-0.716767,0.078326,-0.381776,-0.600903,0.123984,-0.093016,-0.839177,-1.200270,-0.218617
-0.237181,-1.294659,-0.391159,0.507393,0.380755,1.040168,-0.758680,0.108862,1.279064,0.995772,-0.014815,-0.845140,-1.005080,2.192198,-1.090201,-1.246105,0.378808
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
-0.127454,-1.264101,0.099070,0.688536,-0.084694,1.100681,-1.128314,0.519592,1.104672,0.796068,0.721811,-1.122662,-0.681432,2.034477,-1.138154,-1.235292,-0.282803
-0.793059,-0.613557,-0.597805,-0.779863,-0.657746,1.492512,-0.819387,1.405637,1.688470,1.826531,-0.714651,-0.746858,-0.410698,0.794025,-0.788586,-0.188988,-0.095978
-0.524536,-1.124329,1.046190,1.338763,-0.316343,1.264220,-0.458982,-0.442760,0.727317,-0.361650,0.097441,-1.370566,0.111363,1.180930,-1.378162,-1.234706,1.445810
-0.641849,1.628786,-0.657217,-0.699313,-0.656994,2.595201,-0.154149,0.419300,-0.281475,-0.534643,-0.731757,0.058778,-0.342280,-0.760043,-0.079016,-0.742002,1.578672


In [14]:
pd.read_csv('/gpfs/data/rsingh47/hzaki1/scMultiOmics_Datasets/AdBrainCortex_metadata.tsv', sep='\t', header=0, index_col=0)

Unnamed: 0_level_0,Barcode,Ident
Batch,Unnamed: 1_level_1,Unnamed: 2_level_1
09A,CAGCCCCGCCTT,E3Rorb
09A,CGCCTACCATGA,E5Parm1
09A,GATGCGCGGCTA,Ast
09A,GGTCCGAGTCCT,E4Il1rapl2
09A,TCTCCCGGCACC,E5Parm1
...,...,...
09L,TACTAGTTCAAG,E3Rorb
09L,ATGACGGGCCCC,Mis
09L,GAAACACCTCAT,Mis
09L,AACGGTTTATCC,InP
