https://github.com/sagelywizard/pytorch-mdn/blob/master/mdn/mdn.py

That is for the GMM implementation in PyTorch. Very good resource!!

https://github.com/ldeecke/gmm-torch/blob/master/gmm.py

That is the implementation of the EM algorithm getting the components of a GMM!!

In [1]:
import os.path as osp
import os
import glob
import torch
import awkward as ak
from torch.utils.data import Dataset
import time
import uproot
import uproot3
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.data import DataListLoader, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import SplineConv, global_mean_pool, DataParallel, EdgeConv
from torch_geometric.data import Data
from torchsummary import summary
from sklearn.neighbors import kneighbors_graph
import scipy.sparse as ss
from datetime import datetime, timedelta
import torch
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
import math

In [None]:
class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer from:
    Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
    Forward pass is the identity function. In the backward pass,
    the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
    """

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None


class GradientReversal(torch.nn.Module):
    def __init__(self, lambda_=1):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

In [None]:
def to_categorical(y, num_classes=None, dtype='float32'):
    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical

In [2]:
def create_graph_train(z, k, d, p1, p2, label):
    vec = []
    vec.append(np.array([d, z, k]).T)
    vec = np.array(vec)
    vec = np.squeeze(vec)

    v1 = [[ind, x] for ind, x in enumerate(p1) if x > -1]
    v2 = [[ind, x] for ind, x in enumerate(p2) if x > -1]

    a1 = np.reshape(v1,(len(v1),2)).T
    a2 = np.reshape(v2,(len(v2),2)).T
    edge1 = np.concatenate((a1[0], a2[0], a1[1], a2[1]),axis = 0)
    edge2 = np.concatenate((a1[1], a2[1], a1[0], a2[0]),axis = 0)
    edge = torch.tensor(np.array([edge1, edge2]), dtype=torch.long)
    return Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge, y=torch.tensor(label, dtype=torch.float))

In [3]:
def create_train_dataset_fulld(z, k, d, p1, p2, label):
    graphs = [create_graph_train(a, b, c, d, e, f) for a, b, c, d, e, f in zip(z, k, d, p1, p2, label)]
    #graphs.append(Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge, y=torch.tensor(label[i], dtype=torch.float)))
    return graphs

In [4]:
def create_train_dataset_fulld(z, k, d, p1, p2, label):
    graphs = []
    for i in range(len(z)):
        if i%1000 == 0: 
            print("Processing event {}/{}".format(i, len(z)), end="\r")
        vec = []
        vec.append(np.array([d[i], z[i], k[i]]).T)
        vec = np.array(vec)
        vec = np.squeeze(vec)

        v1 = [[ind, x] for ind, x in enumerate(p1[i]) if x > -1]
        v2 = [[ind, x] for ind, x in enumerate(p2[i]) if x > -1]

        a1 = np.reshape(v1,(len(v1),2)).T
        a2 = np.reshape(v2,(len(v2),2)).T
        edge1 = np.concatenate((a1[0], a2[0], a1[1], a2[1]),axis = 0)
        edge2 = np.concatenate((a1[1], a2[1], a1[0], a2[0]),axis = 0)
        edge = torch.tensor(np.array([edge1, edge2]), dtype=torch.long)
        graphs.append(Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge, y=torch.tensor(label[i], dtype=torch.float)))
    return graphs

In [5]:
def create_test_dataset_fulld(z, k, d, p1, p2):
    graphs = []
    for i in range(len(z)):
        vec = []
        vec.append(np.array([d[i], z[i], k[i]]).T)
        vec = np.array(vec)
        vec = np.squeeze(vec)
        v1 = [[ind, x] for ind, x in enumerate(p1[i]) if x > -1]
        v2 = [[ind, x] for ind, x in enumerate(p2[i]) if x > -1]

        a1 = np.reshape(v1,(len(v1),2)).T
        a2 = np.reshape(v2,(len(v2),2)).T
        edge1 = np.concatenate((a1[0], a2[0], a1[1], a2[1]),axis = 0)
        edge2 = np.concatenate((a1[1], a2[1], a1[0], a2[0]),axis = 0)
        edge = torch.tensor(np.array([edge1, edge2]), dtype=torch.long)
        graphs.append(Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge))
    return graphs

In [6]:
def create_train_dataset_prmy(z, k, d, label):
    graphs = []
    for i in range(len(z)):
        vec = []
        vec.append(np.array([d[i], z[i], k[i]]).T)
        vec = np.array(vec)
        vec = np.squeeze(vec)
        ya = kneighbors_graph(vec, n_neighbors=int(np.floor(vec.shape[0]/2)))
        edges = np.array([ya.nonzero()[0], ya.nonzero()[1]])
        edge = torch.tensor(edges, dtype=torch.long)
        graphs.append(Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge, y=torch.tensor(label[i], dtype=torch.float)))
    return graphs

In [7]:
def create_test_dataset_prmy(z, k, d):
    graphs = []
    for i in range(len(z)):
        vec = []
        vec.append(np.array([d[i], z[i], k[i]]).T)
        vec = np.array(vec)
        vec = np.squeeze(vec)
        ya = kneighbors_graph(vec, n_neighbors=int(np.floor(vec.shape[0]/2)))
        edges = np.array([ya.nonzero()[0], ya.nonzero()[1]])
        edge = torch.tensor(edges, dtype=torch.long)
        graphs.append(Data(x=torch.tensor(vec, dtype=torch.float), edge_index=edge))
    return graphs

In [8]:
#Configuration from bash script
if "INFILE" in os.environ:
    infile_path = os.environ["INFILE"]
    model_name = os.environ["MODEL"]
    epochs = int(os.environ["EPOCHS"])

#Configuration in notebook
else:    
#    files = glob.glob("/mnt/storage/asopio/train_test_split_20210305/training_set_jz2to9.root")
    files = glob.glob("/mnt/storage/raresiora/train_test_split/full_decluster/wprime/*_train.root") + glob.glob("/mnt/storage/raresiora/train_test_split/full_decluster/jz3to9/*_train.root")
    model_name = "GNN_full"
#    model_name = "LSTM"
#    model_name = "1DCNN"
#    model_name = "2DCNN"
#    model_name = "ImgCNN"
#    model_name = "GNN"
#    model_name = "Transformer"
#    nb_epochs = 20

In [9]:
nentries_total = 0
intreename = "lundjets_InDetTrackParticles"
print(files[0])
for infile_name in files: 
    nentries_total += uproot3.numentries(infile_name, intreename)

print("Evaluating on {} files with {} entries in total.".format(len(file3), nentries_total))

/mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000001.ANALYSIS.root_train.root
Evaluating on 16 files with 334337 entries in total.


In [11]:
#Load tf keras model
# jet_type = "Akt10RecoChargedJet" #track jets
jet_type = "Akt10UFOJet" #UFO jets

save_trained_model = True
intreename = "lundjets_InDetTrackParticles"

model_filename = "save/models/"+model_name+".hdf5"

In [12]:
def pad_ak(arr, l):
    arr = ak.pad_none(arr, l)
    arr = ak.fill_none(arr, 1)
    arr = ak.to_numpy(arr)
    return arr

def pad_ak3(arr, l):
    arr = ak.pad_none(arr, 1)
    arr = ak.fill_none(arr, [0])
    arr = arr[:,0]
    arr = ak.pad_none(arr, l)
    arr = ak.fill_none(arr, 0)
    #arr = ak.to_numpy(arr)
    return arr

print("Training tagger on files", len(files))
t_start = time.time()

dsids = np.array([])
NBHadrons = np.array([])
trial = np.ones((1,30))
all_lund_zs = ak.from_numpy(trial)
all_lund_kts = ak.from_numpy(trial)
all_lund_drs = ak.from_numpy(trial)
parent1 = ak.from_numpy(trial)
parent2 = ak.from_numpy(trial)
jet_pts = np.array([])
jet_ms = np.array([])
eta = np.array([])
vector = []
for file in files:
    
    print("Loading file",file)
    
    infile = uproot.open(file)
    tree = infile[intreename]
    dsids = np.append( dsids, np.array(tree["DSID"].array()) )
    #print(tree.keys())
    #eta = ak.concatenate(eta, pad_ak3(tree["Akt10TruthJet_jetEta"].array(), 30),axis=0)
    mcweights = tree["mcWeight"].array()
    NBHadrons = np.append( NBHadrons, pad_ak(tree["Akt10TruthJet_inputJetGABHadrons"].array(), 30)[:,0] )
    parent1 = ak.concatenate((parent1, pad_ak3(tree["{}_jetLundIDParent1".format(jet_type)].array(), 2)), axis = 0)
    parent2 = ak.concatenate((parent2, pad_ak3(tree["{}_jetLundIDParent2".format(jet_type)].array(), 2)), axis = 0)
    
    #Get jet kinematics
    jet_pts = np.append(jet_pts, pad_ak(tree["{}_jetPt".format(jet_type)].array(), 30)[:,0])
    #jet_etas = pad_ak(tree["AntiKt10UFOCSSKSoftDropBeta100Zcut10JetsCalib_jetEta"].array(), 30)[:,0]
    #jet_phis = pad_ak(tree["AntiKt10UFOCSSKSoftDropBeta100Zcut10JetsCalib_jetPhi"].array(), 30)[:,0]              
    jet_ms = np.append(jet_ms, pad_ak(tree["{}_jetM".format(jet_type)].array(), 30)[:,0])
    
    #Get Lund variables
    all_lund_zs = ak.concatenate((all_lund_zs, pad_ak3(tree["{}_jetLundZ".format(jet_type)].array(), 2)), axis=0)
    all_lund_kts = ak.concatenate((all_lund_kts, pad_ak3(tree["{}_jetLundKt".format(jet_type)].array(), 2)), axis=0)
    all_lund_drs = ak.concatenate((all_lund_drs, pad_ak3(tree["{}_jetLundDeltaR".format(jet_type)].array(), 2)), axis=0)
    #print(len(jet_pts), len(jet_ms))
all_lund_zs = all_lund_zs[1:]    
all_lund_kts = all_lund_kts[1:]
all_lund_drs = all_lund_drs[1:]
parent1 = parent1[1:]
parent2 = parent2[1:]


delta_t_fileax = time.time() - t_start
print("Opened data in {:.4f} seconds.".format(delta_t_fileax))

#Get labels
#labels = ( dsids > 360000 ) & ( dsids < 370000 )

labels = ( dsids > 370000 ) # depends on your signal and background definition

#print(labels)
labels = to_categorical(labels, 2)

Training tagger on files 22
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000001.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000004.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000008.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000003.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/wprime/user.asopio.24603642._000012.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/j3to9/user.asopio.24603633._000004.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/j3to9/user.asopio.24603631._000012.ANALYSIS.root_train.root
Loading file /mnt/storage/raresiora/train_test_split/full_decluster/j3t

In [None]:
# Just random numbers to check the models work
# These are not used in practice
all_lund_zs = np.random.randn(50000,30)
all_lund_kts = np.random.randn(50000,30)
all_lund_drs = np.random.randn(50000,30)
parent1 = np.random.randint(low=0, high=20, size=(50000, 14))
parent2 = np.random.randint(low=0, high=20, size=(50000, 14))
labels = np.random.randint(low=0, high=1, size=(50000, 2))
jet_pts = np.random.randint(low=1, high=500, size=(50000))
jet_ms = np.random.randint(low=1, high=500, size=(50000))

In [13]:
#W bosons
dataset = create_train_dataset_fulld(all_lund_zs, all_lund_kts, all_lund_drs, parent1, parent2, labels) 
train_loader = DataLoader(dataset, batch_size=1024, shuffle=True)

Processing event 1242000/1242295

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = EdgeConv(nn.Sequential(nn.Linear(6, 64),
                                  nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64)),aggr='add')
        self.conv2 = EdgeConv(nn.Sequential(nn.Linear(128, 128),
                                  nn.ReLU(), nn.Linear(128, 128),nn.ReLU(), nn.Linear(128, 128)),aggr='add')
        self.conv3 = EdgeConv(nn.Sequential(nn.Linear(256,256,),
                                  nn.ReLU(), nn.Linear(256, 256),nn.ReLU(), nn.Linear(256, 256)),aggr='add')
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(256, 64)
        self.lin3 = torch.nn.Linear(64, 2)
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.1)
        x = self.lin2(x)
        x = F.dropout(x, p=0.1)
        x = self.lin3(x)
        #print(x.shape)
        return F.sigmoid(x)

Part to train the classifier. Can be skipped where you just import a trained model

In [None]:
model = Net()
device = torch.device('cuda') # Usually gpu 4 worked best, it had the most memory available
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

def train(epoch):
    model.train()

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        new_y = torch.reshape(data.y, (int(list(data.y.shape)[0]/2),2))
        loss = F.binary_cross_entropy(output, new_y)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
    return loss_all / len(dataset)

@torch.no_grad()
def get_accuracy(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        new_y = torch.reshape(data.y, (int(list(data.y.shape)[0]/2),2))
        pred = model(data).max(dim=1)[1]
        correct += pred.eq(new_y[:,1]).sum().item()
    return correct / len(loader.dataset)
    
@torch.no_grad()
def get_scores(loader):
    model.eval()
    total_output = np.array([[1,1]])
    for data in loader:
        data = data.to(device)
        pred = model(data)
        total_output = np.append(total_output, pred.cpu().detach().numpy(), axis=0)
    return total_output[1:]
    
for epoch in range(1, 21):
    loss = train(epoch)
    train_acc = get_accuracy(train_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.5f}'.format(epoch, loss, train_acc))

In [None]:
path = "/your/path/to/models/wprime.pt"
torch.save(model.state_dict(), path)

This is only for loading the trained classifier for mass decorrelation

In [None]:
path = "/your/path/to/models/wprime.pt"
clsf = Net()
clsf.load_state_dict(torch.load(path))

Mixture density network

In [None]:
ONEOVERSQRT2PI = 1.0 / math.sqrt(2 * math.pi)


class MDN(nn.Module):
    """A mixture density network layer
    The input maps to the parameters of a MoG probability distribution, where
    each Gaussian has O dimensions and diagonal covariance.
    Arguments:
        in_features (int): the number of dimensions in the input
        out_features (int): the number of dimensions in the output
        num_gaussians (int): the number of Gaussians per output dimensions
    Input:
        minibatch (BxD): B is the batch size and D is the number of input
            dimensions.
    Output:
        (pi, sigma, mu) (BxG, BxGxO, BxGxO): B is the batch size, G is the
            number of Gaussians, and O is the number of dimensions for each
            Gaussian. Pi is a multinomial distribution of the Gaussians. Sigma
            is the standard deviation of each Gaussian. Mu is the mean of each
            Gaussian.
    """

    def __init__(self, in_features, out_features, num_gaussians):
        super(MDN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_gaussians = num_gaussians
        self.pi = nn.Sequential(
            nn.Linear(in_features, num_gaussians),
            nn.Softmax(dim=1)
        )
        self.sigma = nn.Linear(in_features, out_features * num_gaussians)
        self.mu = nn.Linear(in_features, out_features * num_gaussians)

    def forward(self, minibatch):
        pi = self.pi(minibatch)
        sigma = torch.exp(self.sigma(minibatch))
        sigma = sigma.view(-1, self.num_gaussians, self.out_features)
        mu = self.mu(minibatch)
        mu = mu.view(-1, self.num_gaussians, self.out_features)
        return pi, sigma, mu


def gaussian_probability(sigma, mu, target):
    """Returns the probability of `target` given MoG parameters `sigma` and `mu`.
    Arguments:
        sigma (BxGxO): The standard deviation of the Gaussians. B is the batch
            size, G is the number of Gaussians, and O is the number of
            dimensions per Gaussian.
        mu (BxGxO): The means of the Gaussians. B is the batch size, G is the
            number of Gaussians, and O is the number of dimensions per Gaussian.
        target (BxI): A batch of target. B is the batch size and I is the number of
            input dimensions.
    Returns:
        probabilities (BxG): The probability of each point in the probability
            of the distribution in the corresponding sigma/mu index.
    """
    target = target.unsqueeze(1).expand_as(sigma)
    ret = ONEOVERSQRT2PI * torch.exp(-0.5 * ((target - mu) / sigma)**2) / sigma
    ret = torch.where(ret == 0, ret + 1E-20, ret)
    return torch.prod(ret, 2)


def mdn_loss(pi, sigma, mu, target):
    """Calculates the error, given the MoG parameters and the target
    The loss is the negative log likelihood of the data given the MoG
    parameters.
    """
    prob = pi * gaussian_probability(sigma, mu, target)
    nll = -torch.log(torch.sum(prob, dim=1))
    return torch.mean(nll)


def sample(pi, sigma, mu):
    """Draw samples from a MoG.
    """
    # Choose which gaussian we'll sample from
    pis = Categorical(pi).sample().view(pi.size(0), 1, 1)
    # Choose a random sample, one randn for batch X output dims
    # Do a (output dims)X(batch size) tensor here, so the broadcast works in
    # the next step, but we have to transpose back.
    gaussian_noise = torch.randn(
        (sigma.size(2), sigma.size(0)), requires_grad=False)
    variance_samples = sigma.gather(1, pis).detach().squeeze()
    mean_samples = mu.detach().gather(1, pis).squeeze()
    return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1)

In [None]:
num_gaussians = 20 # number of Gaussians at the end

# The architecture of the adversary could be changed
class Adversary(nn.Module):
    def __init__(self):
        super(Adversary, self).__init__()
        self.gauss = nn.Sequential(
    nn.Linear(2, 64),
    nn.ReLU(),
    MDN(64, 1, num_gaussians)
)
        self.revgrad = GradientReversal(10)
        
    def forward(self, x):
        x = self.revgrad(x) # important hyperparameter, the scale, 
                                     # tells by how much the classifier is punished
        x = self.gauss(x)
        return x

Setting up the mass and pts for the adversarial network

In [None]:
ms = np.array(jet_ms).reshape(len(jet_ms), 1)
pts = np.array(np.log(jet_pts)).reshape(len(jet_pts), 1)

Creating the adversarial dataset

In [None]:
def create_adversary_trainset(pt, mass):
    graphs = [Data(x=torch.tensor([p], dtype=torch.float), y=torch.tensor([m], dtype=torch.float)) for p, m in zip(pt, mass)]
    return graphs

In [None]:
class ConcatDataset(Dataset):
    def __init__(self, datasetA, datasetB):
        self.datasetA = datasetA
        self.datasetB = datasetB
        
    def __getitem__(self, index):
        xA = self.datasetA[index]
        xB = self.datasetB[index]
        return xA, xB
    
    def __len__(self):
        return len(self.datasetA)

In [None]:
device = torch.device('cuda') # Usually gpu 4 worked best, it had the most memory available
clsf.to(device)
adv.to(device)
optimizer = torch.optim.Adam(adv.parameters(), lr=0.0005)
#optimizer = torch.optim.Adam(list(clsf.parameters()) + list(adv.parameters()), lr=0.0005)

def train(epoch):
    clsf.eval()
    adv.train()
    loss_all = 0
    for data in adv_loader:
        cl_data = data[0].to(device)
        adv_data = data[1].to(device)
        new_y = torch.reshape(cl_data.y, (int(list(cl_data.y.shape)[0]),1))
        mask_bkg = new_y.lt(0.5)
        optimizer.zero_grad()
        cl_out = clsf(cl_data)
        #print(torch.reshape(cl_out, (len(cl_out), 1)), torch.reshape(cl_out, (len(cl_out), 1)).shape)
        #print(adv_data.x, adv_data.x.shape)
        adv_inp = torch.cat((torch.reshape(cl_out[mask_bkg], (len(cl_out[mask_bkg]), 1)), torch.reshape(adv_data.x[mask_bkg], (len(adv_data.x[mask_bkg]), 1))), 1)
        #print(adv_inp.shape)
        pi, sigma, mu = adv(adv_inp)
        #cl_out = clsf(cl_data)
        loss2 = mdn_loss(pi, sigma, mu, torch.reshape(adv_data.y[mask_bkg], (len(cl_out[mask_bkg]), 1)))
        loss2.backward()
        loss_all += cl_data.num_graphs * loss2.item()
        optimizer.step()
    return loss_all / len(dataset)

@torch.no_grad()
def get_accuracy(loader):
    #remember to change this when evaluating combined model
    clsf.eval()
    correct = 0
    for data in loader:
        cl_data = data[0].to(device)
        #adv_data = data[1].to(device)
        new_y = torch.reshape(cl_data.y, (int(list(cl_data.y.shape)[0]),1))
        pred = clsf(cl_data).max(dim=1)[1]
        correct += pred.eq(new_y[:,1]).sum().item()
    return correct / len(loader.dataset)
    
@torch.no_grad()
def get_scores(loader):
    clsf.eval()
    total_output = np.array([[1,1]])
    for data in loader:
        cl_data = data[0].to(device)
        pred = clsf(cl_data)
        total_output = np.append(total_output, pred.cpu().detach().numpy(), axis=0)
    return total_output[1:]
    

print("Training adversary whilst keeping classifier the same.")

for epoch in range(1, 51): # this may need to be bigger
    loss = train(epoch)
   # train_acc = get_accuracy(adv_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: - '.format(epoch, loss))

In [None]:
for param in clsf.parameters():
    param.require_grads = True

device = torch.device('cuda') # Usually gpu 4 worked best, it had the most memory available
clsf.to(device)
adv.to(device)
optimizer_cl = torch.optim.Adam(clsf.parameters(), lr=0.01)
optimizer_adv = torch.optim.Adam(adv.parameters(), lr=0.00001)
#optimizer = torch.optim.Adam(list(clsf.parameters()) + list(adv.parameters()), lr=0.0005)


def train(epoch):
    clsf.train()
    adv.train()
    loss_all = 0
    for data in adv_loader:
        cl_data = data[0].to(device)
        adv_data = data[1].to(device)
        new_y = torch.reshape(cl_data.y, (int(list(cl_data.y.shape)[0]),1))
        mask_bkg = new_y.lt(0.5)
        optimizer_cl.zero_grad()
        optimizer_adv.zero_grad()
        cl_out = clsf(cl_data)
        adv_inp = torch.cat((torch.reshape(cl_out[mask_bkg], (len(cl_out[mask_bkg]), 1)), torch.reshape(adv_data.x[mask_bkg], (len(adv_data.x[mask_bkg]), 1))), 1)
        pi, sigma, mu = adv(adv_inp)
        loss1 = F.binary_cross_entropy(cl_out, new_y)
        loss2 = mdn_loss(pi, sigma, mu, torch.reshape(adv_data.y[mask_bkg], (len(adv_data.y[mask_bkg]), 1)))
        loss = loss1 + loss2
        loss.backward()
        loss_all += cl_data.num_graphs * loss.item()
        optimizer_cl.step()
        optimizer_adv.step()
    return loss_all / len(dataset)

print("Started training together!")

for epoch in range(1, 21): # this may need to be bigger
    loss = train(epoch)
    #train_acc = get_accuracy(adv_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train Acc:'.format(epoch, loss))

path = "/your/path/to/models/lundnet_2opt_2_5.pt"
torch.save(clsf.state_dict(), path)