In [1]:
import torch
import sys
import yaml
from torchvision import transforms, datasets
import torchvision
import numpy as np
import os
from sklearn import preprocessing
from torch.utils.data.dataloader import DataLoader

In [2]:
sys.path.append('../')
from models.pointnetv2_encoder import PointNetV2

In [3]:
batch_size = 40

In [4]:
config = yaml.load(open("../config/config.yaml", "r"), Loader=yaml.FullLoader)

In [5]:
# from data_utils.ModelNetDataLoader import ModelNetDataLoader
# train_dataset = ModelNetDataLoader('/data/nerdxie/neuron_cluster/neuron_dataset_1',split='train', uniform=True, normal_channel=True,istrain=False,)
# test_dataset = ModelNetDataLoader('/data/nerdxie/neuron_cluster/neuron_dataset_1',split='test', uniform=True, normal_channel=True,istrain=False,)
from data_utils.TestDataLoader import FAFBDataset
# full_dataset = ModelNetDataLoader('./data_neuron.h5', uniform=True, normal_channel=True, istrain=False)
full_dataset = FAFBDataset('../../fafb-cellseg', uniform=True, normal_channel=True, catfile='classname_3c.txt')

labelname = ['neurite', 'glia', 'soma']

train_size = int(0.5 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

The size of all data is 1726


In [6]:
print("Input shape:", train_dataset[0][0].shape)

Input shape: (1024, 3)


In [7]:
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          num_workers=4, drop_last=True, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=int(batch_size/2),
                          num_workers=4, drop_last=True, shuffle=True)

In [8]:
device = 'cuda' #'cuda' if torch.cuda.is_available() else 'cpu'
encoder = PointNetV2(**config['network']).to(device)
# encoder.projetion = torch.nn.Sequential()
encoder = torch.nn.DataParallel(encoder).cuda()

In [9]:
#load pre-trained parameters
load_params = torch.load(os.path.join('../pre-trained/model_sslcontrast.pth'),
                         map_location=torch.device(torch.device(device)))

# for i in load_params['online_network_state_dict'].keys():
#     if 'projetion' in i:
#         del load_params['online_network_state_dict'][i]

if 'online_network_state_dict' in load_params:
    encoder.load_state_dict(load_params['online_network_state_dict'])
    print("Parameters successfully loaded.")

# remove the projection head
# encoder = torch.nn.Sequential(*list(encoder.cpu().children())[:-1])    
# encoder = encoder.to(device)
# output_feature_dim = 1024

Parameters successfully loaded.


In [10]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

In [11]:
logreg = LogisticRegression(1024, 3)
logreg = logreg.to(device)

In [12]:
def get_features_from_encoder(encoder, loader):
    
    x_train = []
    y_train = []

    # get the features from the pre-trained model
    for batch, y, index in loader:
        batch = batch.to(device)
        features = encoder(batch)[1]
        x_train.extend(features)
        y_train.extend(y.numpy())

            
    x_train = torch.stack(x_train)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [None]:
from itertools import chain

optimizer = torch.optim.Adam(params=chain(logreg.parameters(), encoder.parameters()), lr=7e-5)
criterion = torch.nn.CrossEntropyLoss()
eval_every_n_epochs = 2

best_acc = 0
soma_acc = 0
neurite_acc = 0
glia_acc = 0
eval_every_n_epochs = 1
best_val = 0
val_acc = 0
test_acc = 0
best_epoch = 0
best_acc = 0

true_name = []
torch.backends.cudnn.enabled = False

for epoch in range(200):
#     train_acc = []
    print('Starting Training Epoch {}'.format(epoch))
    total = 0
    correct = 0
    for batch, y, _ in train_loader:
        encoder.train()
        logreg.train()
        batch = batch.to(device)
        features = encoder(batch)[1]
        logits = logreg(features)
        y = y.to(device).squeeze(1).long()
        
        # zero the parameter gradients
        optimizer.zero_grad() 
        loss = criterion(logits, y)
        
        loss.backward()
        optimizer.step()
        
        predictions = torch.argmax(logits, dim=1)
        total += y.size(0)
        correct += (predictions == y).sum().item()
    train_acc = 100 * correct / total
    print(f"Training accuracy: {np.mean(train_acc)}")
#     if epoch == 9:
#         optimizer = torch.optim.Adam(params=chain(logreg.parameters(), encoder.parameters()), lr=3e-5)
    total = 0
    if epoch % eval_every_n_epochs == 0:
        encoder.eval()
        logreg.eval()
        true_name = []
        best_val = val_acc
        correct_soma = 0
        correct_neurite = 0
        correct_glia = 0
        total_soma = 0
        total_neurite = 0
        total_glia = 0
        correct = 0
        total = 0
        for x, y, _ in test_loader:
            x = x.to(device)
            y = y.to(device).squeeze(0).long()
            y = y.squeeze(1)
            # print(y.shape)
            x = encoder(x)[1]

            logits = logreg(x)
            predictions = torch.argmax(logits, dim=1)

            total += y.size(0)
            total_neurite += (predictions == 0).sum().item()
            total_glia += (predictions == 1).sum().item()
            total_soma += (predictions == 2).sum().item()
            y = y.cpu().numpy()
            predictions = predictions.cpu().numpy()
            true_sample = y[np.where(predictions == y)]
            correct_neurite += (true_sample == 0).sum()
            correct_glia += (true_sample == 1).sum()      
            correct_soma += (true_sample == 2).sum()
            correct += len(true_sample)
#                 true_name.extend(name[np.where(predictions == y)])

        test_acc = np.mean(100 * correct / total)
        best_acc = max(test_acc, best_acc)
        if best_acc == test_acc:
            torch.save(logreg.state_dict(), 'save_model/linear.pt')
            torch.save(encoder.state_dict(), 'save_model/encoder.pt')
            print(f'saving best finetuned model at epoch: {epoch}')
        soma_acc = correct_soma / total_soma
        glia_acc = correct_glia / total_glia
        neurite_acc = correct_neurite / total_neurite
        best_epoch = epoch

    print(f"Current Val Accuracy: {val_acc}, Current Test Accuracy: {test_acc} on Epoch {best_epoch}. Soma {soma_acc} Neurite {neurite_acc} Glia {glia_acc}")