In [1]:
!git clone https://github.com/atolstikov/aidao24
!pip install nilearn

fatal: destination path 'aidao24' already exists and is not an empty directory.


In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from nilearn.connectome import ConnectivityMeasure
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
class MedicalDataset(Dataset):

    def __init__(self,data,labels,transform):
        super().__init__()
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        data = self.transform(self.data)
        data = data.permute(1,0,2)
        data = data[:,None,:,:]
        return data[idx], self.labels[idx]

In [5]:
bnu_csv = pd.read_csv("/content/aidao24/qualifiers_track_2/data/ts_cut/HCPex/bnu.csv").iloc[:,0]
bnu1 = np.load("/content/aidao24/qualifiers_track_2/data/ts_cut/HCPex/bnu1.npy")
bnu2 = np.load("/content/aidao24/qualifiers_track_2/data/ts_cut/HCPex/bnu2.npy")
ihb_csv = pd.read_csv("/content/aidao24/qualifiers_track_2/data/ts_cut/HCPex/ihb.csv").iloc[:,0]
ihb = np.load("/content/aidao24/qualifiers_track_2/data/ts_cut/HCPex/ihb.npy")

In [6]:
bnu = np.concatenate((bnu1,bnu2))

In [7]:
measure = ConnectivityMeasure(kind = 'correlation')
bnu_cm = measure.fit_transform(bnu)
ihb_cm = measure.fit_transform(ihb)

In [8]:
data = np.concatenate((bnu_cm,ihb_cm))
labels = np.concatenate((bnu_csv,ihb_csv))

In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])

In [10]:
dataset = MedicalDataset(data,labels,transform)

In [11]:
loader = DataLoader(dataset, 5, True)

In [12]:
class ConvClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.forw = nn.Sequential(
            #1x419x419
            nn.Conv2d(1,8,30,5,0),
            nn.BatchNorm2d(8),
            nn.SiLU(),
            #8x78x78
            nn.Conv2d(8,32,4,2,0),
            nn.MaxPool2d(3,1,0),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            #32x36x36
            nn.Conv2d(32,128,5,2,0),
            nn.MaxPool2d(2,1,0),
            nn.BatchNorm2d(128),
            nn.SiLU(),
            #128x15x15
            nn.Conv2d(128,512,4,2,0),
            nn.BatchNorm2d(512),
            nn.SiLU(),
            #512x6x6
            nn.Conv2d(512,512,6,2,0),
            nn.Flatten(),
            nn.Linear(512,128),
            nn.SiLU(),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )

    def forward(self,x):
        return self.forw(x)

In [15]:
model = ConvClassifier().to(DEVICE)
opt = optim.Adam(model.parameters(),lr = 3e-5)
crit = nn.BCELoss()

In [16]:
for epoch in range(5):
    with tqdm(loader) as TQDM:
        wrong = 0
        for idx, (data,labels) in enumerate(TQDM):
            if idx<30:
                model.train()
                data = data.to(DEVICE).type(torch.float32)
                labels = labels.to(DEVICE).type(torch.float32)
                predicted_labels = model(data).flatten()
                loss = crit(predicted_labels, labels)

                opt.zero_grad()
                loss.backward()
                opt.step()

                TQDM.set_postfix({"loss":loss.item()})

            else:
                model.eval()
                with torch.no_grad():
                    data = data.to(DEVICE).type(torch.float32)
                    labels = labels.to(DEVICE).type(torch.float32)
                    predicted_labels = torch.round(model(data).flatten())
                    wrong += torch.sum(torch.abs(predicted_labels - labels))
        print(1 - wrong/12)

100%|██████████| 33/33 [01:14<00:00,  2.26s/it, loss=0.381]


tensor(0.5000, device='cuda:0')


100%|██████████| 33/33 [01:14<00:00,  2.25s/it, loss=0.472]


tensor(0.9167, device='cuda:0')


100%|██████████| 33/33 [01:15<00:00,  2.30s/it, loss=0.0537]


tensor(0.9167, device='cuda:0')


100%|██████████| 33/33 [01:14<00:00,  2.24s/it, loss=0.0106]


tensor(1., device='cuda:0')


100%|██████████| 33/33 [01:15<00:00,  2.28s/it, loss=0.00573]

tensor(1., device='cuda:0')



