[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sparks-baird/xtal2png/blob/cnn-classification/notebooks/2.1-xtal2png-cnn-classification.ipynb)

# Classification on `matbench_mp_is_metal` using `xtal2png` representation of crystal structures

## Description
In this notebook, a convolutional neural network is applied to the `matbench_mp_is_metal` classification task using [`xtal2png`](https://xtal2png.readthedocs.io/en/latest/) representations of crystal structures. Crystal structures are encoded as grayscale PNG images, but because the conversion operations are restricted to structures with fewer than 52 sites, the network is only trained on structures with `num_sites <= 52`. For structures in the test set with more than 52 sites, we simply predict the mode of the training outputs (i.e. the most common class in `y_train`, where `X_train` and `y_train` correspond to training inputs and labels respectively, with `num_sites <= 52`).

## Benchmark Name
Matbench v0.1

## Package Versions
- [matbench](https://anaconda.org/conda-forge/matbench)==0.5.0
- [xtal2png](https://anaconda.org/conda-forge/xtal2png)==0.7.0
- [pytorch](https://anaconda.org/pytorch/pytorch)==1.11.0
- [skorch](https://anaconda.org/conda-forge/skorch)==0.11.0
- [pytorch-lightning](https://anaconda.org/conda-forge/pytorch-lightning)==1.6.4
- [mosaicml](https://anaconda.org/mosaicml/mosaicml)==0.8.0

## Algorithm Description
A fairly simple CNN is created in vanilla PyTorch, very loosely based on the PyTorch implementation of [AlexNet](https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py). Model surgery is then performed on the max-pooling and certain convolutional layers using MosaicML's [Composer](https://github.com/mosaicml/composer) library.


In [1]:
%pip install matbench skorch xtal2png pytorch-lightning mosaicml

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matbench
  Downloading matbench-0.5-py3-none-any.whl (9.9 MB)
[K     |████████████████████████████████| 9.9 MB 5.2 MB/s 
[?25hCollecting skorch
  Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
[K     |████████████████████████████████| 155 kB 58.5 MB/s 
[?25hCollecting xtal2png
  Downloading xtal2png-0.8.0-py3-none-any.whl (30 kB)
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 60.5 MB/s 
[?25hCollecting mosaicml
  Downloading mosaicml-0.8.0-py3-none-any.whl (548 kB)
[K     |████████████████████████████████| 548 kB 54.0 MB/s 
[?25hCollecting monty==2021.8.17
  Downloading monty-2021.8.17-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 4.4 MB/s 
[?25hCollecting matminer==0.7.4
  Downloading matminer-0.7.4-py3-none-any.whl (1.4 MB)
[K     |████████


### Imports

In [2]:
# %pip install skorch xtal2png matbench pytorch-lightning mosaicml

import composer.functional as cf
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from matbench.bench import MatbenchBenchmark
from skorch.callbacks import EarlyStopping
from skorch.classifier import NeuralNetBinaryClassifier
from torch import nn
from xtal2png.core import XtalConverter

# Set all random seeds as specified by Matbench
pl.seed_everything(18012019)

Global seed set to 18012019


18012019

### CNN Architecture
For the vanilla PyTorch model, the architecture of the convolutional layers is as follows:
```python
self.convolutions = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=3, padding=1),  # (64, 64, 1) --> (64, 64, 8)
    nn.BatchNorm2d(8),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (31, 31, 8)

    nn.Conv2d(8, 16, kernel_size=3, padding=1),     # --> (31, 31, 16)
    nn.BatchNorm2d(16),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (15, 15, 16)
    
    nn.Conv2d(16, 32, kernel_size=3, padding=1),    # --> (15, 15, 32)
    nn.BatchNorm2d(32),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (7, 7, 32)
)
```
The full `CNNClassifier` class is defined below:

In [None]:
class CNNClassifier(nn.Module):
    def __init__(self, dropout: float = 0.5) -> None:
        super().__init__()
        self.convolutions = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.fullyconnected = nn.Sequential(
            nn.Linear(7 * 7 * 32, 512),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 256),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, 256),
            nn.Mish(inplace=True),
            nn.Linear(256, 1),
            # No need for sigmoid here if using BCEWithLogitsLoss
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        x = torch.flatten(x, 1)  # flatten all but batch dim
        x = self.fullyconnected(x)
        return x

To get slightly better generalization performance, [BlurPool](https://docs.mosaicml.com/en/latest/method_cards/blurpool.html) and [squeeze-and-excite](https://docs.mosaicml.com/en/latest/method_cards/squeeze_excite.html) operations were applied to the model using Composer. BlurPool layers replace all max pooling layers, and squeeze-excite layers replace certain convolutional layers with channels above a threshold. Below is the full architecture of the model:

```python
>>> model = CNNClassifier()
>>> composer.functional.apply_squeeze_excite(model, min_channels=16)
>>> composer.functional.apply_blurpool(model)
```
```
CNNClassifier(
  (convolutions): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish(inplace=True)
    (3): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Mish(inplace=True)
    (7): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): SqueezeExciteConv2d(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (se): SqueezeExcite2d(
        (pool_and_mlp): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Flatten(start_dim=1, end_dim=-1)
          (2): Linear(in_features=32, out_features=64, bias=False)
          (3): ReLU()
          (4): Linear(in_features=64, out_features=32, bias=False)
          (5): Sigmoid()
        )
      )
    )
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Mish(inplace=True)
    (11): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fullyconnected): Sequential(
    (0): Linear(in_features=1568, out_features=512, bias=True)
    (1): Mish(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): Mish(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): Mish(inplace=True)
    (8): Linear(in_features=256, out_features=1, bias=True)
  )
)
```

## Benchmark on Matbench Folds
Training is done using [skorch](https://skorch.readthedocs.io/en/stable/) to abstract away the typical training loop and need for DataLoaders. Images are preprocessed in the following manner:
- Convert from `PIL.Image` to `torch.Tensor` and scale all pixel values to `[0.0, 1.0]`.
- Compute the mean and standard deviation of scaled pixel values, then normalize to zero-mean, unit variance.

For normalization, note that the mean and standard deviations of pixel values are calculated separately per training fold. In each fold, the statistics of the training set are also used to normalize the respective test set.

Matbench will keep track of the final test results for the full test set, but because we're only able to train and predict on structures with 52 sites or fewer, it would be good to keep track of how well we do on just the subset of the data with `num_sites <= 52`. To that end, a simple helper function is defined below.

In [None]:
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    roc_auc_score,
)


def scoring(y_true, y_pred, fold):
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_pred)
    scores = {
        "Accuracy": acc,
        "Balanced Accuracy": bal_acc,
        "F1 Score": f1,
        "ROC AUC": roc_auc,
    }
    return pd.Series(scores, name=fold)


# track test_scores for subset of test_data with num_sites <= 52
subset_test_scores = []

The matbench benchmark is run below. To save all image representations created by the `xtal2png` function during crystal to image conversion, set the `save` flag to `True`.

In [None]:
xc = XtalConverter()
mb = MatbenchBenchmark(autoload=False, subset=["matbench_mp_is_metal"])
save = False  # xtal2png(data, save=save)

for task in mb.tasks:
    task.load()
    for fold in task.folds:
        # Get training data
        train_inputs, train_outputs = task.get_train_and_val_data(fold)

        # Train on structures with num_sites <= 52
        site_counter = lambda x: x.num_sites
        idx = train_inputs.apply(site_counter) <= 52
        X_train = train_inputs[idx]
        y_train = train_outputs[idx]

        # Convert crystal structures to images
        X_train = xc.xtal2png(X_train, save=save)

        # Convert PIL Images to torch.Tensor
        # Note that this scales from [0, 255] to [0.0, 1.0]
        X_train = [TF.to_tensor(img) for img in X_train]
        # Normalize images (subtract mean, divide by std)
        mean = torch.cat(X_train).mean()
        std = torch.cat(X_train).std()
        X_train = [TF.normalize(i, mean=mean, std=std) for i in X_train]

        # Change X from a list of tensors to a single tensor, and y from bool to float
        X_train = torch.stack(X_train)
        y_train = y_train.astype(np.float32)

        # Apply Composer methods to vanilla PyTorch classifier before training
        model = CNNClassifier()
        cf.apply_squeeze_excite(model, min_channels=16)
        cf.apply_blurpool(model)

        # Train and validate classifier with skorch
        net = NeuralNetBinaryClassifier(
            model,
            criterion=nn.BCEWithLogitsLoss,
            max_epochs=50,
            optimizer=optim.AdamW,
            optimizer__amsgrad=True,
            optimizer__lr=0.0005,
            callbacks=[EarlyStopping(patience=15)],
            device="cuda" if torch.cuda.is_available() else "cpu",
            batch_size=64,
        )
        net.fit(X_train, y_train)

        # Get test data and keep structures with num_sites <= 52
        test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
        idx = test_inputs.apply(site_counter) <= 52
        X_test = test_inputs[idx]

        # Convert to images, preprocess using mean and std from training data
        X_test = xc.xtal2png(X_test, save=save)
        preprocess = T.Compose([T.ToTensor(), T.Normalize(mean, std),])
        X_test = torch.stack([preprocess(img) for img in X_test])

        # Predict on X_test
        # For structures with num_sites > 52, predict mode of y_train
        y_pred = net.predict(X_test)
        y_pred_full = np.empty(test_inputs.size)
        y_pred_full[idx] = y_pred
        y_pred_full[~idx] = y_train.mode().item()

        # Record data
        task.record(fold, y_pred_full)
        # Also record test scores on subset of data with num_sites <= 52
        subset_test_scores.append(scoring(test_outputs[idx], y_pred, f"fold-{fold}"))

# Save benchmark results
mb.to_file("new-results.json.gz")

2022-07-08 16:49:57 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_mp_is_metal']
2022-07-08 16:49:57 INFO     Loading dataset 'matbench_mp_is_metal'...
2022-07-08 16:51:29 INFO     Dataset 'matbench_mp_is_metal loaded.


100%|██████████| 71602/71602 [06:18<00:00, 189.39it/s]
  warn(
  warn(


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.6073[0m       [32m0.6224[0m        [35m0.6698[0m  14.7695
      2        [36m0.5356[0m       0.6208        0.7271  10.9460
      3        [36m0.5118[0m       [32m0.6481[0m        0.7283  10.8090
      4        [36m0.5024[0m       0.6467        0.7483  10.8945
      5        [36m0.4809[0m       [32m0.6736[0m        0.6756  10.7665
      6        [36m0.4701[0m       0.6716        0.6849  10.7760
      7        [36m0.4650[0m       0.6680        0.6896  10.8740
      8        [36m0.4538[0m       0.6705        0.6834  11.3155
      9        [36m0.4480[0m       0.6702        0.6814  10.9335
     10        [36m0.4457[0m       0.6669        0.6971  10.7680
     11        [36m0.4374[0m       0.6736        0.6725  10.9445
     12        [36m0.4343[0m       0.6710        0.6902  10.9260
     13        [36m0.4292[0m       0.6684 

100%|██████████| 17925/17925 [01:38<00:00, 181.45it/s]
  warn(
  warn(


2022-07-08 17:08:07 INFO     Recorded fold matbench_mp_is_metal-0 successfully.


100%|██████████| 71644/71644 [06:18<00:00, 189.14it/s]
  warn(
  warn(


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.5929[0m       [32m0.6283[0m        [35m0.6789[0m  11.1935
      2        [36m0.5243[0m       [32m0.6296[0m        0.6866  11.0007
      3        [36m0.5031[0m       0.6284        0.7424  10.8800
      4        [36m0.4796[0m       [32m0.6359[0m        0.7158  10.8855
      5        [36m0.4669[0m       [32m0.6471[0m        0.7155  10.8195
      6        [36m0.4543[0m       0.6410        0.7543  10.9535
      7        [36m0.4490[0m       [32m0.6537[0m        0.7031  10.8030
      8        [36m0.4395[0m       [32m0.6616[0m        0.7039  10.7705
      9        [36m0.4372[0m       0.6529        0.7108  10.9330
     10        [36m0.4305[0m       0.6606        0.7028  10.9995
     11        [36m0.4266[0m       0.6587        0.7050  10.9490
     12        [36m0.4226[0m       [32m0.6655[0m        0.7067  10.8910
     13 

100%|██████████| 17883/17883 [01:48<00:00, 164.18it/s]
  warn(
  warn(


2022-07-08 17:19:58 INFO     Recorded fold matbench_mp_is_metal-1 successfully.


100%|██████████| 71604/71604 [06:31<00:00, 182.87it/s]
  warn(
  warn(


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.5957[0m       [32m0.6205[0m        [35m0.6803[0m  11.3735
      2        [36m0.5308[0m       [32m0.6339[0m        0.7092  11.3055
      3        [36m0.5120[0m       0.6208        0.7561  11.3290
      4        [36m0.4911[0m       [32m0.6561[0m        0.7014  11.4160
      5        [36m0.4756[0m       [32m0.6729[0m        [35m0.6585[0m  11.1370
      6        [36m0.4651[0m       0.6716        0.6809  11.1925
      7        [36m0.4564[0m       0.6707        0.7054  11.2350
      8        [36m0.4511[0m       0.6703        0.7149  11.3085
      9        [36m0.4455[0m       0.6683        0.7382  11.2265
     10        [36m0.4437[0m       [32m0.6769[0m        0.7111  11.2260
     11        [36m0.4384[0m       0.6573        0.7956  11.1905
     12        [36m0.4339[0m       0.6650        0.7453  11.2540
     13        [

100%|██████████| 17923/17923 [01:42<00:00, 174.96it/s]
  warn(
  warn(


2022-07-08 17:32:45 INFO     Recorded fold matbench_mp_is_metal-2 successfully.


100%|██████████| 71599/71599 [06:27<00:00, 184.65it/s]
  warn(
  warn(


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.6078[0m       [32m0.6161[0m        [35m0.6687[0m  11.1820
      2        [36m0.5410[0m       0.5843        0.7135  10.9360
      3        [36m0.5274[0m       [32m0.6334[0m        0.7075  10.8940
      4        [36m0.5167[0m       0.6174        0.7411  10.7925
      5        [36m0.4938[0m       0.6011        0.8099  10.8460
      6        [36m0.4825[0m       0.6122        0.7850  11.2060
      7        [36m0.4710[0m       0.6226        0.7418  10.8695
      8        [36m0.4610[0m       [32m0.6394[0m        0.7188  10.9045
      9        [36m0.4559[0m       [32m0.6555[0m        0.7164  10.8600
     10        [36m0.4507[0m       0.6457        0.7033  10.8250
     11        [36m0.4424[0m       0.6524        0.6752  10.7745
     12        [36m0.4394[0m       0.6547        0.6737  10.7125
     13        [36m0.4307[0m     

100%|██████████| 17928/17928 [01:40<00:00, 178.62it/s]
  warn(
  warn(


2022-07-08 17:50:48 INFO     Recorded fold matbench_mp_is_metal-3 successfully.


100%|██████████| 71659/71659 [06:20<00:00, 188.16it/s]
  warn(
  warn(


  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.5997[0m       [32m0.6197[0m        [35m0.6638[0m  10.9605
      2        [36m0.5519[0m       [32m0.6425[0m        [35m0.6501[0m  10.7185
      3        [36m0.5317[0m       [32m0.6442[0m        [35m0.6350[0m  10.7315
      4        [36m0.5290[0m       0.6300        0.6813  10.7890
      5        [36m0.5041[0m       [32m0.6557[0m        0.6740  10.5690
      6        [36m0.5022[0m       0.6289        0.7428  10.4915
      7        [36m0.4817[0m       0.6442        0.7406  10.6470
      8        [36m0.4694[0m       0.6246        0.7653  10.7805
      9        [36m0.4636[0m       0.6217        0.7734  10.6635
     10        [36m0.4534[0m       [32m0.6572[0m        0.7254  10.4560
     11        [36m0.4469[0m       0.6469        0.7296  10.4670
     12        [36m0.4401[0m       0.6500        0.7351  10.5830
     13 

100%|██████████| 17868/17868 [01:41<00:00, 175.96it/s]
  warn(
  warn(


2022-07-08 18:02:48 INFO     Recorded fold matbench_mp_is_metal-4 successfully.
2022-07-08 18:02:49 INFO     Successfully wrote MatbenchBenchmark to file 'new-results.json.gz'.


In [None]:
# Make sure our benchmark is valid
valid = mb.is_valid
print(f"is valid: {valid}")

# Check out how our algorithm is doing using scores
import pprint
pprint.pprint(mb.scores)

# Get some more info about the benchmark
mb.get_info()

is valid: True
{'matbench_mp_is_metal': {'accuracy': {'max': 0.8046838186787296,
                                       'mean': 0.78905512981546,
                                       'min': 0.7796729962776233,
                                       'std': 0.011047855221953238},
                          'balanced_accuracy': {'max': 0.7889602638667743,
                                                'mean': 0.7670721073159733,
                                                'min': 0.7535760749130695,
                                                'std': 0.015144418534814375},
                          'f1': {'max': 0.7484677468292978,
                                 'mean': 0.7104270714737543,
                                 'min': 0.6843673801137903,
                                 'std': 0.026518595116783353},
                          'rocauc': {'max': 0.7889602638667743,
                                     'mean': 0.7670721073159734,
                                     'min'

Finally, let's also display the score summary for just the structures in the test sets with 52 sites or fewer.

In [None]:
from IPython.display import display

df = pd.concat(subset_test_scores, axis=1).T
display(df)
df.describe()

Unnamed: 0,Accuracy,Balanced Accuracy,F1 Score,ROC AUC
fold-0,0.793752,0.790358,0.761561,0.790358
fold-1,0.771627,0.766685,0.724315,0.766685
fold-2,0.773029,0.76669,0.715604,0.76669
fold-3,0.801818,0.798735,0.776358,0.798735
fold-4,0.769756,0.764638,0.710648,0.764638


Unnamed: 0,Accuracy,Balanced Accuracy,F1 Score,ROC AUC
count,5.0,5.0,5.0,5.0
mean,0.781996,0.777421,0.737697,0.777421
std,0.014738,0.015933,0.029423,0.015933
min,0.769756,0.764638,0.710648,0.764638
25%,0.771627,0.766685,0.715604,0.766685
50%,0.773029,0.76669,0.724315,0.76669
75%,0.793752,0.790358,0.761561,0.790358
max,0.801818,0.798735,0.776358,0.798735
