In [1]:
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
from torch.autograd import Variable
import torch
import torch.nn.functional as F

In [2]:
ica_tc_file = './dualreg_mat.npy'
pheno_file = './pheno.csv'

ica_tc = np.load(ica_tc_file)
pheno = pd.read_csv(pheno_file)

ic19 = [1,3, 5, 8, 9, 10, 13, 15, 16, 17, 19, 21, 23, 24, 25, 27, 29, 30, 33]

ic7 = [5, 8, 13, 16, 17, 24, 29]

In [3]:
print(ica_tc.shape)
print(pheno.shape)

(359, 175, 52)
(359, 19)


In [4]:
tc = np.random.randn(5000, 100)
tc2 = tc*0.8+0.3*np.random.randn(5000, 100)

corrtc = np.zeros((5000,100,2))
corrtc[:,:,0] = tc
corrtc[:,:,1] = tc2

uncorrtc = np.random.randn(5000, 100,2)

data = np.vstack([corrtc, uncorrtc])
label = np.hstack([np.zeros(5000), np.ones(5000)])

idx = np.random.permutation(range(10000))

data_shuffle = data[idx]
label_shuffle = label[idx].astype('int64')

In [5]:
class Abide(DataLoader):
    
    def __init__(self, data_shuffle, label_shuffle, split='train'): #input_path, label, split='train'):
        super(Abide).__init__()
        
        ica_tc = data_shuffle #np.load(input_path)[:,:,ic7]
        dx = label_shuffle # pd.read_csv(label)['DX_GROUP']-1 # in order to make label 0,1 instead of 1,2
        dx = np.eye(2)[dx] #.values] # 1-hot encode 

        split_at = 8000
        if split == 'train':
            self.ica_tc = ica_tc[:split_at]
            self.dx = dx[:split_at]
            
        if split == 'test':
            self.ica_tc = ica_tc[split_at:]
            self.dx = dx[split_at:] 
            
        
    def __len__(self):   
        return len(self.dx)
        
    def __getitem__(self, index):
        return self.ica_tc[index].transpose(), self.dx[index]
        

In [6]:
class Abide1DConvNet(nn.Module):
    def __init__(self, numICs=2):
        super(Abide1DConvNet, self).__init__()
        
        self.conv1 = nn.Conv1d(numICs, 16, 7)
        self.conv2 = nn.Conv1d(16, 8, 5)
        self.conv3 = nn.Conv1d(8, 16, 5)
        self.avg = nn.AdaptiveAvgPool1d((1))
        
        self.linear1 = nn.Linear(16, 2)
        #self.linear2 = nn.Linear(100, 2)
        
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.avg(x).view(-1, 16)
        x = F.relu(self.linear1(x))
        #x = self.linear2(x)
        
        return x
    

In [7]:
# train_data = Abide(input_path='./dualreg_mat.npy', label='./pheno.csv', split='train')
# test_data = Abide(input_path='./dualreg_mat.npy', label='./pheno.csv', split='test')
train_data = Abide(data_shuffle, label_shuffle, split='train')
test_data = Abide(data_shuffle, label_shuffle, split='test')

In [8]:
train_data_loader = DataLoader(train_data, batch_size=256, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=1, shuffle=True)

In [9]:
net = Abide1DConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=.01, weight_decay=0.1)
nepochs=100

net.train()
track_loss=[]

for i_epoch in range(nepochs):
    print(i_epoch)
    epoch_loss = 0.0
    for i, (ic, dx) in enumerate(train_data_loader):
        
        ic = Variable(ic).type(torch.FloatTensor)
        dx = Variable(dx).type(torch.LongTensor)
        
        # forward pass
        output = net(ic)
        
        # calculate loss
        loss = criterion(output, torch.max(dx,1)[1])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss
        
    track_loss.append(epoch_loss/i)
    print('epoch loss = {}'.format(epoch_loss))    

0
epoch loss = 22.18522071838379
1
epoch loss = 22.18089485168457
2
epoch loss = 22.180675506591797
3
epoch loss = 22.180944442749023
4


KeyboardInterrupt: 

In [162]:
net.eval()
acc = []
for ic,dx in train_data_loader:
    
    ic = Variable(ic).type(torch.FloatTensor)
    dx = Variable(dx).type(torch.LongTensor)
    
    output = net(ic)
    acc.append(torch.argmax(output) == torch.argmax(dx))
    
print('Total accuracy = {}'.format(sum(acc).item()/len(train_data_loader)))

Total accuracy = 0.5862068965517241


In [138]:
sum(acc).item()

12