##To do:

- enforce SO(3) symmetry in coordinate space of crystals. Augment coordinate space data with SO(3) transformations, then enforce transformation equivariance in the loss function.
- implement equivariant flows - these are just a variarion of continous normalizing flows (cnfs) (see eqn 8 https://papers.nips.cc/paper/2018/file/69386f6bb1dfed68692a24c8686939b9-Paper.pdf) where the 'vector field' of the flow is invariant (i.e. in simple terms this is our nn.Sequential() inside the self.flow() module of the vae). Explained in section 4/5 of https://arxiv.org/pdf/2006.02425.pdf. We want a network which respects the SO(3) symmetry of the materials. They refer to 'schnet' https://github.com/atomistic-machine-learning/schnetpack as an example network.
- prepare data and shift augmented data. It might also be useful to augment broadened/noised/splitted/backgrounded data and combine all of these to feed into the vae. 
- look at the sampling scheme.


In [None]:
#!pip install torchdyn
#!pip install git+https://github.com/google-research/torchsde.git
#!pip install torchdiffeq
#!pip install antialiased-cnns
#!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html


Collecting git+https://github.com/google-research/torchsde.git
  Cloning https://github.com/google-research/torchsde.git to /tmp/pip-req-build-q6v8a3_s
  Running command git clone -q https://github.com/google-research/torchsde.git /tmp/pip-req-build-q6v8a3_s
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting torch<1.8.0,>=1.6.0
  Using cached https://files.pythonhosted.org/packages/90/5d/095ddddc91c8a769a68c791c019c5793f9c4456a688ddd235d6670924ecb/torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl
Building wheels for collected packages: torchsde
  Building wheel for torchsde (PEP 517) ... [?25l[?25hdone
  Created wheel for torchsde: filename=torchsde-0.2.5-cp37-none-any.whl size=55592 sha256=66bd4a8a255a67e2c21079a97b2310479b5cd31dbeaf6c4cb1b834c6a6e591f1
  Stored in directory: /tmp/pip-ephem-wheel-cache-ezvvm2u6/wheels/31/b5/4b/53c7d7c124c1bbfebd2c5f429ca86b5e59f6cd471

In [None]:
import torch
import torch.nn as nn
from torch.autograd import grad
from torchdyn.models import *
from torchdyn import *
from torchdyn.datasets import *
import antialiased_cnns
from sklearn.preprocessing import StandardScaler, LabelEncoder
import pandas as pd

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
    torch.set_default_dtype(torch.float64)
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")
    torch.set_default_dtype(torch.float64)

torch.manual_seed(1)

<torch._C.Generator at 0x7f45070c3b10>

In [None]:
# load
data = pd.read_csv('/content/drive/MyDrive/material-VAE/new_vae/theor.csv', index_col=0)
data = data.iloc[1:,]
xrd = np.delete(data.values, list(range(0, data.shape[1], 2)), axis=1)
angle = np.delete(data.values, list(range(1, data.shape[1], 2)), axis=1)
xrd, angle = np.transpose(xrd), np.transpose(angle)
# plt.plot(angle[0], xrd[0])

# normalize
xrd_scale = StandardScaler()
xrd = xrd_scale.fit_transform(xrd)

# prepare for convoultions [batch_size, channels, features] - channels=2 (one for xrd one for angle)
# X = []
# for i in range(xrd.shape[0]):
#     X.append(np.vstack((xrd[i], angle[i])))
# X = np.array(X)

# prepare for convoultions [batch_size, channels, features] - channels=1
X = np.expand_dims(xrd, axis=1)

# label
y = pd.read_csv('/content/drive/MyDrive/material-VAE/new_vae/label_theo.csv', header=None, index_col=0)
y = np.delete(y.values, list(range(0, y.shape[0], 2)), axis=0)
y = np.ravel(y).tolist()

le = LabelEncoder()
y = le.fit_transform(y)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        super(Dataset, self).__init__()
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.X[idx], self.y[idx]
    
trainset = Dataset(X, y)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)

del xrd, angle, X, y

In [None]:
# helper- put PrintSize() at any point in the nn.Sequential() and it will show the network dimension at that point
# e.g. nn.Sequential(ConvODE,
#                    PrintSize(),
#                    nn.Linear(),
#                    PrintSize())
class PrintSize(nn.Module):
    def __init__(self):
        super(PrintSize, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x

# https://github.com/DiffEqML/torchdyn/blob/master/tutorials/08_hamiltonian_nets.ipynb
# easy wrapper for any nn.Sequential() network
class HNN(nn.Module):
    def __init__(self, Hamiltonian:nn.Module, dim=1):
        super().__init__()
        self.H = Hamiltonian
        self.n = dim
    def forward(self, x):
        with torch.set_grad_enabled(True):
            x = x.requires_grad_(True)
            gradH = torch.autograd.grad(self.H(x).sum(), x, allow_unused=False, create_graph=True)[0] 
        return torch.cat([gradH[:,self.n:], -gradH[:,:self.n]], 1).to(x)
    
# calculating the Jacobian trace for continous normalizing flows    
def autograd_trace(x_out, x_in, **kwargs):
    """Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd"""
    trJ = 0.
    for i in range(x_in.shape[1]):
        trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[0][:, i]  
    return trJ

# https://github.com/DiffEqML/torchdyn/blob/master/tutorials/07a_continuous_normalizing_flows.ipynb
# easy wrapper for any nn.Sequential() transformation
class CNF(nn.Module):
    def __init__(self, net, trace_estimator=None, noise_dist=None):
        super().__init__()
        self.net = net
        self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace;
        self.noise_dist, self.noise = noise_dist, None
        if self.trace_estimator in REQUIRES_NOISE:
            assert self.noise_dist is not None, 'This type of trace estimator requires specification of a noise distribution'
            
    def forward(self, x):   
        with torch.set_grad_enabled(True):
            x_in = torch.autograd.Variable(x[:,1:], requires_grad=True).to(x) # first dimension reserved to divergence propagation          
            # the neural network will handle the data-dynamics here
            x_out = self.net(x_in)
                
            trJ = self.trace_estimator(x_out, x_in, noise=self.noise)
        return torch.cat([-trJ[:, None], x_out], 1) + 0*x # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph

# helps reweight the convolutional channels at the end of a CNN
class SqueezeExcitation(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        reduced_channels = max(channel // reduction, int(channel ** 0.5))
        self.fc = nn.Sequential(
            nn.Linear(channel, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, channel, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        b, c, = x.shape[:2]
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.expand_as(x)
    
# replace ReLU 
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class Swish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)
       
# easy to use class version of this: https://github.com/DiffEqML/torchdyn/blob/master/tutorials/06_higher_order.ipynb
# can be used as a simple network: NeuralODE(in, out, hidden)
class NeuralODE(nn.Module):
    
    """
    A general-purpose high order ode network for 1-dim inputs.
    Option to be used as a Gaussian encoder network (or mixture of Gaussian encoder network)
    """

    def __init__(self, in_dim, out_dim, hidden_dim, order=4, batch_norm=True, gauss_encoder=False):
        super().__init__()
        
        self.initial_layer = nn.Linear(in_dim, hidden_dim)
        self.nde = NeuralDE(
            nn.Sequential(
                nn.BatchNorm1d(hidden_dim*order),
                Swish(),
                nn.Linear(hidden_dim*order, hidden_dim*order),
                nn.BatchNorm1d(hidden_dim*order),
                Swish(),
                nn.Linear(hidden_dim*order, hidden_dim)),
            solver='euler', 
            order=order,
            s_span=torch.linspace(0, 1, 5))
        self.augment = Augmenter(augment_dims=hidden_dim*(order-1))
        self.final_layer = nn.Linear(hidden_dim*order, out_dim)
        
        # if using as a VAE encoder:
        self.encode = gauss_encoder
        if self.encode:
            self.final_layer_loc = nn.Linear(hidden_dim*order, out_dim)
            self.final_layer_scale = nn.Linear(hidden_dim*order, out_dim)
        
    def forward(self, inputs):
        temps = self.initial_layer(inputs)
        temps = self.nde(self.augment(temps))
        if self.encode:
            loc = self.final_layer_loc(temps)
            scale = torch.exp(self.final_layer_scale(temps))
            return loc, scale
        else:
            outputs = self.final_layer(temps)
            return outputs


# just a convoultional nn wrapped in the Hamiltonian class above
# experiment with max/avg pooling (see https://github.com/adobe/antialiased-cnns) ??
# see https://github.com/DiffEqML/torchdyn/blob/master/tutorials/04_augmentation_strategies.ipynb for 'augmentation'
class ConvODE(nn.Module):
    
    """
    A general-purpose augmented convolutional ode network for 1-dim inputs. 
    The network parametrised by Hamilton's canonical differential equations.
    """

    def __init__(self, in_channels, out_channels, augment_dim, channel_length, transpose=False, canonical_dim=1):
        super().__init__()
        
        self.augment = Augmenter(augment_func=nn.Linear(channel_length, channel_length)) # or initialize augmented dim to zero: Augmenter(augment_dims=augment_dim)
        self.transpose = transpose
        if transpose == False:
            self.ham_func = HNN(
                nn.Sequential(
                    nn.Conv1d(augment_dim+in_channels, out_channels, kernel_size=8, stride=8),
                    nn.BatchNorm1d(out_channels),
                    Swish(),
                    nn.Conv1d(out_channels, out_channels, kernel_size=5, stride=5),
                    nn.BatchNorm1d(out_channels),
                    Swish(),
                    nn.Conv1d(out_channels, in_channels, kernel_size=3, stride=1),
                    nn.BatchNorm1d(in_channels),
                    Swish(),
                    antialiased_cnns.BlurPool1D(in_channels, stride=3),
                    SqueezeExcitation(in_channels, reduction=8)),
                dim=canonical_dim)
        else:
            self.ham_func = HNN(
                nn.Sequential(
                    nn.ConvTranspose1d(augment_dim+in_channels, out_channels, kernel_size=3, stride=1),
                    nn.BatchNorm1d(out_channels),
                    Swish(),
                    antialiased_cnns.BlurPool1D(out_channels, stride=3),
                    nn.ConvTranspose1d(out_channels, out_channels, kernel_size=5, stride=5),
                    nn.BatchNorm1d(out_channels),
                    Swish(),
                    nn.ConvTranspose1d(out_channels, out_channels, kernel_size=8, stride=8),
                    nn.BatchNorm1d(out_channels),
                    Swish(),
                    SqueezeExcitation(out_channels, reduction=8)),
                dim=canonical_dim)
        self.nde = NeuralDE(
            self.ham_func,
            solver='euler',
            s_span=torch.linspace(0, 1, 5))
        self.bn = nn.BatchNorm1d(in_channels)
        self.swish = Swish()

    def forward(self, inputs):
        temps = self.swish(self.bn(inputs))
        outputs = self.nde(self.augment(temps))
        return outputs


# encoder/decoder with continous normalizing flows in between
# inside CNF() is what I presume they call the 'vector field' and what they want to make invariant
# I save and return the Jacobian trace of the flow so I can add it to the loss function
# copied how Yi did it under VAENF_loss(): https://github.com/CVC-Lab/Material_VAE/blob/master/loss_function.py  
class FlowODEVAE(nn.Module):
    
    """
    1D convolutional ODE-VAE with continuous normalizing flo Sampling of Coupled Particle Systemws.
    """
    
    def __init__(self, in_dim, in_channels, conv_channels, augment_dim, latent_dim, order=4, canonical_dim=1):
        super().__init__()
        
        self.flatten_dim = in_channels*in_dim
        self.flatten_aug_dim = (in_channels+augment_dim)*in_dim
        self.encoder = nn.Sequential(
            ConvODE(in_channels, conv_channels, augment_dim, in_dim, canonical_dim=canonical_dim),
            nn.Flatten(),
            NeuralODE(self.flatten_aug_dim, latent_dim, int(self.flatten_aug_dim/8), order=order, gauss_encoder=True))
        self.decoder = nn.Sequential(
            NeuralODE(latent_dim, self.flatten_dim, int(self.flatten_dim/8), order=order),
            nn.Unflatten(1, (in_channels, in_dim)),
            ConvODE(in_channels, conv_channels, augment_dim, in_dim, transpose=True),
            nn.Conv1d(in_channels+augment_dim, in_channels, kernel_size=1))
        self.flow = NeuralDE(
            CNF(nn.Sequential(
                    nn.Linear(latent_dim, 64),
                    nn.Softplus(),
                    nn.Linear(64, 64),
                    nn.Softplus(),
                    nn.Linear(64, 64),
                    nn.Softplus(),
                    nn.Linear(64, latent_dim)),
                    trace_estimator=autograd_trace), #hutch_trace
            solver='dopri5',  
            sensitivity='adjoint',
            s_span=torch.linspace(0, 1, 2),
            atol=1e-4,
            rtol=1e-4)
        self.augment = Augmenter(augment_idx=1, augment_dims=1)
        
    def forward(self, inputs):
        loc, scale = self.encoder(inputs)
        z_prior = torch.distributions.Normal(loc, scale).sample()
        transform = self.flow(self.augment(z_prior))
        z_flow, trace_J = transform[:,1:], transform[:,0]
        outputs = self.decoder(z_flow)
        return outputs, loc, scale, z_flow, trace_J
        
        
def Loss(X, X_hat, mu, var, trace_J):
    KLD = -0.5 * torch.sum(1 + torch.log(var) - mu.pow(2) - var)
    MSE = nn.MSELoss(reduction='mean')(X_hat, X) #(X-X_hat).pow(2).sum()/ X.numel()
    return MSE + (KLD/y.numel()) - trace_J.mean()

In [None]:
class Args:
    epochs = 100
    lr = 1e-5
    weight_decay = 1e-4
    log_interval = 7
args = Args()


##############
#code for vae#
##############

model = FlowODEVAE(in_dim=2125, in_channels=1, conv_channels=32, augment_dim=1, latent_dim=15, order=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


for epoch in range(args.epochs+1):
    model.train()
    for batch_idx, (X, y) in enumerate(trainloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        X_hat, mu, var, z, trace_J = model(X)
        loss = Loss(X_hat, X, mu, var, trace_J)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, batch_idx * len(X), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.item()))

               

#####################
#code for classifier#
#####################
'''
channels = 1
augment = 1
filters = 64

model = nn.Sequential(
    ConvODE(channels, filters, augment, 2125),
    nn.Flatten(),
    NeuralODE((channels+augment)*2125, 7, int(2125/2), order=5),
    nn.Softmax(dim=0)
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)


for epoch in range(args.epochs+1):
    model.train()
    correct = 0
    for batch_idx, (X, y) in enumerate(trainloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(X)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        loss.backward()
        optimizer.step()
        pred = y_hat.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(y.view_as(pred)).sum().item()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, batch_idx * len(X), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.item()))
    accuracy = 100. * correct / len(trainloader.dataset)
    print("Accuracy = {:.3f}%".format(accuracy))
    scheduler.step()
'''





Note: when I augement the data in the ConvODE() module using a neural network (rather than initializing the augmented dimension to zero), the training is much much slower. However the training is much better. See 'self.augment' in ConvODE() and swap out for commented version for quicker training.

In [None]:
# number of trainable parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

20738070