In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq
from tqdm.std import tqdm
import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T
import torch.utils.data as data
from torch_geometric.data import Batch
from torch_points3d.datasets.batch import SimpleBatch
import torch
import torch.utils.data as data
import numpy as np
import random
from torch_geometric.data import Batch, Data

In [None]:
yaml_config = """
conv_type: "SPARSE"
define_constants:
    in_feat: 128
    block: ResBlock # Can be any of the blocks in modules/SparseConv3d/modules.py
down_conv:
    module_name: ResNetDown
    block: block
    N: [0, 2, 3, 4, 3]
    down_conv_nn:
        [
            [FEAT, in_feat],
            [in_feat, in_feat],
            [in_feat, 2*in_feat],
            [2*in_feat, 4*in_feat],
            [4*in_feat, 8*in_feat],
        ]
    kernel_size: [3, 3, 3, 3, 3]
    stride: [1, 2, 2, 2, 2]
innermost:
    module_name: GlobalBaseModule
    activation:
        name: LeakyReLU
        negative_slope: 0.2
    aggr: "max"
    nn: [8*in_feat, 8*in_feat]
""" 

from omegaconf import OmegaConf
params = OmegaConf.create(yaml_config)

In [None]:
from torch_points3d.applications.sparseconv3d import SparseConv3d
from torch.nn import functional as F
class spConvregress(torch.nn.Module):
    def __init__(self):
        super().__init__() 
        #if use 1feat- input_nc=1, use 4feat - input_nc=4
        self.encoder = SparseConv3d("encoder", input_nc=1, num_layers=4, config=params) # minkowski by default
        self.linear = torch.nn.Linear(8*128, 4*128, bias=True)
        self.linear1 = torch.nn.Linear(4*128, 2*128, bias=True)
        self.linear2 = torch.nn.Linear(2*128, 128, bias=True)
        self.linear3 = torch.nn.Linear(128, 1, bias=False)
        self.bn = nn.BatchNorm1d(8*128)
        self.bn1 = nn.BatchNorm1d(4*128)
        self.bn2 = nn.BatchNorm1d(2*128)
        self.bn3 = nn.BatchNorm1d(1*128)
    def forward(self, data):
        # Set labels for the tracker
        
        # Forward through the network
        data_out = self.encoder(data)
        self.output = F.relu(self.bn(data_out.x.squeeze()))
        self.output =  F.relu(self.bn1(self.linear(self.output)))
        self.output =  F.relu(self.bn2(self.linear1(self.output)))
        self.output =  F.relu(self.bn3(self.linear2(self.output)))
        self.output =  F.relu(self.linear3(self.output))
        return self.output

In [None]:
model = spConvregress().cuda()


In [None]:
model

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=f'/path/to/save/sp3_convcell_band_super')

In [None]:
import json
with open("mp_id_data.json", "r") as f:
    dic = json.load(f)

In [None]:
import glob
pt_list = glob.glob('/path/to/pt/files/*.pt')

In [None]:
len(pt_list)

In [None]:
#in case of bandgap energy
datal = []
for i in tqdm(pt_list):
    data = torch.load(i)
    data.x = data.x[:, 0].view(-1,1)
    data.tar = torch.tensor(dic[i.split('/')[-1].split('.')[0]['bandgap']], dtype=torch.float32)
    datal.append(data)

In [None]:
#in case of formation energy

datal = []
for i in tqdm(pt_list):
    data = torch.load(i)
    data.x = data.x[:, 0].view(-1,1)
    data.tar = torch.tensor(dic[i.split('/')[-1].split('.')[0]['formationE']], dtype=torch.float32)
    datal.append(data)

In [None]:
len(data_list)

In [None]:
data_train, test_files = train_test_split(data_list, test_size=0.1, random_state=1)
train_files, val_files = train_test_split(data_train, test_size=0.13333, random_state=7)

In [None]:
num_epochs = 3000

collate_function = lambda datalist: Batch.from_data_list(datalist)
train_dataloader = torch.utils.data.DataLoader(
train_files, 
batch_size=32, 
shuffle=True, 
num_workers=0,
collate_fn=collate_function, pin_memory = True , 
)

collate_function = lambda datalist: Batch.from_data_list(datalist)
val_dataloader = torch.utils.data.DataLoader(
val_files, 
batch_size=32, 
shuffle=False, 
num_workers=0,  
collate_fn=collate_function, pin_memory = True, #sampler=val_sampler
)
import torch.nn as nn

#model = spConvregress().cuda()


collate_function = lambda datalist: Batch.from_data_list(datalist)
test_dataloader = torch.utils.data.DataLoader(
test_files, 
batch_size=32, 
shuffle=False, 
num_workers=0,  
collate_fn=collate_function, pin_memory = True, #sampler=val_sampler
)
import torch.nn as nn



In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=10
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model, optimizer
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, f'/home/ssd1/shkim/logs30/sp3_convcell_band_super/best_model.pth')

save_best_model = SaveBestModel()


In [None]:
from tqdm.auto import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch.optim.lr_scheduler import ReduceLROnPlateau
optimizer = optim.Adam(model.parameters(), lr=0.00005)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)
loss_function = torch.nn.L1Loss().cuda()
for epoch in range(3000):
     train_loss = 0.0
     train_total = 0
     train_mae = 0
     model.train()
     n = 1
     with tqdm(train_dataloader, unit="batch", total=len(train_dataloader)) as tepoch:
        for batch_data in tepoch:
                tepoch.set_description(f"Epoch {epoch}")

                batch_data = batch_data.to(device)
                optimizer.zero_grad()

                # Forward pass
                logits = model(batch_data)
                loss = loss_function(logits.squeeze(), batch_data.tar2.squeeze())

                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                tepoch.set_postfix(total_mae_loss=train_loss/n, batch_loss=loss.item())
                n = n + 1
                
        train_loss /= len(train_dataloader)

        print(f"Train Loss: MAE {train_loss:.4f}")
        
     model.eval()
     val_loss = 0.0
     val_correct = 0
     val_total = 0
     val_mae = 0
     with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(val_dataloader)):

            batch_data = batch_data.cuda()
            logits = model(batch_data)
            loss = loss_function(logits.squeeze(), batch_data.tar2.squeeze())

            val_loss += loss.item()


        val_loss /= len(val_dataloader)
        scheduler.step(val_loss)

        save_best_model(val_loss, epoch, model, optimizer)

        writer.add_scalars("every_1_epoch", {"Loss/train_2":train_loss,
                                "Loss/val_2":val_loss,
                                "learning_Rate":scheduler.optimizer.param_groups[0]['lr']},epoch)

        print(f"Validation Loss:  MAE {val_loss:.4f}")
      

In [None]:
 model.eval()
 val_loss = 0.0
 val_correct = 0
 val_total = 0
 val_mae = 0
 with torch.no_grad():
    for batch_idx, batch_data in enumerate(tqdm(val_dataloader)):

        batch_data = batch_data.cuda()
        logits = model(batch_data)
        loss = loss_function(logits.squeeze(), batch_data.tar.squeeze())

        val_loss += loss.item()


    val_loss /= len(val_dataloader)
    scheduler.step(val_loss)

    print(f"Validation Loss:  MAE {val_loss:.4f}")