In [1]:
import numpy as np
import argparse
import random
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.utils.data
from scipy.io import loadmat
from model.GMVAE import *
import matplotlib.pyplot as plt
import os
import glob

In [6]:
# Directory containing the files
folder_path = 'checkpoint'

# Construct the pattern to match files ending with '.pth'
file_pattern = os.path.join(folder_path, '*.pth')

# Use glob to find files matching the pattern
matching_files = glob.glob(file_pattern)

# Print the list of matching files
print("List of files ending with '.pth':")
for file in matching_files:
    print(file)
    checkpoint = torch.load(file, map_location=torch.device('cpu'))
    args = params = checkpoint['config']
    print(params.desc)


List of files ending with '.pth':
checkpoint/model_3.pth
> [0;32m/tmp/ipykernel_2473115/3387544945.py[0m(12)[0;36m<module>[0;34m()[0m
[0;32m     10 [0;31m[0;31m# Print the list of matching files[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0mprint[0m[0;34m([0m[0;34m"List of files ending with '.pth':"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 12 [0;31m[0;32mfor[0m [0mfile[0m [0;32min[0m [0mmatching_files[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m    [0mprint[0m[0;34m([0m[0mfile[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m    [0mcheckpoint[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mload[0m[0;34m([0m[0mfile[0m[0;34m,[0m [0mmap_location[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mdevice[0m[0;34m([0m[0;34m'cpu'[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
*** TypeError: 'Namespace' object is not subscriptable
Namespace(file=None, dataset='cifar10', seed=0, cu

In [None]:
model_name = 'model_3'

checkpoint = torch.load(f"checkpoint/{model_name}.pth", map_location=torch.device('cpu'))
args = params = checkpoint['config']

In [None]:
## Random Seed
SEED = args.seed
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if args.cuda:
  torch.cuda.manual_seed(SEED)

In [None]:
dataset_dict = {
    'mnist': datasets.MNIST,
    'cifar10': datasets.CIFAR10,
    'cifar100': datasets.CIFAR100
}

print(f"Loading {args.dataset} dataset...")
# Download or load downloaded MNIST dataset
train_dataset = dataset_dict[args.dataset](f'./data/{args.dataset}', train=True, download=True, transform=transforms.ToTensor())
test_dataset = dataset_dict[args.dataset](f'./data/{args.dataset}', train=False, transform=transforms.ToTensor())


In [None]:
def partition_dataset(n, proportion=0.8):
  train_num = int(n * proportion)
  indices = np.random.permutation(n)
  train_indices, val_indices = indices[:train_num], indices[train_num:]
  return train_indices, val_indices

if args.train_proportion == 1.0:
  # we use all train dataset without partitioning
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size_val, shuffle=False)
  val_loader = test_loader
else:
  # partition dataset according to train_proportion
  train_indices, val_indices = partition_dataset(len(train_dataset), args.train_proportion)
  # create data loaders for train, validation and test datasets
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=SubsetRandomSampler(train_indices))
  val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size_val, sampler=SubsetRandomSampler(val_indices))
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size_val, shuffle=False)


## Train Model

In [None]:
print(f"params: {params}")
loss_history = checkpoint['loss_history']
print(loss_history.keys())
plt.plot(np.array(loss_history['train_history_nmi']), label='Train')
plt.plot(np.array(loss_history['val_history_nmi']), label='Validation')

# plt.plot(loss_history, label='KLD')
plt.ylim([0, None])
plt.title(f"Training: NMI score")
plt.xlabel("Epochs")
# plt.ylabel("KLD loss")
plt.legend()
plt.show()

In [None]:
# Model Initialization
params.cuda=False
gmvae = GMVAE(params)
gmvae.network.load_state_dict(checkpoint['model_state'])

## Test Data

In [None]:
accuracy, nmi = gmvae.test(test_loader)

print("Testing phase...")
print("Accuracy: %.5lf,  NMI: %.5lf" % (accuracy, nmi) )

## Visualization of the feature latent space

In [None]:
# get feature representations
test_features, test_labels = gmvae.latent_features(test_loader, True)

In [None]:
# import TSNE from scikit-learn library
from sklearn.manifold import TSNE

# reduce dimensionality to 2D, we consider a subset of data because TSNE
# is a slow algorithm
tsne_features = TSNE(n_components=2).fit_transform(test_features[:1000,])

In [None]:
fig = plt.figure(figsize=(10, 6))

plt.scatter(tsne_features[:, 0], tsne_features[:, 1], c=test_labels[:tsne_features.shape[0]], marker='o',
            edgecolor='none', cmap=plt.cm.get_cmap('jet', 10), s = 10)
plt.grid(False)
plt.axis('off')
plt.colorbar()