In [1]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install torch-geometric


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m96.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch-geometric
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch-geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910454 sha256=25790bfc60df7330a77e26f

In [4]:

import numpy as np
import os
import glob
from scipy.stats import pearsonr
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader

In [7]:
class KbDataset44(InMemoryDataset):
    def __init__(self, root, split='train', transform=None, pre_transform=None, pre_filter=None):

        super().__init__(root, transform, pre_transform, pre_filter)
        self.root = root
        if split == 'train':
          self.data, self.slices = torch.load(self.processed_paths[0])
        elif split=='test':
          self.data, self.slices = torch.load(self.processed_paths[1])


    @property
    def raw_file_names(self):
        return glob.glob(os.path.join(self.raw_dir,'tem_044/data*'))

    @property
    def processed_file_names(self):
        return ['train_data_tem_044.pt','test_data_tem_044.pt' ]

    def process(self):

        splits = ['train', 'test']
        i= 0
        for split in splits:
          print('tem_044'+'_'+split)
          path = os.path.join(self.raw_dir,f'tem_044/{split}')
          datalist = self.read_data(path)

          torch.save(self.collate(datalist), self.processed_paths[i])
          i+=1

    def read_data(self,path):
        datalist =[]
        for i in range(400):
            print(i)

            npz_data = np.load(os.path.join(path,f'data_{i}.npz'))
            type = npz_data['type']
            type_mask = type == 0
            target = np.array(npz_data['target']).T[type_mask]

            ### node feature
            x = torch.from_numpy(type).to(torch.long)
            box = torch.from_numpy(npz_data['box']).to(torch.float)


            edge_index = torch.from_numpy(npz_data['edge_index']).to(torch.long)

            edge_weight = torch.from_numpy(npz_data['edge_weight']).to(torch.float)

            pos = torch.from_numpy(npz_data['pos']).to(torch.float)
            cross_pos = pos[edge_index[0]] - pos[edge_index[1]]
            cross_pos += (cross_pos < -box/2) * box
            cross_pos -= (cross_pos > box/2) * box



            datalist.append(Data(x = x,edge_index=edge_index,edge_attr = cross_pos,edge_weight = edge_weight,box=box, pos =pos,mask = torch.from_numpy(type_mask).to(torch.bool),y=torch.from_numpy(target).to(torch.float)))
        return datalist

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
import sys
sys.path.append('/content/drive/MyDrive/Geo-GNN')
import os
os.chdir('/content/drive/MyDrive/Geo-GNN')
from model import Geo_GNN

In [12]:
root =  '/content/drive/MyDrive'

train_dataset = KbDataset44(root=root,split='train')
test_dataset = KbDataset44(root=root,split='test')

import random
index = random.shuffle(np.arange(0,400))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Geo_GNN(hidden_channels = 32, out_channels=1,num_gaussians = 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 20, gamma=0.50)

In [20]:
def train(time_index =1):
    model.train()
    total_loss = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        type = data.x.to(torch.long)

        y = data.y[:,time_index]

        y= y.unsqueeze(-1)

        mask = data.mask.to(torch.bool)
        optimizer.zero_grad()
        out,_ = model(data)

        loss = F.mse_loss(out[mask],y)
        loss.backward()
        total_loss += loss.item() * data.num_nodes
        optimizer.step()
        if (i + 1) % 100 == 0:
          print(f'[{i+1}/{len(train_loader)}] Loss: {loss.item()*data.num_nodes:.4f}')

    return total_loss / len(train_dataset)

def test(loader, time_index = 1):
    model.eval()

    correct = []
    for data in loader:
        data = data.to(device)
        # print(data.batch)
        # print(data.mask)
        mask = data.mask.cpu().numpy().astype(bool)
        # print(mask)
        with torch.no_grad():
            pred,_ = model(data)
        pred = pred.reshape(-1).cpu().numpy()
        # print(pred[mask].size)
        y = data.y[:,time_index]
        y = y.cpu().numpy()
        # print(y.size)
        pe = pearsonr(pred[mask],y)[0]

        correct.append(pe)

    return correct

In [21]:

train_loss = []

start_epoch = 1

RESUME = False
if RESUME:
    path_checkpoint = '/content/drive/MyDrive/Geo-GNN.pt'
    checkpoint = torch.load(path_checkpoint)  
    model.load_state_dict(checkpoint['net'])  
    optimizer.load_state_dict(checkpoint['optimizer']) 
    start_epoch = checkpoint['epoch']  
    scheduler.load_state_dict(checkpoint['lr_schedule'])


for epoch in range(start_epoch, 100):
  loss = train()
  # loss = train_2()
  if (epoch+1)%2==0:
    checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch,
        'lr_schedule': scheduler.state_dict()
    }
    torch.save(checkpoint, '/content/drive/MyDrive/Geo-GNN.pt')

    test_per = test(test_loader)
    test_per = np.array(test_per)

    print(test_per)
    print('pearson correlation is {}'.format(np.abs(test_per).mean()))
    
  train_loss.append(loss)
  scheduler.step()
  print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')

test_per = test(test_loader)
test_per = np.array(test_per)
print(test_per)
print('pearson correlation is {}'.format(test_per.mean()))

[100/400] Loss: 6.3811
[200/400] Loss: 5.6964
[300/400] Loss: 6.2137
[400/400] Loss: 6.0258
[0.35133445 0.32811844 0.33262837 0.34912184 0.3392836  0.34351992
 0.33658711 0.32500134 0.30233477 0.36944876 0.30722782 0.33785482
 0.35049719 0.33725453 0.35978087 0.3256073  0.3375921  0.33697072
 0.30126529 0.34396631 0.29891344 0.36828359 0.3355092  0.35949802
 0.38150904 0.35937514 0.33826641 0.36051562 0.3296302  0.33013637
 0.34327102 0.35299163 0.33073045 0.34280988 0.30775143 0.33320455
 0.3389717  0.33625826 0.36124061 0.30475132 0.32556877 0.34882314
 0.33158325 0.34154059 0.33413    0.32459031 0.32161597 0.33251029
 0.33232153 0.30032525 0.33581017 0.33233774 0.34458001 0.32665735
 0.35241439 0.32614747 0.35341052 0.35529302 0.33839913 0.36975287
 0.3281903  0.3411519  0.30573766 0.32668303 0.36280431 0.33479373
 0.32961418 0.35226136 0.37486382 0.34673317 0.32478233 0.34464846
 0.31589321 0.33834469 0.29778944 0.347516   0.36354283 0.34716387
 0.3363832  0.3299966  0.33915556 0.3