In [1]:
import torch
import sys
sys.path.append("../")
from Models.ResNet import ResnetConfig, ResnetClassifier
from Data.Data import DiffractionDataset

In [2]:
ds=DiffractionDataset("../Data/BravaisLattice_Data.pt", categorical='Bravais Lattice')

In [3]:
#Semi-supervised Model using 50% of the data for labeled training
state_dict=torch.load("../Models/Bravais_model_SGAN.pth")
config = ResnetConfig(
        input_dim = 1,
        output_dim = 14,
        res_dims=[32, 64, 64, 64],
        res_kernel=[5, 7, 17, 13],
        res_stride=[4, 4, 5, 3],
        num_blocks=[2, 2, 2, 2],
        first_kernel_size = 13,
        first_stride = 1,
        first_pool_kernel_size = 7,
        first_pool_stride = 7,
    )
net=ResnetClassifier(config)
net.eval().float()
net.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [4]:
logits=net(ds.data).logits
predictions=torch.flatten(torch.argmax(logits, dim=-1))
print(predictions)
print(ds.labels)

tensor([ 9,  0,  3,  5,  9, 13,  3,  0,  3, 11])
tensor([ 9,  0,  3, 12,  9,  4,  3,  0,  3, 11])


In [5]:
ds.compare(predictions)

Index          True Label          Prediction
  0     orthorhombic (P)     orthorhombic (P)
  1            cubic (F)            cubic (F)
  2        hexagonal (P)        hexagonal (P)
  3       tetragonal (P)       monoclinic (P)
  4     orthorhombic (P)     orthorhombic (P)
  5       monoclinic (C)        triclinic (P)
  6        hexagonal (P)        hexagonal (P)
  7            cubic (F)            cubic (F)
  8        hexagonal (P)        hexagonal (P)
  9       tetragonal (I)       tetragonal (I)


In [6]:
accuracy=(torch.sum((predictions==ds.labels))/len(predictions))
print(accuracy)

tensor(0.8000)


In [7]:
#Supervised Model using 90% of the data
supervised=ResnetClassifier(config)
supervised.eval().float()
state_dict=torch.load("../Models/Bravais_model_supervised.pth")
supervised.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [8]:
logits=supervised(ds.data).logits
supervised_predictions=torch.flatten(torch.argmax(logits, dim=-1))
print(supervised_predictions)
print(ds.labels)

tensor([12,  0,  3, 11,  9, 13,  3,  0,  3, 11])
tensor([ 9,  0,  3, 12,  9,  4,  3,  0,  3, 11])


In [9]:
ds.compare(supervised_predictions)

Index          True Label          Prediction
  0     orthorhombic (P)       tetragonal (P)
  1            cubic (F)            cubic (F)
  2        hexagonal (P)        hexagonal (P)
  3       tetragonal (P)       tetragonal (I)
  4     orthorhombic (P)     orthorhombic (P)
  5       monoclinic (C)        triclinic (P)
  6        hexagonal (P)        hexagonal (P)
  7            cubic (F)            cubic (F)
  8        hexagonal (P)        hexagonal (P)
  9       tetragonal (I)       tetragonal (I)


In [10]:
accuracy=(torch.sum((supervised_predictions==ds.labels))/len(predictions))
print(accuracy)

tensor(0.7000)
