[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sparks-baird/xtal2png/blob/cnn-classification/notebooks/2.1-xtal2png-cnn-regression.ipynb)

# Regression on `matbench_mp_e_form` using `xtal2png` representation of crystal structures

## Description
In this notebook, a convolutional neural network is applied to the `matbench_mp_e_form` regression task using [`xtal2png`](https://xtal2png.readthedocs.io/en/latest/) representations of crystal structures. Crystal structures are encoded as grayscale PNG images, but because the conversion operations are restricted to structures with fewer than 52 sites, the network is only trained on structures with `num_sites <= 52`. For structures in the test set with more than 52 sites, we simply predict the mean of the training outputs (i.e. the mean of `y_train`, where `X_train` and `y_train` correspond to training inputs and outputs respectively, with `num_sites <= 52`).

## Benchmark Name
Matbench v0.1

## Package Versions
- [matbench](https://anaconda.org/conda-forge/matbench)==0.5.0
- [xtal2png](https://anaconda.org/conda-forge/xtal2png)==0.7.0
- [pytorch](https://anaconda.org/pytorch/pytorch)==1.11.0
- [skorch](https://anaconda.org/conda-forge/skorch)==0.11.0
- [pytorch-lightning](https://anaconda.org/conda-forge/pytorch-lightning)==1.6.4
- [mosaicml](https://anaconda.org/mosaicml/mosaicml)==0.8.0

## Algorithm Description
A fairly simple CNN is created in vanilla PyTorch, very loosely following the PyTorch implementation of [AlexNet](https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py). Model surgery is then performed on the max-pooling and certain convolutional layers using MosaicML's [Composer](https://github.com/mosaicml/composer) library.

### Imports

In [None]:
%pip install matbench skorch xtal2png pytorch-lightning mosaicml

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matbench
  Downloading matbench-0.5-py3-none-any.whl (9.9 MB)
[K     |████████████████████████████████| 9.9 MB 7.1 MB/s 
[?25hCollecting skorch
  Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
[K     |████████████████████████████████| 155 kB 62.0 MB/s 
[?25hCollecting xtal2png
  Downloading xtal2png-0.8.0-py3-none-any.whl (30 kB)
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.6.4-py3-none-any.whl (585 kB)
[K     |████████████████████████████████| 585 kB 80.7 MB/s 
[?25hCollecting mosaicml
  Downloading mosaicml-0.8.0-py3-none-any.whl (548 kB)
[K     |████████████████████████████████| 548 kB 52.1 MB/s 
[?25hCollecting scikit-learn==1.0
  Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)
[K     |████████████████████████████████| 23.1 MB 1.0 MB/s 
[?25hCollecting monty==2021.8.17
  Downloading monty-2021

In [None]:
import composer.functional as cf
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from matbench.bench import MatbenchBenchmark
from skorch.callbacks import EarlyStopping
from skorch.regressor import NeuralNetRegressor
from torch import nn
from xtal2png.core import XtalConverter

# Set all random seeds as specified by Matbench
pl.seed_everything(18012019)

Global seed set to 18012019


18012019

### CNN Architecture
For the vanilla PyTorch model, the architecture of the convolutional layers is as follows:
```python
self.convolutions = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=3, padding=1),  # (64, 64, 1) --> (64, 64, 8)
    nn.BatchNorm2d(8),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (31, 31, 8)

    nn.Conv2d(8, 16, kernel_size=3, padding=1),     # --> (31, 31, 16)
    nn.BatchNorm2d(16),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (15, 15, 16)
    
    nn.Conv2d(16, 32, kernel_size=3, padding=1),    # --> (15, 15, 32)
    nn.BatchNorm2d(32),
    nn.Mish(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),          # --> (7, 7, 32)
)
```
The full `CNNRegressor` class is defined below:

In [None]:
class CNNRegressor(nn.Module):
    def __init__(self, dropout: float = 0.5) -> None:
        super().__init__()
        self.convolutions = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.Mish(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.fullyconnected = nn.Sequential(
            nn.Linear(7 * 7 * 32, 512),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 256),
            nn.Mish(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, 256),
            nn.Mish(inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.convolutions(x)
        x = torch.flatten(x, 1)  # flatten all but batch dim
        x = self.fullyconnected(x)
        return x

To get slightly better generalization performance, [BlurPool](https://docs.mosaicml.com/en/latest/method_cards/blurpool.html) and [squeeze-and-excite](https://docs.mosaicml.com/en/latest/method_cards/squeeze_excite.html) operations were applied to the model using Composer. BlurPool layers replace all max pooling layers, and squeeze-excite layers replace certain convolutional layers with channels above a threshold. Below is the full architecture of the model:

```python
>>> model = CNNRegressor()
>>> composer.functional.apply_squeeze_excite(model, min_channels=16)
>>> composer.functional.apply_blurpool(model)
```
```
CNNRegressor(
  (convolutions): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish(inplace=True)
    (3): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Mish(inplace=True)
    (7): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): SqueezeExciteConv2d(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (se): SqueezeExcite2d(
        (pool_and_mlp): Sequential(
          (0): AdaptiveAvgPool2d(output_size=1)
          (1): Flatten(start_dim=1, end_dim=-1)
          (2): Linear(in_features=32, out_features=64, bias=False)
          (3): ReLU()
          (4): Linear(in_features=64, out_features=32, bias=False)
          (5): Sigmoid()
        )
      )
    )
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Mish(inplace=True)
    (11): BlurMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fullyconnected): Sequential(
    (0): Linear(in_features=1568, out_features=512, bias=True)
    (1): Mish(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
    (4): Mish(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): Mish(inplace=True)
    (8): Linear(in_features=256, out_features=1, bias=True)
  )
)
```

## Benchmark on Matbench Folds
Training is done using [skorch](https://skorch.readthedocs.io/en/stable/) to abstract away the typical training loop and need for DataLoaders. Images are preprocessed in the following manner:
- Convert from `PIL.Image` to `torch.Tensor` and scale all pixel values to `[0.0, 1.0]`.
- Compute the mean and standard deviation of scaled pixel values, then normalize to zero-mean, unit variance.

For normalization, note that the mean and standard deviations of pixel values are calculated separately per training fold. In each fold, the statistics of the training set are also used to normalize the respective test set. 

The matbench benchmark is run below. To save all image representations created by the `xtal2png` function during crystal to image conversion, set the `save` flag to `True`.

In [None]:
xc = XtalConverter()
mb = MatbenchBenchmark(autoload=False, subset=["matbench_mp_e_form"])
save = False  # xtal2png(data, save=save)

for task in mb.tasks:
    task.load()
    for fold in task.folds:
        # Get training data
        train_inputs, train_outputs = task.get_train_and_val_data(fold)

        # Train on structures with num_sites <= 52
        site_counter = lambda x: x.num_sites
        idx = train_inputs.apply(site_counter) <= 52
        X_train = train_inputs[idx]
        y_train = train_outputs[idx]

        # Convert crystal structures to images
        X_train = xc.xtal2png(X_train, save=False)

        # Convert PIL Images to torch.Tensor
        # Note that this scales from [0, 255] to [0.0, 1.0]
        X_train = [TF.to_tensor(img) for img in X_train]
        # Normalize images (subtract mean, divide by std)
        mean = torch.cat(X_train).mean()
        std = torch.cat(X_train).std()
        X_train = [TF.normalize(i, mean=mean, std=std) for i in X_train]

        # Change X from a list of tensors to a single tensor
        # Change shape of y from 1D to 2D for skorch regressor
        X_train = torch.stack(X_train)
        y_train = y_train.values.reshape(-1, 1).astype(np.float32)

        # Apply Composer methods to vanilla PyTorch regressor before training
        model = CNNRegressor()
        cf.apply_squeeze_excite(model, min_channels=16)
        cf.apply_blurpool(model)

        # Train and validate classifier with skorch
        net = NeuralNetRegressor(
            model,
            criterion=nn.MSELoss,
            max_epochs=50,
            optimizer=optim.AdamW,
            optimizer__amsgrad=True,
            optimizer__lr=0.0005,
            callbacks=[EarlyStopping(patience=10)],
            device="cuda" if torch.cuda.is_available() else "cpu",
            batch_size=64,
        )
        net.fit(X_train, y_train)

        # Get test data and keep structures with num_sites <= 52
        test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
        idx = test_inputs.apply(site_counter) <= 52
        X_test = test_inputs[idx]

        # Convert to images, preprocess using mean and std from training data
        X_test = xc.xtal2png(X_test, save=False)
        preprocess = T.Compose([T.ToTensor(), T.Normalize(mean, std),])
        X_test = torch.stack([preprocess(img) for img in X_test])

        # Predict on X_test
        # For structures with num_sites > 52, predict mean of y_train
        y_pred = net.predict(X_test)
        y_pred_full = np.empty(test_inputs.size)
        y_pred_full[idx] = y_pred.flatten()
        y_pred_full[~idx] = y_train.mean()

        # Record data
        task.record(fold, y_pred_full)

# Save benchmark results
mb.to_file("reg-results.json.gz")

2022-07-08 10:23:47 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_mp_e_form']
2022-07-08 10:23:47 INFO     Loading dataset 'matbench_mp_e_form'...
2022-07-08 10:25:49 INFO     Dataset 'matbench_mp_e_form loaded.


100%|██████████| 90001/90001 [07:48<00:00, 191.94it/s]
  warn(
  warn(


  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0020[0m       [32m18.1529[0m  16.5975
      2        0.0384       [32m16.6310[0m  12.7425
      3        0.0571       [32m13.4103[0m  12.5390
      4        0.0488       17.7011  12.5970
      5        0.0672       15.5128  12.5975
      6        0.0577       18.6461  12.7355
      7        0.0623        [32m9.9081[0m  12.6740
      8        0.0439       19.2510  12.5930
      9        0.0970       16.8708  12.5310
     10        0.0977       17.6505  12.7025
     11        0.0878       19.2228  12.7420
     12        0.0969       19.1816  12.5340
     13        0.1103       14.6279  12.5715
     14        0.0729       18.5980  12.7015
     15        0.1059       15.6332  12.7890
     16        0.0925       19.1866  12.5650
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 22544/22544 [02:01<00:00, 185.47it/s]
  warn(
  warn(


2022-07-08 10:40:15 INFO     Recorded fold matbench_mp_e_form-0 successfully.


100%|██████████| 89991/89991 [07:43<00:00, 194.27it/s]
  warn(
  warn(


  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0020[0m       [32m18.0537[0m  13.0495
      2        0.0811        [32m4.6638[0m  12.8200
      3        0.1273        5.8073  12.6550
      4        0.0562        8.6943  12.4255
      5        0.0634       17.0132  12.4490
      6        0.1046       15.2822  12.8110
      7        0.0698       18.2485  12.6245
      8        0.2897        9.3080  12.4105
      9        0.0727       13.5770  12.3755
     10        0.0946       11.2782  12.6520
     11        0.0806       14.0032  12.6285
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 22554/22554 [02:07<00:00, 176.59it/s]
  warn(
  warn(


2022-07-08 10:53:38 INFO     Recorded fold matbench_mp_e_form-1 successfully.


100%|██████████| 90126/90126 [07:45<00:00, 193.56it/s]
  warn(
  warn(


  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0023[0m       [32m18.4915[0m  12.7550
      2        0.3970       [32m13.9041[0m  12.5605
      3        0.0381       [32m12.5564[0m  12.7695
      4        0.0578       14.3654  12.8330
      5        0.0591       14.5919  12.4445
      6        0.0659       18.4184  12.5425
      7        0.1046       16.4346  12.5675
      8        0.0981       19.0178  12.6730
      9        0.1246       14.1160  12.5770
     10        0.0980       19.1989  12.5265
     11        0.1265       18.0837  12.5975
     12        0.1625       18.9763  12.6950
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 22419/22419 [02:02<00:00, 182.73it/s]
  warn(
  warn(


2022-07-08 11:07:08 INFO     Recorded fold matbench_mp_e_form-2 successfully.


100%|██████████| 90052/90052 [07:24<00:00, 202.47it/s]
  warn(
  warn(


  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0020[0m       [32m17.8750[0m  12.7195
      2        0.0663        [32m7.2680[0m  12.6040
      3        0.0881        [32m6.6167[0m  12.6110
      4        0.0372       15.9711  12.8515
      5        0.0948       10.5767  12.5495
      6        0.0560       16.6127  12.4895
      7        0.0500       17.3700  12.7215
      8        0.0639       17.7506  12.7755
      9        0.0673       10.0282  12.7170
     10        0.0600       10.4584  12.5055
     11        0.0693       13.6575  12.6980
     12        0.0726       18.2052  12.6100
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 22493/22493 [02:03<00:00, 182.77it/s]
  warn(
  warn(


2022-07-08 11:20:17 INFO     Recorded fold matbench_mp_e_form-3 successfully.


100%|██████████| 90010/90010 [07:47<00:00, 192.36it/s]
  warn(
  warn(


  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0022[0m       [32m15.7354[0m  12.6700
      2        0.0365       16.7053  12.5100
      3        0.0465       [32m14.8574[0m  12.4960
      4        0.7193        [32m5.7185[0m  12.7445
      5        0.1320        9.2095  12.5350
      6        0.0843       11.2550  12.5955
      7        0.1094       11.4567  12.6630
      8        0.1172       14.5114  12.6105
      9        0.1468       14.0634  12.4440
     10        0.0780        9.5807  12.5160
     11        0.0816       13.3103  12.6220
     12        0.1415       13.7865  12.5745
     13        0.1483       14.7646  12.8135
Stopping since valid_loss has not improved in the last 10 epochs.


100%|██████████| 22535/22535 [02:02<00:00, 183.91it/s]
  warn(
  warn(


2022-07-08 11:33:57 INFO     Recorded fold matbench_mp_e_form-4 successfully.
2022-07-08 11:33:58 INFO     Successfully wrote MatbenchBenchmark to file 'reg-results.json.gz'.


In [None]:
# Make sure our benchmark is valid
valid = mb.is_valid
print(f"is valid: {valid}")

# Check out how our algorithm is doing using scores
import pprint
pprint.pprint(mb.scores)

# Get some more info about the benchmark
mb.get_info()

is valid: True
{'matbench_mp_e_form': {'mae': {'max': 2.526535151128166,
                                'mean': 2.247677549178041,
                                'min': 1.9270280884437816,
                                'std': 0.2173278846449626},
                        'mape': {'max': 34.1503729497845,
                                 'mean': 26.723119633126373,
                                 'min': 18.471787525063075,
                                 'std': 5.6747174033422},
                        'max_error': {'max': 6.6549069282540625,
                                      'mean': 6.286157434377406,
                                      'min': 5.898015143570556,
                                      'std': 0.26371023086119855},
                        'rmse': {'max': 2.8384010692391595,
                                 'mean': 2.5576678160617874,
                                 'min': 2.242694598448991,
                                 'std': 0.21454607521353117}}}
2022-07-