In [1]:
import os
import time
from datetime import datetime
import pickle
from collections import defaultdict
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import networkx as nx

from diff_pool6_max import DiffPool

import torch
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam

from torch_geometric.data import Data, DataLoader, DenseDataLoader as DenseLoader
from torch_geometric.data import InMemoryDataset
import torch_geometric.transforms as T


# Create dataset 

In [2]:
with open(r'./data/patient_gumbel_train.pickle', 'rb') as handle:
    patient_dict_train = pickle.load(handle)
with open(r'./data/patient_gumbel_val.pickle', 'rb') as handle:
    patient_dict_val = pickle.load(handle)
    
patient_dict = defaultdict(list)
for dic in (patient_dict_train, patient_dict_val):
    for key, value in dic.items():
        patient_dict[key] += value
        

In [3]:
class PatientDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(PatientDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['patient.dataset']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []
        node_labels_dict = {'Tumor': 0, 'Stroma': 1, 'TIL1': 2, 'TIL2': 3, 'NK': 4, 'MP': 5}
        class_num = len(node_labels_dict)
        
        for idx, v in enumerate(patient_dict.values()):
            for G in v:
                node_features = torch.LongTensor([node_labels_dict[i] for i in 
                                list(nx.get_node_attributes(G, 'cell_types').values())]).unsqueeze(1)
                x = torch.zeros(len(G.nodes), class_num).scatter_(1, node_features, 1)
                y = torch.LongTensor([idx])
                edges = sorted([e for e in G.edges] + [e[::-1] for e in G.edges])
                edge_index = torch.tensor([[e[0] for e in edges],
                                           [e[1] for e in edges]], dtype=torch.long)
                data = Data(x=x, edge_index=edge_index, y=y)
                data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [4]:
def get_dataset(path, sparse=False):
    
    dataset = PatientDataset(path)
    if not sparse:
        max_num_nodes = 0
        for data in dataset:
            max_num_nodes = max(data.num_nodes, max_num_nodes)

        if dataset.transform is None:
            dataset.transform = T.ToDense(max_num_nodes)
        else:
            dataset.transform = T.Compose(
                [dataset.transform, T.ToDense(max_num_nodes)])

    return dataset


In [5]:
dataset_dict = {} 
# This dataset includes both training and validation data
path = './data/patient_gumbel_val'
dataset_dict['gumbel2_5'] = get_dataset(path, sparse=False)


Processing...
Done!



## Train

In [6]:
device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
num_patients=10
num_patches = 5

def run(dataset, model, epochs, batch_size, lr, lr_decay_factor, lr_decay_step_size,
        weight_decay, logger=None, resume=None):
    
    lines = []
    train_indices = []
    for i in range(num_patients):
        tmp = [2*num_patches*i+ j for j in range(num_patches)]
        train_indices += tmp
    test_indices = sorted(list(set(range(num_patients*num_patches*2)) - set(train_indices)))
    train_indices = torch.tensor(train_indices)
    test_indices = torch.tensor(test_indices)
    train_dataset = dataset[train_indices]
    test_dataset = dataset[test_indices]
    
    if 'adj' in dataset[0]:
        # This data loader only works with dense adjacency matrices
        train_loader = DenseLoader(train_dataset, batch_size, shuffle=False)
        test_loader = DenseLoader(test_dataset, batch_size, shuffle=False)
    else:
        train_loader = DataLoader(train_dataset, batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
        
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    # save on cpu, load on cpu
    if resume:
        last_checkpoint = torch.load(dir_path + 'checkpoint_last.pt')
        model.load_state_dict(last_checkpoint['state_dict'])
        model.to(device)
        optimizer.load_state_dict(last_checkpoint['optimizer'])
        start_epoch = last_checkpoint['epoch']+1
    else:
        model.to(device).reset_parameters()
        start_epoch = 1
        
    if torch.cuda.is_available():
        torch.cuda.synchronize()
            
    # !!! save initial parameters
    torch.save(model.state_dict(), dir_path+'params_epoch{}.pt'.format(0))
    
    t_start = time.perf_counter()
    
    for epoch in tqdm(range(start_epoch, start_epoch + epochs)):
        train_loss, train_acc = train(model, optimizer, train_loader)

        eval_info = {
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
        }

        if logger is not None:
            lines.append(logger(eval_info))

        if epoch % lr_decay_step_size == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_decay_factor * param_group['lr']

        if epoch % 5 == 0 or epoch == 1:
            test_loss, test_acc = eval_loss_acc(model, test_loader)
            lines.append('Test Loss: {:.4f}, Test Accuracy: {:.3f}'.format(test_loss, test_acc))
            torch.save(model.state_dict(), dir_path+'params_epoch{}.pt'.format(epoch))
            
    checkpoint = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, dir_path + 'checkpoint_last.pt')
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    t_end = time.perf_counter()
    duration = t_end - t_start
    lines.append('Duration: {:.3f}'.format(duration))
    
    return lines
      

# def shuffle(dataset, m=20):
#     indices = []
#     for i in range(10):
#         tmp = [j for j in range(i*m, i*m+m)]
#         random.shuffle(tmp)
#         indices += tmp
#     return indices

def train(model, optimizer, loader):
    model.train()

    total_loss = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data)
        len_ = len(data.y)
        indices = [i for i in range(0, len_, num_patches)]
        loss = F.nll_loss(out, data.y[indices].view(-1), reduction='sum')
        
        loss.backward()
        total_loss += loss.item()
        pred = out.max(1)[1]
        correct += pred.eq(data.y[indices].view(-1)).sum().item()
        optimizer.step()

    return total_loss / (len(loader.dataset)/num_patches), correct / (len(loader.dataset)/num_patches)

def eval_loss_acc(model, loader):
    model.eval()
    
    y_pred = []
    correct = 0
    loss = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
        pred = out.max(1)[1]
        y_pred += pred.tolist()
        len_ = len(data.y)
        indices = [i for i in range(0, len_, num_patches)]
        correct += pred.eq(data.y[indices].view(-1)).sum().item()
        loss += F.nll_loss(out, data.y[indices].view(-1), reduction='sum').item()
    return loss / (len(loader.dataset)/num_patches), correct / (len(loader.dataset)/num_patches)


# Main

In [7]:
num_layers = 5
hidden = 64
num_hops = 2
batch_size = 50
ratio = 0.05
dropout = False
Net =  DiffPool 

def logger(info):
    epoch = info['epoch']
    train_loss, train_acc = info['train_loss'], info['train_acc']
    output = '{:03d}: Train Loss: {:.4f}, Train Accuracy: {:.3f}'\
              .format(epoch, train_loss, train_acc)
    return output

dir_path = './data'+ '/' + 'DiffPool_diff_pool6_max_bs50'+ '/' + 'gumbel2_5' + '/'
hyperparams_name = 'hyperparams.pickle'

if os.path.exists(dir_path + hyperparams_name):
    with open(r'{}'.format(dir_path + hyperparams_name), 'rb') as handle:
        hyperparams = pickle.load(handle)
        num_layers = hyperparams['# of layers']
        hidden = hyperparams['# of hidden units']
        num_hops = hyperparams['# of hops']
        batch_size = hyperparams['batch size']
        ratio = hyperparams['pooling ratio']
        dropout = hyperparams['dropout']
        num_patches = hyperparams['# of patches']

else:
    lines = ['-----\{}'.format(Net.__name__)]
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    lines.append('Num of Layers: {}, Num of Hidden Units: {}, Num of Hops: {}, Batch Size: {}, ' \
                 'Pooling Ratio: {}, Dropout: {}, Num of Patches: {}' \
                 .format(num_layers, hidden, num_hops, batch_size, ratio, dropout, num_patches))
    dataset = dataset_dict['gumbel2_5']
    # For diff_pool6
    model = Net(dataset, num_layers, hidden, hop=num_hops, num_patches=num_patches, ratio=ratio, dropout=dropout)

    process_lines = \
    run(
        dataset,
        model,
        epochs=500,
        batch_size=batch_size,
        lr=0.01,
        lr_decay_factor=0.5,
        lr_decay_step_size=50,
        weight_decay=0,
        logger=logger
    )
    lines += process_lines

    hyperparams = {'# of layers': num_layers, '# of hidden units': hidden, '# of hops': num_hops, \
                   'batch size': batch_size, 'pooling ratio': ratio, 'dropout':dropout, '# of patches': num_patches}
                           
    with open(r'{}'.format(dir_path + hyperparams_name), 'wb') as handle:
        pickle.dump(hyperparams, handle, protocol=pickle.HIGHEST_PROTOCOL)
    now = datetime.now()
    date_time = now.strftime("%Y-%m-%d_%H-%M")
    filename = 'log_' + date_time + '.txt'
    logfile = open(dir_path + filename, 'w')
    for line in lines:
        logfile.write("{}\n".format(line))
    logfile.close()
    


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))


