<a href="https://colab.research.google.com/github/sparks-baird/xtal2png/blob/main/2.1-xtal2png-cnn-classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# `xtal2png` Convolutional Neural Network Classification

Using Matbench task of `mp_is_metal`.

In [None]:
%pip install skorch xtal2png matbench pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import numpy as np
import pytorch_lightning as pl
import torch
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from matbench.bench import MatbenchBenchmark
from skorch.callbacks import EarlyStopping
from skorch.classifier import NeuralNetBinaryClassifier
from torch import nn
from xtal2png.core import XtalConverter

# Set all random seeds as specified by Matbench
pl.seed_everything(18012019)

Global seed set to 18012019


18012019

## Matbench Setup

In [None]:
mb = MatbenchBenchmark(autoload=False, subset=["matbench_mp_is_metal"])

2022-06-29 02:43:52 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_mp_is_metal']


## CNN Architecture

Architecture of convolutional layers:
```python
        self.convolutions = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=3, padding=1),      # (64, 64, 1) --> (64, 64, 3)
            nn.BatchNorm2d(3),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),          # (31, 31, 3)

            nn.Conv2d(3, 8, kernel_size=3, padding=1),      # (31, 31, 8)
            nn.BatchNorm2d(8),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),          # (15, 15, 8)

            nn.Conv2d(8, 16, kernel_size=3, padding=1),     # (15, 15, 16)
            nn.BatchNorm2d(16),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),          # (7, 7, 16)
        )
```

In [None]:
class CNNClassifier(nn.Module):
    def __init__(self, dropout: float = 0.5) -> None:
        super().__init__()
        self.convolutions = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(3, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.fullyconnected = nn.Sequential(
            nn.Linear(7 * 7 * 16, 512),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 256),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, 256),
            nn.LeakyReLU(inplace=True),
            nn.Linear(256, 1),
            # No need for sigmoid here if using BCEWithLogitsLoss
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        x = torch.flatten(x, 1)  # flatten all but batch dim
        x = self.fullyconnected(x)
        return x


## Matbench Folds

In [None]:
xc = XtalConverter()

for task in mb.tasks:
    task.load()
    for fold in task.folds:
        # Get training data
        train_inputs, train_outputs = task.get_train_and_val_data(fold)

        # Keep structures with num_sites <= 52
        site_counter = lambda x: x.num_sites
        idx = train_inputs.apply(site_counter) <= 52
        X_train = train_inputs[idx]
        y_train = train_outputs[idx]

        # Convert crystal structures to images
        X_train = xc.xtal2png(X_train)

        # Convert PIL Images to torch.Tensor
        # Note that this scales from [0, 255] to [0.0, 1.0]
        X_train = [TF.to_tensor(img) for img in X_train]
        # Normalize images (subtract mean, divide by std)
        mean = torch.cat(X_train).mean()
        std = torch.cat(X_train).std()
        X_train = [TF.normalize(i, mean=mean, std=std) for i in X_train]

        # Change lists to arrays and y_train dtype to float32 for skorch
        X_train = np.array(X_train)
        y_train = y_train.astype(np.float32)

        # Train and validate classifier
        net = NeuralNetBinaryClassifier(
            CNNClassifier,
            criterion=nn.BCEWithLogitsLoss,
            max_epochs=50,
            optimizer=optim.AdamW,
            optimizer__amsgrad=True,
            optimizer__lr=0.0005,
            callbacks=[EarlyStopping(patience=10)],
            device="cuda" if torch.cuda.is_available() else "cpu",
            batch_size=32,
        )
        net.fit(X_train, y_train)

        # Get test data and keep structures with num_sites <= 52
        test_inputs = task.get_test_data(fold)
        idx = test_inputs.apply(site_counter) <= 52
        X_test = test_inputs[idx]

        # Convert to images, preprocess using mean and std from training data
        X_test = xc.xtal2png(X_test)
        preprocess = T.Compose([T.ToTensor(), T.Normalize(mean, std),])
        X_test = np.array([preprocess(img) for img in X_test])

        # Predict on X_test
        # For structures with num_sites > 52, predict mode of y_train
        y_pred = net.predict(X_test)
        y_pred_full = np.empty(test_inputs.size)
        y_pred_full[idx] = y_pred
        y_pred_full[~idx] = y_train.mode().item()

        # Record data
        task.record(fold, y_pred_full)

# Save results
mb.to_file("my_models_benchmark.json.gz")

2022-06-29 01:37:33 INFO     Loading dataset 'matbench_mp_is_metal'...
Fetching matbench_mp_is_metal.json.gz from https://ml.materialsproject.org/projects/matbench_mp_is_metal.json.gz to /usr/local/lib/python3.7/dist-packages/matminer/datasets/matbench_mp_is_metal.json.gz


Fetching https://ml.materialsproject.org/projects/matbench_mp_is_metal.json.gz in MB: 136.699904MB [00:00, 575.80MB/s]                                                


2022-06-29 01:40:31 INFO     Dataset 'matbench_mp_is_metal loaded.


100%|██████████| 71602/71602 [17:16<00:00, 69.07it/s]
  f"lower RGB value(s) OOB ({mn} less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)"  # noqa: E501
  f"upper RGB value(s) OOB ({mx} greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)"  # noqa: E501


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.6061[0m       [32m0.6330[0m        [35m0.6784[0m  19.4657
      2        [36m0.5143[0m       0.6281        0.6843  18.3359
      3        [36m0.4907[0m       [32m0.6410[0m        0.6852  18.0493
      4        [36m0.4786[0m       [32m0.6467[0m        0.7220  18.0379
      5        [36m0.4625[0m       0.6409        0.7409  17.8828
      6        [36m0.4553[0m       [32m0.6495[0m        0.7136  18.3212
      7        [36m0.4489[0m       [32m0.6513[0m        0.7080  17.9975
      8        [36m0.4437[0m       [32m0.6554[0m        0.6954  18.1166
      9        [36m0.4379[0m       [32m0.6573[0m        0.7202  17.8346
     10        [36m0.4328[0m       [32m0.6578[0m        0.7154  17.6820
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 17925/17925 [04:20<00:00, 68.85it/s]
  f"lower RGB value(s) OOB ({mn} less than 0). thresholding to 0.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)"  # noqa: E501
  f"upper RGB value(s) OOB ({mx} greater than 255). thresholding to 255.. may throw off crystal structure parameters (e.g. if lattice parameters are thresholded)"  # noqa: E501


2022-06-29 02:06:56 INFO     Recorded fold matbench_mp_is_metal-0 successfully.


100%|██████████| 71644/71644 [17:25<00:00, 68.53it/s]
