In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import warnings 
warnings.filterwarnings('ignore')
from util import *
from models import *
from classification import *
from sklearn.manifold import TSNE

In [None]:
def arg_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Wallpaper', help='Dataset to use (Taiji or Wallpaper)')
    parser.add_argument('--save_dir', type=str, default='results', help='Directory to save results')
    parser.add_argument('--data_root', type=str, default='data', help='Directory to save results')
    parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda or cpu)')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('--seed', type=int, default=2023, help='Random seed')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--log_dir', type=str, default='logs', help='Directory to save logs')
    parser.add_argument('--log_interval', type=int, default=1, help='Print loss every log_interval epochs, feel free to change')
    parser.add_argument('--train' , action='store_true', help='Train the model')
    parser.add_argument('--save_model', action='store_true', help='Save the model')
    parser.add_argument('--baseline', action = 'store_true', help = 'Basline model configuiration')
    parser.add_argument('--model1', action = 'store_true', help = 'Model-1 configuiration, built on top of baseline')
    parser.add_argument('--model2', action = 'store_true', help = 'Model-2 configuiration, built on top of baseline')
    parser.add_argument('--num_subs', type=int, default=10, help='Number of subjects to train and test on')
    parser.add_argument('--fp_size', type=str, default='lod4', help='Size of the fingerprint to use (lod4 or full)')
    parser.add_argument('--img_size', type=int, default=128, help='Size of image to be resized to')
    parser.add_argument('--test_set', type=str, default='test', help='Test set to use (test or test_challenge)')
    parser.add_argument('--aug_train', action='store_true', help='Use augmented training data')
    return parser.parse_args(args = [])
args = arg_parse()

# Part 1

In [None]:
num_subs = args.num_subs
num_forms = 46
sub_train_acc = np.zeros(num_subs)
sub_class_train = np.zeros((num_subs, num_forms))
sub_test_acc = np.zeros(num_subs)
sub_class_test = np.zeros((num_subs, num_forms))
overall_train_mat = np.zeros((num_forms, 1))
overall_test_mat = np.zeros((num_forms, 1))

if not os.path.exists(os.path.join(args.save_dir, 'Taiji', args.fp_size)):
    os.makedirs(os.path.join(args.save_dir, 'Taiji', args.fp_size))
if args.device == 'cuda':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device('cpu')
torch.manual_seed(args.seed)
np.random.seed(args.seed)

In [None]:
print('\n\nTraining subject: {}'.format(1))

train_data = TaijiData(data_dir='data', subject= 1, split='train')
test_data = TaijiData(data_dir ='data', subject= 1, split='test')
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

model = MLP(input_dim = train_data.data_dim, hidden_dim = 1024, output_dim = num_forms).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()

# Train + test the model
model, train_losses, per_epoch_train_acc, train_preds, train_targets \
                = train(model, train_loader, optimizer, criterion, args.num_epochs, args.log_interval, device)
test_loss, test_acc, test_pred, test_targets = test(model, test_loader, device, criterion)

# Print accs to three decimal places
sub_train_acc[0] = per_epoch_train_acc[-1]
sub_test_acc[0] = test_acc
print(f'Subject {1} Train Accuracy: {per_epoch_train_acc[-1]*100:.3f}')
print(f'Subject {1} Test Accuracy: {test_acc*100:.3f}')

# Part 2

In [None]:
num_classes = 17
if args.device == 'cuda':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device('cpu')

data_root = os.path.join(args.data_root, 'Wallpaper')
if not os.path.exists(os.path.join(args.save_dir, 'Wallpaper', args.test_set)):
    os.makedirs(os.path.join(args.save_dir, 'Wallpaper', args.test_set))

torch.manual_seed(args.seed)
np.random.seed(args.seed)

transform = transforms.Compose([
    transforms.Resize((args.img_size, args.img_size)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])
train_dataset = ImageFolder(os.path.join(data_root, 'train'), transform=transform)
test_dataset = ImageFolder(os.path.join(data_root, args.test_set), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


print(f"Training on {len(train_dataset)} images, testing on {len(test_dataset)} images.")

model = CNN(input_channels = 1, img_size = args.img_size, num_classes = num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()

model, per_epoch_loss, per_epoch_acc, train_preds, train_targets = train(model, train_loader, optimizer, criterion, args.num_epochs, 
                                                                         args.log_interval, device )
test_loss, test_acc, test_preds, test_targets = test(model, test_loader, device, criterion)

classes_train, overall_train_mat = get_stats(train_preds, train_targets, num_classes)
classes_test, overall_test_mat = get_stats(test_preds, test_targets, num_classes)

print(f"Training Accuracy Standard Deviation over 17 Classes: {np.std(classes_train):.5f}")
print(f"Test Accuracy Standard Deviation over 17 Classes: {np.std(classes_test):.5f}")


print(f'\n\nTrain accuracy: {per_epoch_acc[-1]*100:.3f}')
print(f'Test accuracy: {test_acc*100:.3f}')

In [None]:
model.eval()
labels = ["CM", "CMM", "P1", "P2", "P3", "P31M", "P2M1", "P4", "P4G", "P4M", "P6", "P6M", "PG", "PGG", "PM", "PMG", "PMM"]
outputs = []
for i in range(0, len(test_loader.dataset), 200):
    layer_out = []
    img = torch.unsqueeze(test_loader.dataset[i][0].to(device), 0)
    for layer in list(model.children())[0]:
        img = layer(img)
        layer_out.append(img)
    outputs.append(layer_out)
fig = plt.figure(figsize = (12, 12))
for i in range(len(outputs)):
    img = torch.mean(torch.squeeze(outputs[i][1], 0), 0).cpu().detach().numpy()
    plt.subplot(4, 5, i+1)
    plt.imshow(img, cmap = "gray")
    plt.title(labels[i])
# plt.savefig("layer.png", bbox_inches='tight')

In [None]:
visualize_tsne(model, test_loader, args.device, "tsne.png")