In [1]:
import pickle

fname = '../../data/fmri-FC-slim.pkl'
fmriDict = None

with open(fname, 'rb') as f:
    fmriDict = pickle.load(f)
    
print(fmriDict.keys())

dict_keys(['FC-slim', 'subjNum2IdxMap', 'subjIdx2NumMap', 'groupsNormalDiagMap'])


In [2]:
def getGroupIdcs(groupsMap):
    normals = []
    fibros = []
    for num,diag in groupsMap.items():
        if diag == 1:
            normals.append(num)
        else:
            fibros.append(num)
    return normals,fibros

normals, fibros = getGroupIdcs(fmriDict['groupsNormalDiagMap'])

print(len(normals))
print(normals)
print(len(fibros))
print(fibros)

33
['007', '012', '014', '016', '018', '021', '022', '026', '030', '031', '032', '033', '034', '036', '042', '045', '047', '056', '058', '059', '060', '061', '064', '066', '068', '069', '070', '072', '073', '074', '075', '076', '077']
33
['002', '004', '005', '006', '008', '009', '010', '011', '013', '015', '017', '019', '020', '023', '024', '025', '028', '029', '037', '038', '039', '040', '043', '044', '046', '049', '050', '052', '053', '054', '055', '062', '063']


In [47]:
nRuns = 40
L1 = 0
L2 = 2e-4
hidden = 40
lr = 2e-4
epochs = 6000
pPeriod = 2000

model = f'MLP L2={L2} L1={L1} hidden={hidden} lr={lr} epochs={epochs}'
desc = 'normal(1) vs. fibromyalgia(0) rest fMRI only'

import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, X, hidden=20):
        super(MLP, self).__init__()
        self.A = nn.Linear(X.shape[-1],hidden).float().cuda()
        self.B = nn.Linear(hidden,2).float().cuda()
        
    def forward(self, X):
        y = F.relu(self.A(X))
        y = self.B(y)
        return y
    
def sparsityLoss(A, L1):
    loss = torch.sum(torch.abs(A.weight)) + torch.sum(torch.abs(A.bias))
    return L1*loss

def validate(mlp, Xtest, ytest):
    acc = None
    mlp.eval()
    with torch.no_grad():
        yhat = mlp(Xtest)
        acc = torch.sum(torch.argmax(yhat, dim=1) == torch.argmax(ytest, dim=1))/yhat.shape[0]
    mlp.train()
    return float(acc)

ceLoss = torch.nn.CrossEntropyLoss()

import sys

sys.path.append('../../src')

from imagenomer import Analysis, JsonData, JsonSubjects, JsonFCMetadata

analysis = Analysis(f'{desc}: {model}', 'localhost')

a,b = np.triu_indices(264,1)
idcs = np.arange(34716)

labels = [f'{a[i]}-{b[i]}' for i in idcs]

print(labels[0:10])
print(labels[-10:])

for run in range(nRuns):
    FCslim = fmriDict['FC-slim']
    subjNum2Idx = fmriDict['subjNum2IdxMap']

    random.shuffle(normals)
    random.shuffle(fibros)

    normalTrain = normals[:25]
    normalTest = normals[25:]

    fibrosTrain = fibros[:25]
    fibrosTest = fibros[25:]

    train = normalTrain + fibrosTrain
    test = normalTest + fibrosTest

    Xtrain = [FCslim[subjNum2Idx[num]] for num in train]
    Xtest = [FCslim[subjNum2Idx[num]] for num in test]

    Xtrain = np.stack(Xtrain)
    Xtest = np.stack(Xtest)

    ytrain = np.concatenate([np.ones(25), np.zeros(25)])
    ytest = np.concatenate([np.ones(8), np.zeros(8)])
    
    Xtrain_t = torch.from_numpy(Xtrain).float().cuda()
    Xtest_t = torch.from_numpy(Xtest).float().cuda()
    
    ytrain_t = np.zeros((ytrain.shape[0],2))
    ytest_t = np.zeros((ytest.shape[0],2))
    
    ytrain_t[:,0] = ytrain
    ytrain_t[:,1] = 1-ytrain
    ytest_t[:,0] = ytest
    ytest_t[:,1] = 1-ytest
     
    ytrain_t = torch.from_numpy(ytrain_t).float().cuda()
    ytest_t = torch.from_numpy(ytest_t).float().cuda()

#     print(Xtrain.shape)
#     print(Xtest.shape)
#     print(ytrain.shape)
#     print(ytest.shape)

    mlp = MLP(Xtrain_t, hidden=hidden)
    optim = torch.optim.Adam(mlp.parameters(), lr=1e-4, weight_decay=L2)
    
    for epoch in range(epochs):
        optim.zero_grad()
        yhat = mlp(Xtrain_t)
        loss = ceLoss(yhat, ytrain_t)
        sloss = sparsityLoss(mlp.A, L1)
        (loss+sloss).backward()
        optim.step()
        if epoch % pPeriod == 0 or epoch == epochs-1:
            print(f'epoch {epoch} loss {float(loss)} sloss {float(sloss)}')
    
    acc = validate(mlp, Xtest_t, ytest_t)
    
    print(f'{run}. {acc}')
    
    w = None
    with torch.no_grad():
        w = torch.sum(mlp.A.weight, axis=0).detach().cpu().numpy()
        w = np.mean(np.expand_dims(w,0)*Xtest, axis=0)
    
    jsonCompare = desc
    jsonAccuracy = acc
    jsonTrain = [25,25]
    jsonTest = [8,8]
    jsonWeights = w.astype('float64')
    jsonLabels = labels

    jsonObj = {
        'Compare': jsonCompare,
        'Model': 'MLP',
        'Accuracy': jsonAccuracy,
        'Train': jsonTrain,
        'Test': jsonTest,
        'Weights': list(jsonWeights),
        'Labels': jsonLabels
    }
    
    dat = JsonData(analysis)
    dat.dict.update(jsonObj)
    r = dat.post()
    print(r.content)


['0-1', '0-2', '0-3', '0-4', '0-5', '0-6', '0-7', '0-8', '0-9', '0-10']
['259-260', '259-261', '259-262', '259-263', '260-261', '260-262', '260-263', '261-262', '261-263', '262-263']
epoch 0 loss 0.701495885848999 sloss 0.0
epoch 2000 loss 0.0001963208196684718 sloss 0.0
epoch 4000 loss 0.00011640261800494045 sloss 0.0
epoch 5999 loss 6.303082773229107e-05 sloss 0.0
0. 0.6875
b'Success'
epoch 0 loss 0.6892275214195251 sloss 0.0
epoch 2000 loss 0.00011313882714603096 sloss 0.0
epoch 4000 loss 6.839713751105592e-05 sloss 0.0
epoch 5999 loss 4.738671486848034e-05 sloss 0.0
1. 0.5625
b'Success'
epoch 0 loss 0.6910544633865356 sloss 0.0
epoch 2000 loss 0.00012552013504318893 sloss 0.0
epoch 4000 loss 5.3430310799740255e-05 sloss 0.0
epoch 5999 loss 4.000569606432691e-05 sloss 0.0
2. 0.25
b'Success'
epoch 0 loss 0.6925092935562134 sloss 0.0
epoch 2000 loss 0.00013574917102232575 sloss 0.0
epoch 4000 loss 7.804497727192938e-05 sloss 0.0
epoch 5999 loss 5.343499651644379e-05 sloss 0.0
3. 0.625

epoch 5999 loss 5.473679630085826e-05 sloss 0.0
38. 0.4375
b'Success'
epoch 0 loss 0.6981229782104492 sloss 0.0
epoch 2000 loss 0.00012298408546485007 sloss 0.0
epoch 4000 loss 6.432521331589669e-05 sloss 0.0
epoch 5999 loss 4.71387647849042e-05 sloss 0.0
39. 0.5
b'Success'


In [48]:
commNames = []
commAffil = {}
commAbrev = 'SMH,SMM,CNG,AUD,DMN,MEM,VIS,FRT,SAL,SUB,VTR,DRL,CB,UNK'.split(',')
commCount = np.zeros(14)

powerAffilFname = '../../power/power264CommunityAffiliation.1D'

with open(powerAffilFname, 'r') as f:
    for i,line in enumerate(f.readlines()):
        commAffil[i] = int(line)-1
        commCount[commAffil[i]] += 1
        
print(commAffil)
print(commCount.astype('int'))

{0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 13, 6: 13, 7: 13, 8: 13, 9: 13, 10: 13, 11: 13, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 36: 0, 37: 0, 38: 0, 39: 0, 40: 0, 41: 1, 42: 1, 43: 1, 44: 1, 45: 1, 46: 2, 47: 2, 48: 2, 49: 2, 50: 2, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 3, 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 3, 72: 3, 73: 4, 74: 4, 75: 4, 76: 4, 77: 4, 78: 4, 79: 4, 80: 4, 81: 4, 82: 4, 83: 13, 84: 13, 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 4, 92: 4, 93: 4, 94: 4, 95: 4, 96: 4, 97: 4, 98: 4, 99: 4, 100: 4, 101: 4, 102: 4, 103: 4, 104: 4, 105: 4, 106: 4, 107: 4, 108: 4, 109: 4, 110: 4, 111: 4, 112: 4, 113: 4, 114: 4, 115: 4, 116: 4, 117: 4, 118: 4, 119: 4, 120: 4, 121: 4, 122: 4, 123: 4, 124: 4, 125: 4, 126: 4, 127: 4, 128: 4, 129: 4, 130: 4, 131: 13, 132: 5, 133: 5, 134: 5, 135: 5, 136: 4

In [49]:
jsonMeta = JsonFCMetadata(analysis)
metaDict = {
    'CommunityMap': commAffil,
    'CommunityNames': commAbrev
}
jsonMeta.update(metaDict)
r = jsonMeta.post()
print(r.content)

b'Success'
