In [1]:
import torch
import sys
sys.path.append('../')
from Models.ResNet import ResnetConfig, ResnetClassifier
from Data.Data import DiffractionDataset

In [2]:
ds=DiffractionDataset("../Data/SpaceGroup_Data.pt", categorical='Space Group')

In [3]:
#Semi-supervised Model using 50% of the data for labeled training
state_dict=torch.load("../Models/Space_Group_model_SGAN.pth")
config = ResnetConfig(
        input_dim = 1,
        output_dim = 144,
        res_dims=[32, 64, 64, 64],
        res_kernel=[5, 7, 17, 13],
        res_stride=[4, 4, 5, 3],
        num_blocks=[2, 2, 2, 2],
        first_kernel_size = 13,
        first_stride = 1,
        first_pool_kernel_size = 7,
        first_pool_stride = 7,
    )
net=ResnetClassifier(config)
net.eval().float()
net.load_state_dict(state_dict['discriminator'])

<All keys matched successfully>

In [4]:
logits=net(ds.data).logits
predictions=torch.flatten(torch.argmax(logits, dim=-1))
print(predictions)
print(ds.labels)

tensor([ 70,  10,  97,  49,   1,   0, 139,  11,  53, 119])
tensor([ 70,  25,  97,  49,  42,   1, 139,  10,  53, 119])


In [5]:
ds.compare(predictions)

Index          True Label          Prediction
  0                  123                  123
  1                   41                   12
  2                  160                  160
  3                   70                   70
  4                   62                    2
  5                    2                    1
  6                  225                  225
  7                   12                   13
  8                   74                   74
  9                  194                  194


In [6]:
accuracy=(torch.sum((predictions==ds.labels))/len(predictions))
print(accuracy)

tensor(0.6000)


In [7]:
top_5_preds=torch.topk(logits, dim=-1, k=5).indices
print(top_5_preds)
labels=torch.unsqueeze(ds.labels,dim=-1)
labels=torch.cat((labels, labels, labels, labels, labels),-1)
labels=torch.unsqueeze(labels,dim=1)
top_5_acc=torch.sum(top_5_preds==labels)/labels.shape[0]
print(top_5_acc)

tensor([[[ 70,  31,  61, 136,  75]],

        [[ 10,   9,  11,   6,   1]],

        [[ 97,  87,   0,  47,  62]],

        [[ 49,  83,  79,  91,  32]],

        [[  1,  12,  10,  42,  13]],

        [[  0,   1,   5,  12,  27]],

        [[139, 132,   0, 103, 136]],

        [[ 11,  89,  76,  50,  96]],

        [[ 53,  10,  13,  50,  12]],

        [[119, 100, 103, 112, 110]]])
tensor(0.8000)


In [8]:
#Supervised Model using 90% of the data
supervised=ResnetClassifier(config)
supervised.eval().float()
state_dict=torch.load("../Models/Space_Group_model_supervised.pth")
supervised.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [9]:
supervised_logits=supervised(ds.data).logits
supervised_predictions=torch.flatten(torch.argmax(supervised_logits, dim=-1))
print(supervised_predictions)
print(ds.labels)

tensor([ 70,   6,  97,  49,   1,   1, 139,  71,  53, 119])
tensor([ 70,  25,  97,  49,  42,   1, 139,  10,  53, 119])


In [10]:
accuracy=(torch.sum((supervised_predictions==ds.labels))/len(supervised_predictions))
print(accuracy)

tensor(0.7000)


In [11]:
supervised_top_5_preds=torch.topk(supervised_logits, dim=-1, k=5).indices
print(supervised_top_5_preds)
labels=torch.unsqueeze(ds.labels,dim=-1) 
labels=torch.cat((labels, labels, labels, labels, labels),-1)
labels=torch.unsqueeze(labels,dim=1)
supervised_top_5_acc=torch.sum(supervised_top_5_preds==labels)/labels.shape[0]
print(supervised_top_5_acc)

tensor([[[ 70,  31,  45,  61,  82]],

        [[  6,  10,   3,  12,  27]],

        [[ 97,   0,  87,  94,   6]],

        [[ 49,  91, 130,  89,   7]],

        [[  1,  42,  12,  13,  59]],

        [[  1,  11,  13,   9,   0]],

        [[139, 132,  97,  50, 127]],

        [[ 71,  57,  76,  87, 126]],

        [[ 53,  30,  50,  28,  67]],

        [[119, 100, 101, 115, 108]]])
tensor(0.8000)


In [12]:
ds.compare(supervised_predictions)

Index          True Label          Prediction
  0                  123                  123
  1                   41                    8
  2                  160                  160
  3                   70                   70
  4                   62                    2
  5                    2                    2
  6                  225                  225
  7                   12                  125
  8                   74                   74
  9                  194                  194
