In [None]:
import dataloader
import pandas as pd
import utils
from model import generate_model
import os
import torch
import torch.nn as nn
import torch.optim as optim

from train_wrapper import train_epoch

In [None]:
config = utils.load_config()

batch_size = config['dataloader']['batch_size']
num_workers = config['dataloader']['num_workers']
pin_memory = config['dataloader']['pin_memory'] == 1 
gpu_parallel = config['gpus']

In [None]:
#데이터셋 분리(Train, validation, test)
df_dataset = pd.read_csv(config['PATH_DATASET_CSV'])
df_dataset = df_dataset.dropna().reset_index(drop=True)
df_oasis = df_dataset[df_dataset['source'] == 'OASIS-3']
df_adni = df_dataset[df_dataset['source'] == 'ADNI']
X_train,X_val,y_train,y_val = dataloader.dataset_split(df_oasis,test_size=0.2,shuffle=False,)
X_test = df_adni.drop(labels='group_maxinc',axis=1)
y_test = df_adni['group_maxinc']

In [None]:
traindata=dataloader.MRIDataset(X_train,y_train)
valdata=dataloader.MRIDataset(X_val,y_val)
testdata=dataloader.MRIDataset(X_test,y_test)

from torch.utils.data import DataLoader
train_dataloader = DataLoader(traindata, batch_size=batch_size, shuffle=True
                              ,num_workers=num_workers,pin_memory = pin_memory)
val_dataloader  = DataLoader(valdata, batch_size=batch_size, shuffle=False
                              ,num_workers=num_workers,pin_memory = pin_memory)
test_dataloader  = DataLoader(testdata, batch_size=1, shuffle=False)

print('train_dataloader : ',len(train_dataloader.dataset))
print('val_dataloader : ',len(val_dataloader.dataset))
print('test_dataloader : ',len(test_dataloader.dataset))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model_name = config['model']['model_name']
model_depth = config['model']['model_depth']

model, _ = generate_model(model_name=model_name,model_depth = model_depth,n_classes=3,resnet_shortcut='B')
model.to(device)

if len(gpu_parallel) > 1 and torch.cuda.is_available():
    model = nn.DataParallel(model, device_ids = gpu_parallel)
    model.to(device)

optimizer = optim.Adam(model.parameters(), lr= 0.0001, betas=(0.5, 0.999))
criterion_clf = nn.CrossEntropyLoss().to(device)

In [None]:
train_epoch(device,train_dataloader,val_dataloader,model,criterion_clf,optimizer,config,epoch=100)