In [None]:
import torch
import torch.nn as nn
import torch_geometric.nn as geom_nn
from model import CrystalGraphConvNet  # Assuming ConvLayer is correctly imported from your cgcnn model file
import os
from data import CIFData
from data import collate_pool, get_train_val_test_loader
import csv
from torch.optim.lr_scheduler import MultiStepLR
import shutil

In [None]:
atom_fea_len = 64
hidden_size = 256
batch_size = 256
epochs = 200
learning_rate = 0.001
best_val_loss = float('inf')  # Initialize with a high value
save_path = 'best_model.pth'
n_conv = 8
# Load dataset




In [None]:
dataset = CIFData("root_dir")

In [None]:

train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset=dataset,
    collate_fn=collate_pool,
    batch_size=batch_size,
    train_ratio=0.8,
    num_workers=0,
    val_ratio=0.1,
    test_ratio=0.1,
    pin_memory=True,
    train_size=None,
    val_size=None,
    test_size=None,
    return_test=True)

In [None]:
print(len(dataset))

In [None]:
from random import sample

class Normalizer(object):
    """Normalize a Tensor and restore it later. """

    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']

sample_data_list = [dataset[i] for i in
                    sample(range(len(dataset)), 500)]
_, sample_target, _ = collate_pool(sample_data_list)
normalizer = Normalizer(sample_target)


In [None]:
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
print("Original Atom Feature length: ",orig_atom_fea_len)
nbr_fea_len = structures[1].shape[-1]

# Initialize model, loss, optimizer
model = CrystalGraphConvNet(orig_atom_fea_len,nbr_fea_len, atom_fea_len,n_conv,hidden_size).cuda()
criterion = nn.SmoothL1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
best_model = 0

scheduler = MultiStepLR(optimizer, milestones=[100],
                            gamma=0.1)

best_mae_error = 1e10

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def mae(prediction, target):
    """
    Computes the mean absolute error between prediction and target

    Parameters
    ----------

    prediction: torch.Tensor (N, 1)
    target: torch.Tensor (N, 1)
    """
    return torch.mean(torch.abs(target.to('cpu') - prediction))


In [None]:
def validate(val_loader, model, criterion, normalizer, test=False):
    
    test_targets = []
    test_preds = []
    test_cif_ids = []
    model.eval()  # Set model to evaluation mode
    mae_errors = AverageMeter()
    with torch.no_grad():
        for features, target, cif_id in val_loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx, target = (features[0].cuda(), features[1].cuda(), 
                                                      features[2].cuda(), [crys_idx.cuda(non_blocking=True) for crys_idx in features[3]], target.cuda())
            
            target_normed = normalizer.norm(target)
            target_var = target_normed.cuda(non_blocking=True)
            output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
            loss = criterion(output, target_var)
            
            mae_error = mae(normalizer.denorm(output.data.cpu()), target)
            mae_errors.update(mae_error, target.size(0))
            if test:
                test_pred = normalizer.denorm(output.data.cpu())
                test_target = target
                test_preds += test_pred.view(-1).tolist()
                test_targets += test_target.view(-1).tolist()
                test_cif_ids += cif_id
    
        
            
        if test:
            star_label = "**"
        else:
            star_label = "*"
        if True:
            print(' {star} MAE {mae_errors.avg:.3f}'.format(star=star_label,
                                                            mae_errors=mae_errors))
        
        if test:
            return (mae_errors.avg,test_preds,test_targets,test_cif_ids)
        else:
            return (mae_errors.avg)

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, normalizer):
    model.train()
    mae_errors = AverageMeter()
    losses = AverageMeter()
    for i, (input, target, _) in enumerate(train_loader):
        atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx, target = (input[0].cuda(), input[1].cuda(), 
                                                  input[2].cuda(), [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]], target.cuda())
        
        target_normed = normalizer.norm(target)
        target_var = target_normed.cuda(non_blocking=True)
        
        
        # Forward pass
        output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        loss = criterion(output, target_var)
        
        mae_error = mae(normalizer.denorm(output.data.cpu()), target)
        losses.update(loss.data.cpu(), target.size(0))
        mae_errors.update(mae_error, target.size(0))
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i%10==0:
            print('Epoch: [{0}][{1}/{2}]\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format(
                        epoch, i, len(train_loader), loss=losses, mae_errors=mae_errors)
                    )

In [None]:
best_val_loss = float('inf')
save_path = 'best_model.pth'

for epoch in range(epochs):
    
    train(train_loader, model, criterion, optimizer, epoch, normalizer)

    scheduler.step()
    test_loss,test_preds,test_targets,test_cif_ids = validate(test_loader, model, criterion, normalizer, True)
    if test_loss < best_val_loss:
            best_val_loss = test_loss
            best_model = model
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved with validation loss: {best_val_loss:.4f}, epoch: {epoch}")
            star_label = '**'
            import csv
            with open('test_results.csv', 'w') as f:
                writer = csv.writer(f)
                for cif_id, target, pred in zip(test_cif_ids, test_targets,
                                                test_preds):
                    writer.writerow((cif_id, target, pred))
        
