# Analyse Transformer Results

19th May - Notebook for looking at spacegroup number and string predictions from the transformer.

Accuracy

In [1]:
from typing import List
import pandas as pd

VALID_BRAVAIS_LATTICES = ['aP', 'mP', 'mS', 'mA', 'mB', 'mC','oP', 'oS', 'oF', 'oI', 'tP', 'tI', 'hP','hR','cP', 'cF','cI']


def get_rank(row, base, max_rank, tgt_column='target'):
    for i in range(1, max_rank+1):
        if row[tgt_column] == row['{}{}'.format(base, i)]:
            return i
    return 0

def read_targets(targets:str = 'targets.txt',
                 ignore_last_number: bool = False):
    if ignore_last_number:
        with open(targets, 'r') as f:
            targets = [''.join(line.strip().split(' ')[:-1])
                       for line in f.readlines()]
    else:
        with open(targets, 'r') as f:
            targets = [''.join(line.strip().split(' '))
                       for line in f.readlines()]
            
    return targets

def read_preds(preds:str = 'predictions.txt',
               beam_size:int = 5,
              ignore_last_number: bool = False):
    
    predictions = [[] for i in range(beam_size)]

    with open(preds, 'r') as f:

        if ignore_last_number:
            for i, line in enumerate(f.readlines()):

                predictions[i % beam_size].append(
                    ''.join(line.strip().split(' ')[:-1]))
        else:
            for i, line in enumerate(f.readlines()):

                predictions[i % beam_size].append(
                    ''.join(line.strip().split(' ')))
                
    return predictions

def score_accuracy(targets:List, predictions:List, beam_size:int = 5):
    test_df = pd.DataFrame(targets)
    test_df.columns = ['target']
    
    for i, preds in enumerate(predictions):
        test_df['prediction_{}'.format(i + 1)] = preds

    test_df['rank'] = test_df.apply(lambda row: get_rank(
        row, 'prediction_', beam_size), axis=1)

    correct = 0
    for i in range(1, beam_size+1):
        correct += (test_df['rank'] == i).sum()

        print('Top-{}: {:.1f}%'.format(i, correct / len(test_df) * 100))
    return test_df

def read_bravais(aflow_string):
    return aflow_string.split(':')[1][:2]

def score_bravais(targets:List, predictions:List, beam_size:int = 5):
    test_df = pd.DataFrame(targets)
    test_df.columns = ['target']
    test_df['bravais'] = test_df.target.apply(read_bravais)
    print(test_df.bravais.value_counts())
    
    for i, preds in enumerate(predictions):
        test_df[f'prediction_{i+1}'] = preds
        test_df[f'bravais_{i+1}'] = test_df[f'prediction_{i+1}'].apply(read_bravais)
        
    test_df['rank'] = test_df.apply(lambda row: get_rank(
        row, 'bravais_', beam_size, tgt_column='bravais'), axis=1)

    correct = 0
    for i in range(1, beam_size+1):
        correct += (test_df['rank'] == i).sum()
        valid = len(test_df[test_df[f'bravais_{i}'].isin(VALID_BRAVAIS_LATTICES)])
        print(f'Top-{i}: {100*correct/len(test_df):.1f}%, valid bravais lattices: {100*valid/len(test_df):.1f}%')
    return test_df

def score_predictions(beam_size:int = 5,
                      tgt_path:str = 'targets.txt',
                      pred_path:str = 'predictions.txt',
                      inds:List = None,
                      ignore_last_number: bool = False,
                      bravais:bool = False):
    
    targets = read_targets(tgt_path, ignore_last_number)
    with open(pred_path, 'r') as f:
        targets = targets[:int(len(f.readlines())/beam_size)]

    predictions = read_preds(pred_path, beam_size, ignore_last_number)
    
    if inds is not None:
        old_len = len(targets)
        targets = [targets[i] for i in inds]
        predictions = [[predictions[j][i] for i in inds] for j in range(beam_size)]
        print(f'Number of datapoints in test subset: {len(targets)} ({100*len(targets)/old_len:.2f}%)')

    else:   
        print(f'Number of datapoints in test set: {len(targets)}')
    if bravais:
        test_df = score_bravais(targets, predictions, beam_size)
    else:
        test_df = score_accuracy(targets, predictions, beam_size)
    return test_df


### Spacegroup Number

In [30]:
data_dir = '/home/wjm41/ml_physics/smi2wyk/data'
dataset = 'smi2spgnum'
transformer_preds = f'{data_dir}/{dataset}/pred_step_57500.txt'

targets = f'{data_dir}/{dataset}/tgt-test.csv'
df_test = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False,)

Number of datapoints in test set: 33268
Top-1: 42.3%
Top-2: 66.7%
Top-3: 76.7%
Top-4: 84.6%
Top-5: 88.4%
Top-6: 90.4%
Top-7: 91.9%
Top-8: 93.0%
Top-9: 93.8%
Top-10: 94.4%


Let's look at the space group distribution:

In [31]:
import plotly.express as px

px.sunburst(df_test, path=['target'])

In [32]:
inds_not_14 = df_test.query('target != "14"').index.values
df_not_14 = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False, inds = inds_not_14)

Number of datapoints in test subset: 20268 (60.92%)
Top-1: 11.6%
Top-2: 46.9%
Top-3: 61.9%
Top-4: 74.7%
Top-5: 81.0%
Top-6: 84.3%
Top-7: 86.7%
Top-8: 88.5%
Top-9: 89.8%
Top-10: 90.9%


In [51]:
common_spacegroups = ["14", "19", "4", "2", "61", "33"]
inds_uncommon = df_test.query('~target.isin(@common_spacegroups)').index.values
df_uncommon = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False, inds = inds_uncommon)

Number of datapoints in test subset: 3420 (10.28%)
Top-1: 0.0%
Top-2: 0.2%
Top-3: 1.1%
Top-4: 1.9%
Top-5: 8.0%
Top-6: 13.5%
Top-7: 23.9%
Top-8: 32.8%
Top-9: 39.8%
Top-10: 46.0%


Baseline - prediction by frequency

In [36]:
from collections import Counter


def score_predictions_by_frequency(beam_size:int = 5,
                                    tgt_path:str = 'targets.txt',
                                    src_path:str = 'tgt-train.txt',
                                    inds:List = None,
                                    ignore_last_number: bool = False):
    
    targets = read_targets(tgt_path, ignore_last_number)

    targets_in_training_set = read_targets(src_path, ignore_last_number)
    occurence_count = Counter(targets_in_training_set) 
    top_n_most_frequent = occurence_count.most_common(n=beam_size)
    predictions = [[str(top_n_most_frequent[i][0])]*len(targets) for i in range(beam_size)]
    # print(predictions)
    if inds is not None:
        old_len = len(targets)
        targets = [targets[i] for i in inds]
        predictions = [[predictions[j][i] for i in inds] for j in range(beam_size)]
        print(f'Number of datapoints in test subset: {len(targets)} ({100*len(targets)/old_len:.2f}%)')

    else:   
        print(f'Number of datapoints in test set: {len(targets)}')
    test_df = score_accuracy(targets, predictions, beam_size)
    return test_df


In [46]:
data_dir = '/home/wjm41/ml_physics/smi2wyk/data'
dataset = 'smi2spgnum'
transformer_preds = f'{data_dir}/{dataset}/pred_step_57500.txt'

targets = f'{data_dir}/{dataset}/tgt-test.csv'
train_targets = f'{data_dir}/{dataset}/tgt-train.csv'

df_test = score_predictions_by_frequency(beam_size=10, tgt_path=targets, src_path=train_targets, ignore_last_number=False,)

Number of datapoints in test set: 33268
Top-1: 39.1%
Top-2: 61.1%
Top-3: 74.0%
Top-4: 83.7%
Top-5: 87.9%
Top-6: 89.7%
Top-7: 91.1%
Top-8: 92.3%
Top-9: 93.0%
Top-10: 93.7%


In [47]:
df_not_14 = score_predictions_by_frequency(beam_size=10, tgt_path=targets, src_path=train_targets, inds=inds_not_14, ignore_last_number=False,)

Number of datapoints in test subset: 20268 (60.92%)
Top-1: 0.0%
Top-2: 36.1%
Top-3: 57.4%
Top-4: 73.2%
Top-5: 80.1%
Top-6: 83.1%
Top-7: 85.4%
Top-8: 87.3%
Top-9: 88.5%
Top-10: 89.6%


In [49]:
common_spacegroups = ["14", "19", "4", "2", "61", "33"]
inds_uncommon = df_test.query('~target.isin(@common_spacegroups)').index.values
df_uncommon = score_predictions_by_frequency(beam_size=10, tgt_path=targets, src_path=train_targets, inds=inds_uncommon, ignore_last_number=False,)

Number of datapoints in test subset: 5437 (16.34%)
Top-1: 0.0%
Top-2: 0.0%
Top-3: 0.0%
Top-4: 0.0%
Top-5: 25.8%
Top-6: 37.1%
Top-7: 45.7%
Top-8: 52.6%
Top-9: 57.1%
Top-10: 61.2%


### Spacegroup String (untokenized)

In [23]:
data_dir = '/home/wjm41/ml_physics/smi2wyk/data'
dataset = 'smi2spgstr'
transformer_preds = f'{data_dir}/{dataset}/pred_step_57500.txt'

targets = f'{data_dir}/{dataset}/tgt-test.csv'
df_test = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False,)

Number of datapoints in test set: 33268
Top-1: 31.0%
Top-2: 52.3%
Top-3: 67.6%
Top-4: 76.5%
Top-5: 83.4%
Top-6: 87.1%
Top-7: 88.9%
Top-8: 90.5%
Top-9: 91.6%
Top-10: 92.5%


In [27]:
df_test.target.value_counts().iloc[:6]

P21/c      7711
P-1        7316
P21/n      4832
P212121    4310
P21        3196
Pbca       1352
Name: target, dtype: int64

In [28]:
common_spacegroups = ["P21/c", "P-1", "P21/n", "P212121", "P21", "Pbca"]
inds_uncommon = df_test.query('~target.isin(@common_spacegroups)').index.values
df_uncommon = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False, inds = inds_uncommon)

Number of datapoints in test subset: 4551 (13.68%)
Top-1: 0.0%
Top-2: 0.0%
Top-3: 0.9%
Top-4: 2.1%
Top-5: 3.5%
Top-6: 8.2%
Top-7: 20.5%
Top-8: 31.1%
Top-9: 38.8%
Top-10: 45.1%


### Spacegroup String (tokenized)

In [20]:
data_dir = '/home/wjm41/ml_physics/smi2wyk/data'
dataset = 'smi2spgstrtok'
transformer_preds = f'{data_dir}/{dataset}/pred_step_57500.txt'

targets = f'{data_dir}/{dataset}/tgt-test.csv'
df_test = score_predictions(beam_size=10, tgt_path=targets, pred_path=transformer_preds, ignore_last_number=False,)

Number of datapoints in test set: 33268
Top-1: 30.8%
Top-2: 51.2%
Top-3: 65.6%
Top-4: 73.0%
Top-5: 77.2%
Top-6: 79.3%
Top-7: 80.7%
Top-8: 81.9%
Top-9: 82.7%
Top-10: 83.4%
