In [42]:
from torch.nn.utils import clip_grad_norm_
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

import numpy as np
import random 
import time
import matplotlib.pyplot as plt

: 

In [40]:
#* FC part
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class fc(nn.Module):
    def __init__(self, a_tilts):
        super(fc, self).__init__()

        self.a_tilts = a_tilts
        self.channels1 = nn.ModuleList([nn.Sequential(
            nn.Linear(32,32),
            nn.ReLU(),
        ) for _ in range(len(a_tilts))])

        self.channels2 = nn.ModuleList([nn.Sequential(
            nn.Linear(32,32),
            nn.ReLU(),
        ) for _ in range(len(a_tilts))])

        self.output1 = nn.Sequential(
            nn.Linear(len(a_tilts) * 32, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        self.output2 = nn.Linear(32, 1)

    def forward(self, atom_list):
        
        input_state = torch.mm(self.a_tilts[0], atom_list.T).T
        channels = self.channels1[0](input_state)
        channels = torch.mm(self.a_tilts[0], channels.T).T
        channels = self.channels2[0](channels)

        for i in range(1, len(self.a_tilts)):
            input_state = torch.mm(self.a_tilts[i], atom_list.T).T
            channels = torch.cat((channels, self.channels1[i](input_state)), dim=1)

        print(channels.shape)
        input_state = torch.mm(self.a_tilts[0], channels[0]).T
        channels_ = self.channels2[0](input_state)
        for i in range(1, len(self.a_tilts)):
            input_state = torch.mm(self.a_tilts[i], channels[i]).T
            channels_ = torch.cat((channels_, self.channels2[i](input_state)), dim=1)

        channels_ = torch.cat(channels_, dim=1).flatten()
        s = self.output1(channels_)
        
        return s

In [41]:
#* Training step
mini_batchsize = 16
lr_ = 1e-5
train_step = 50000
epoch_per_episode = 16
random_seed = 369

date = '20230211msad_GNN'
path_save = f'./runs/{date}'
a_tilt = np.array([
    np.load(f'./laplacian/d_{i}nn.npy') for i in range(1, 7)
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#* The raw embedding list and corresponding MSAD value
pth_ele = '/media/wz/a7ee6d50-691d-431a-8efb-b93adc04896d/Github/MATools/CE_MC/runs/demo/20221216_msadGA/ele_list_all.npy'
pth_msad = '/media/wz/a7ee6d50-691d-431a-8efb-b93adc04896d/Github/MATools/CE_MC/runs/demo/20221216_msadGA/msad_list_all.npy'
weight_raw = np.load(pth_ele)
msad_list = np.load(pth_msad)

mean_msad = np.mean(msad_list)
# var_msad = np.var(msad_list)
#* Norm.
msad_raw = (msad_list.reshape(-1,1)-mean_msad)

#* And pass to GPU.
weight_raw = torch.from_numpy(weight_raw).to(device).float()
msad_raw = torch.from_numpy(msad_raw).to(device).float()
a_tilt = torch.from_numpy(a_tilt).to(device).float()

#*Device is defined in former block
fc_ = fc(a_tilt).to(device)
fc_optim = torch.optim.Adam(fc_.parameters(), lr = lr_)
scheduler = torch.optim.lr_scheduler.StepLR(fc_optim,step_size=10000,gamma = 0.98)
mse_loss = nn.MSELoss()
# writer = SummaryWriter(log_dir = path_save)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

train_loss, test_loss = [], []
#* Divide dataset.
weight_train, weight_test, msad_train, msad_test = train_test_split(
    weight_raw, msad_raw, train_size=0.8, random_state=random_seed)

for i in range(train_step):
    # for index in BatchSampler(SubsetRandomSampler(range(len(weight_train))), mini_batchsize, True):
        # for epoch in range(epoch_per_episode):

    fc_.train()
    msad_out_train = fc_(weight_train)
    msad_tar_train = msad_train
    msad_loss_train = mse_loss(msad_out_train, msad_tar_train)

    # writer.add_scalar("Training Loss of MSAD", msad_loss_train, i)
    msad_loss_train_ = msad_loss_train
    train_loss.append(msad_loss_train_.detach().cpu().numpy().flatten()[0])

    fc_optim.zero_grad()
    msad_loss_train.backward()
    clip_grad_norm_(fc_.parameters(), 0.5)
    fc_optim.step()
    scheduler.step()

    fc_.eval()
    msad_out_test = fc_(weight_test)
    msad_loss_test = mse_loss(msad_out_test, msad_test)
    # writer.add_scalar("Testing Loss of MSAD", msad_loss_test, i)
    msad_loss_test_ = msad_loss_test
    test_loss.append(msad_loss_test_.detach().cpu().numpy().flatten()[0])

    if i % 100 == 0:
        clear_output(True)

        plt.plot(train_loss, label='Training loss', alpha=0.6)
        plt.plot(test_loss, label='Testing loss', alpha=0.6)

        plt.title(f'{np.min(test_loss)}')
        plt.legend()
        plt.show()

torch.Size([320, 192])


RuntimeError: mat2 must be a matrix