In [1]:
import torch
import sys
sys.path.append("../")
from Models.ResNet import ResnetClassifier
from Data.Data import DiffractionDataset

In [2]:
#Load dataset models
ds=DiffractionDataset("../Data/BravaisLattice_Data.pt", categorical='Bravais Lattice')
resnet=ResnetClassifier(num_classes=14)
resnet.eval()
models=torch.load('../Models/BravaisModels.pth')

In [3]:
#Load Semi-supervised Model using 50% of the data for labeled training
resnet.load_state_dict(models['SGAN'])

<All keys matched successfully>

In [4]:
#Evaluate the model and print the accuracy
sgan_output=resnet(ds.data)
print("Predictions:" ,sgan_output.predictions)
print("Labels:     " ,ds.labels)
print("Accuracy:" , sgan_output.accuracy(ds.labels), "%")

Predictions: tensor([ 9,  0,  3,  5,  9, 13,  3,  0,  3, 11])
Labels:      tensor([ 9,  0,  3, 12,  9,  4,  3,  0,  3, 11])
Accuracy: 80.0 %


In [5]:
#Compare the predictions of the SGAN to the labels
ds.compare(sgan_output.predictions)

Index          True Label          Prediction
  0     orthorhombic (P)     orthorhombic (P)
  1            cubic (F)            cubic (F)
  2        hexagonal (P)        hexagonal (P)
  3       tetragonal (P)       monoclinic (P)
  4     orthorhombic (P)     orthorhombic (P)
  5       monoclinic (C)        triclinic (P)
  6        hexagonal (P)        hexagonal (P)
  7            cubic (F)            cubic (F)
  8        hexagonal (P)        hexagonal (P)
  9       tetragonal (I)       tetragonal (I)


In [6]:
#Load the Supervised Model using 90% of the data
resnet.load_state_dict(models['Supervised'])

<All keys matched successfully>

In [7]:
#Evaluate the model and print the accuracy
supervised_output=resnet(ds.data)
print("Predictions:", supervised_output.predictions)
print("Labels:     ",ds.labels)
print("Accuracy:",supervised_output.accuracy(ds.labels),"%")

Predictions: tensor([12,  0,  3, 11,  9, 13,  3,  0,  3, 11])
Labels:      tensor([ 9,  0,  3, 12,  9,  4,  3,  0,  3, 11])
Accuracy: 70.0 %


In [8]:
#Compare the predictions of the supervised model to the labels
ds.compare(supervised_output.predictions)

Index          True Label          Prediction
  0     orthorhombic (P)       tetragonal (P)
  1            cubic (F)            cubic (F)
  2        hexagonal (P)        hexagonal (P)
  3       tetragonal (P)       tetragonal (I)
  4     orthorhombic (P)     orthorhombic (P)
  5       monoclinic (C)        triclinic (P)
  6        hexagonal (P)        hexagonal (P)
  7            cubic (F)            cubic (F)
  8        hexagonal (P)        hexagonal (P)
  9       tetragonal (I)       tetragonal (I)


In [9]:
#Compare the SGAN predictions to the Supervised model predictions
ds.compare(sgan_output.predictions, supervised_output.predictions, ['SGAN', 'Supervised'])

Index          True Label                SGAN          Supervised
  0     orthorhombic (P)     orthorhombic (P)       tetragonal (P)
  1            cubic (F)            cubic (F)            cubic (F)
  2        hexagonal (P)        hexagonal (P)        hexagonal (P)
  3       tetragonal (P)       monoclinic (P)       tetragonal (I)
  4     orthorhombic (P)     orthorhombic (P)     orthorhombic (P)
  5       monoclinic (C)        triclinic (P)        triclinic (P)
  6        hexagonal (P)        hexagonal (P)        hexagonal (P)
  7            cubic (F)            cubic (F)            cubic (F)
  8        hexagonal (P)        hexagonal (P)        hexagonal (P)
  9       tetragonal (I)       tetragonal (I)       tetragonal (I)
