In [24]:
import os
import torch
import numpy as np
import pandas as pd
import sys
os.chdir('~/projects/LLaVA')
current_path = os.getcwd()
sys.path.append(current_path + '/data-engine')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from utils import set_env
set_env(seed=1234)

from torch.utils.data import Dataset, DataLoader
from utils import ChartFeatureDataset
from models import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from utils import FocalLoss

In [25]:
def train(data, pretrained_features, feature_column_name, criterion, epochs=20, batch_size=2048):

    num_classes = len(data[feature_column_name].unique())
    train_data, test_data, train_features, test_features = train_test_split(
        data, pretrained_features, test_size=0.2, random_state=1234)
    train_dataset = ChartFeatureDataset(train_data, train_features, feature_column_name)
    test_dataset = ChartFeatureDataset(test_data, test_features, feature_column_name)
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
    classifier = MLPClassifier(
        input_dim=pretrained_features.shape[1], num_classes=num_classes).to(device)  

    
    optimizer = torch.optim.Adam(classifier.parameters(), lr=2e-3)
    
    
    
    for epoch in tqdm(range(epochs)):
        classifier.train()
        total_loss = 0
        train_preds, train_labels_list = [], []
        for features, labels in train_dataloader:
            labels = labels.to(device)
            features = features.to(device)
            optimizer.zero_grad()
            
            outputs = classifier(features)
            _, predicted = torch.max(outputs, 1)
            # loss = criterion(outputs, F.one_hot(labels, num_classes=num_classes).float())
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            train_preds.extend(predicted.cpu().numpy())
            train_labels_list.extend(labels.cpu().numpy())
            
        train_accuracy = accuracy_score(train_labels_list, train_preds)
  
        #eval
        classifier.eval()
        test_preds, test_labels_list = [], []
        with torch.no_grad():
            for features, labels in test_dataloader:
                labels = labels.to(device)
                features = features.to(device)

                outputs = classifier(features)
                _, predicted = torch.max(outputs, 1)
                test_preds.extend(predicted.cpu().numpy())
                test_labels_list.extend(labels.cpu().numpy())
        test_accuracy = accuracy_score(test_labels_list, test_preds)
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}, Train Accuracy: {train_accuracy}, Test Accuracy: {test_accuracy}")
        
        # Report
    print("Classification Report on Test Data:")
    print(classification_report(test_labels_list, test_preds, digits=4))
    return classifier

In [26]:
data = pd.read_csv('data-engine/precomputed_features/classification_annotation_30k_5_attributes.csv')
clip_features = np.load('data-engine/precomputed_features/annotated-dataset-clip-feature.npy')
resnet_features = np.load('data-engine/precomputed_features/annotated-dataset-resnet-feature.npy')
convnext_features = np.load('data-engine/precomputed_features/annotated-dataset-convnext-feature.npy')
feature_column_name = 'number-annotation' # label, Trend, Layout, number annotation, grouping

## Focal loss

In [9]:
feature_column_name = 'number-annotation'
criterion = FocalLoss()
features = convnext_features
classifier = train(data, features, feature_column_name, criterion)

  5%|▌         | 1/20 [00:04<01:23,  4.40s/it]

Epoch 1, Loss: 0.4237421475923978, Train Accuracy: 0.7919798556825015, Test Accuracy: 0.8242633794347565


 10%|█         | 2/20 [00:07<01:03,  3.55s/it]

Epoch 2, Loss: 0.2935067048439613, Train Accuracy: 0.8647399278412508, Test Accuracy: 0.8735718580877931


 15%|█▌        | 3/20 [00:10<00:54,  3.18s/it]

Epoch 3, Loss: 0.24952052189753607, Train Accuracy: 0.8904840649428744, Test Accuracy: 0.9054419723391461


 20%|██        | 4/20 [00:13<00:51,  3.20s/it]

Epoch 4, Loss: 0.20247919055131766, Train Accuracy: 0.926037282020445, Test Accuracy: 0.9305472038484667


 25%|██▌       | 5/20 [00:16<00:47,  3.14s/it]

Epoch 5, Loss: 0.16997661957373986, Train Accuracy: 0.9393415514131088, Test Accuracy: 0.9380637402285027


 30%|███       | 6/20 [00:19<00:44,  3.18s/it]

Epoch 6, Loss: 0.15730095024292284, Train Accuracy: 0.9446782922429344, Test Accuracy: 0.9431749849669272


 35%|███▌      | 7/20 [00:22<00:41,  3.22s/it]

Epoch 7, Loss: 0.14924879372119904, Train Accuracy: 0.9480983162958508, Test Accuracy: 0.9455802766085388


 40%|████      | 8/20 [00:25<00:36,  3.05s/it]

Epoch 8, Loss: 0.1422061977478174, Train Accuracy: 0.9499774503908599, Test Accuracy: 0.9496392062537583


 45%|████▌     | 9/20 [00:28<00:32,  2.97s/it]

Epoch 9, Loss: 0.1362383388555967, Train Accuracy: 0.952720986169573, Test Accuracy: 0.9491882140709561


 50%|█████     | 10/20 [00:31<00:31,  3.12s/it]

Epoch 10, Loss: 0.1318854454618234, Train Accuracy: 0.9547880336740829, Test Accuracy: 0.9514431749849669


 55%|█████▌    | 11/20 [00:35<00:29,  3.25s/it]

Epoch 11, Loss: 0.1277716406262838, Train Accuracy: 0.9558027660853878, Test Accuracy: 0.9526458208057726


 60%|██████    | 12/20 [00:38<00:25,  3.14s/it]

Epoch 12, Loss: 0.12767498252483514, Train Accuracy: 0.9566295850871919, Test Accuracy: 0.9530968129885748


 65%|██████▌   | 13/20 [00:41<00:22,  3.20s/it]

Epoch 13, Loss: 0.12243547920997326, Train Accuracy: 0.9575691521346963, Test Accuracy: 0.9523451593505713


 70%|███████   | 14/20 [00:44<00:18,  3.14s/it]

Epoch 14, Loss: 0.1190596205683855, Train Accuracy: 0.9590724594107035, Test Accuracy: 0.9514431749849669


 75%|███████▌  | 15/20 [00:48<00:16,  3.23s/it]

Epoch 15, Loss: 0.11775640226327456, Train Accuracy: 0.9590724594107035, Test Accuracy: 0.9518941671677691


 80%|████████  | 16/20 [00:51<00:12,  3.23s/it]

Epoch 16, Loss: 0.1137864005107146, Train Accuracy: 0.9614777510523151, Test Accuracy: 0.9546001202645821


 85%|████████▌ | 17/20 [00:54<00:09,  3.16s/it]

Epoch 17, Loss: 0.1115338745025488, Train Accuracy: 0.9620414912808178, Test Accuracy: 0.9496392062537583


 90%|█████████ | 18/20 [00:57<00:06,  3.17s/it]

Epoch 18, Loss: 0.10994671285152435, Train Accuracy: 0.9620039085989176, Test Accuracy: 0.9499398677089597


 95%|█████████▌| 19/20 [01:00<00:03,  3.15s/it]

Epoch 19, Loss: 0.10841299536136481, Train Accuracy: 0.962079073962718, Test Accuracy: 0.9559530968129886


100%|██████████| 20/20 [01:03<00:00,  3.18s/it]

Epoch 20, Loss: 0.10502622047295937, Train Accuracy: 0.9640333734215274, Test Accuracy: 0.9553517739025857
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9748    0.9707    0.9728      5470
           1     0.8672    0.8841    0.8756      1182

    accuracy                         0.9554      6652
   macro avg     0.9210    0.9274    0.9242      6652
weighted avg     0.9557    0.9554    0.9555      6652






In [27]:
feature_column_name = 'number-annotation'
criterion = FocalLoss()
features = clip_features
classifier = train(data, features, feature_column_name, criterion)

  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:02<00:51,  2.71s/it]

Epoch 1, Loss: 0.41647157302269566, Train Accuracy: 0.8117483463619964, Test Accuracy: 0.8377931449188214


 10%|█         | 2/20 [00:05<00:51,  2.89s/it]

Epoch 2, Loss: 0.30598413714995754, Train Accuracy: 0.8613574864702345, Test Accuracy: 0.8858989777510523


 15%|█▌        | 3/20 [00:08<00:49,  2.93s/it]

Epoch 3, Loss: 0.25622966770942396, Train Accuracy: 0.9032245941070355, Test Accuracy: 0.9036380036079374


 20%|██        | 4/20 [00:11<00:45,  2.84s/it]

Epoch 4, Loss: 0.2183984609750601, Train Accuracy: 0.9164161154539988, Test Accuracy: 0.9183704149128081


 25%|██▌       | 5/20 [00:14<00:41,  2.76s/it]

Epoch 5, Loss: 0.19433384675246018, Train Accuracy: 0.9270520144317499, Test Accuracy: 0.9254359591100421


 30%|███       | 6/20 [00:16<00:38,  2.72s/it]

Epoch 6, Loss: 0.18114982889248774, Train Accuracy: 0.9327645820805772, Test Accuracy: 0.9278412507516537


 35%|███▌      | 7/20 [00:19<00:34,  2.69s/it]

Epoch 7, Loss: 0.17192353308200836, Train Accuracy: 0.9383644016837042, Test Accuracy: 0.930847865303668


 40%|████      | 8/20 [00:22<00:32,  2.71s/it]

Epoch 8, Loss: 0.16521450189443734, Train Accuracy: 0.9415589296452195, Test Accuracy: 0.9329524954900782


 45%|████▌     | 9/20 [00:24<00:30,  2.74s/it]

Epoch 9, Loss: 0.16126694472936484, Train Accuracy: 0.9427991581479255, Test Accuracy: 0.934155141310884


 50%|█████     | 10/20 [00:27<00:27,  2.78s/it]

Epoch 10, Loss: 0.15561679234871498, Train Accuracy: 0.9452796151533374, Test Accuracy: 0.934155141310884


 55%|█████▌    | 11/20 [00:30<00:24,  2.75s/it]

Epoch 11, Loss: 0.15108836270295656, Train Accuracy: 0.9471211665664462, Test Accuracy: 0.9359591100420926


 60%|██████    | 12/20 [00:33<00:21,  2.73s/it]

Epoch 12, Loss: 0.1471388809956037, Train Accuracy: 0.9487748045700541, Test Accuracy: 0.9382140709561034


 65%|██████▌   | 13/20 [00:35<00:18,  2.70s/it]

Epoch 13, Loss: 0.14384164145359626, Train Accuracy: 0.950616355983163, Test Accuracy: 0.9386650631389056


 70%|███████   | 14/20 [00:37<00:15,  2.52s/it]

Epoch 14, Loss: 0.14015444482748324, Train Accuracy: 0.951668671076368, Test Accuracy: 0.9376127480457005


 75%|███████▌  | 15/20 [00:40<00:12,  2.44s/it]

Epoch 15, Loss: 0.1383003042294429, Train Accuracy: 0.9528337342152736, Test Accuracy: 0.9395670475045099


 80%|████████  | 16/20 [00:42<00:09,  2.35s/it]

Epoch 16, Loss: 0.1339765592263295, Train Accuracy: 0.9538860493084786, Test Accuracy: 0.9407696933253157


 85%|████████▌ | 17/20 [00:44<00:06,  2.30s/it]

Epoch 17, Loss: 0.13056842925456855, Train Accuracy: 0.9563665063138905, Test Accuracy: 0.938965724594107


 90%|█████████ | 18/20 [00:46<00:04,  2.30s/it]

Epoch 18, Loss: 0.12772534386469767, Train Accuracy: 0.9572684906794949, Test Accuracy: 0.940619362597715


 95%|█████████▌| 19/20 [00:49<00:02,  2.40s/it]

Epoch 19, Loss: 0.12552603162251985, Train Accuracy: 0.9580953096812989, Test Accuracy: 0.9407696933253157


100%|██████████| 20/20 [00:51<00:00,  2.59s/it]

Epoch 20, Loss: 0.12237609464388627, Train Accuracy: 0.9585838845460012, Test Accuracy: 0.9409200240529164
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9591    0.9695    0.9643      5470
           1     0.8513    0.8088    0.8295      1182

    accuracy                         0.9409      6652
   macro avg     0.9052    0.8891    0.8969      6652
weighted avg     0.9400    0.9409    0.9403      6652






In [28]:
feature_column_name = 'number-annotation'
criterion = FocalLoss()
features = resnet_features
classifier = train(data, features, feature_column_name, criterion)

  5%|▌         | 1/20 [00:03<00:57,  3.00s/it]

Epoch 1, Loss: 0.6761453449726105, Train Accuracy: 0.8149052916416115, Test Accuracy: 0.8400481058328322


 10%|█         | 2/20 [00:05<00:51,  2.89s/it]

Epoch 2, Loss: 0.3781271829054906, Train Accuracy: 0.8507215874924835, Test Accuracy: 0.8538785327720986


 15%|█▌        | 3/20 [00:08<00:46,  2.76s/it]

Epoch 3, Loss: 0.3352953058022719, Train Accuracy: 0.8674834636199639, Test Accuracy: 0.8767288033674083


 20%|██        | 4/20 [00:11<00:44,  2.77s/it]

Epoch 4, Loss: 0.3101224647118495, Train Accuracy: 0.8780817799158148, Test Accuracy: 0.8792844257366206


 25%|██▌       | 5/20 [00:13<00:39,  2.64s/it]

Epoch 5, Loss: 0.2926597916162931, Train Accuracy: 0.8836815995189417, Test Accuracy: 0.8831930246542393


 30%|███       | 6/20 [00:16<00:36,  2.59s/it]

Epoch 6, Loss: 0.27810933039738583, Train Accuracy: 0.8892062537582682, Test Accuracy: 0.8883042693926638


 35%|███▌      | 7/20 [00:18<00:33,  2.56s/it]

Epoch 7, Loss: 0.26674846273202163, Train Accuracy: 0.8924759470835839, Test Accuracy: 0.8943174984966927


 40%|████      | 8/20 [00:21<00:30,  2.57s/it]

Epoch 8, Loss: 0.25842749843230617, Train Accuracy: 0.8953322309079976, Test Accuracy: 0.8998797354179194


 45%|████▌     | 9/20 [00:23<00:28,  2.58s/it]

Epoch 9, Loss: 0.24867324874951288, Train Accuracy: 0.8995414912808178, Test Accuracy: 0.900631389055923


 50%|█████     | 10/20 [00:26<00:25,  2.59s/it]

Epoch 10, Loss: 0.23929487856534812, Train Accuracy: 0.903750751653638, Test Accuracy: 0.9070956103427541


 55%|█████▌    | 11/20 [00:29<00:24,  2.68s/it]

Epoch 11, Loss: 0.23164977362522712, Train Accuracy: 0.9079975947083584, Test Accuracy: 0.9093505712567649


 60%|██████    | 12/20 [00:31<00:21,  2.67s/it]

Epoch 12, Loss: 0.22470635519577906, Train Accuracy: 0.9107411304870715, Test Accuracy: 0.9098015634395671


 65%|██████▌   | 13/20 [00:34<00:18,  2.61s/it]

Epoch 13, Loss: 0.21813304951557747, Train Accuracy: 0.9137853277209862, Test Accuracy: 0.9064942874323512


 70%|███████   | 14/20 [00:37<00:15,  2.63s/it]

Epoch 14, Loss: 0.21698729350016668, Train Accuracy: 0.9137853277209862, Test Accuracy: 0.9081479254359591


 75%|███████▌  | 15/20 [00:40<00:13,  2.72s/it]

Epoch 15, Loss: 0.20998305082321167, Train Accuracy: 0.9174308478653037, Test Accuracy: 0.9155141310883945


 80%|████████  | 16/20 [00:43<00:11,  2.80s/it]

Epoch 16, Loss: 0.20777811797765586, Train Accuracy: 0.91874624173181, Test Accuracy: 0.9102525556223692


 85%|████████▌ | 17/20 [00:45<00:08,  2.86s/it]

Epoch 17, Loss: 0.2062671837898401, Train Accuracy: 0.918332832230908, Test Accuracy: 0.9153638003607938


 90%|█████████ | 18/20 [00:49<00:05,  2.93s/it]

Epoch 18, Loss: 0.2019736778277617, Train Accuracy: 0.920625375826819, Test Accuracy: 0.9189717378232111


 95%|█████████▌| 19/20 [00:51<00:02,  2.86s/it]

Epoch 19, Loss: 0.19795447473342603, Train Accuracy: 0.9232561635598316, Test Accuracy: 0.921677690920024


100%|██████████| 20/20 [00:54<00:00,  2.73s/it]

Epoch 20, Loss: 0.19450021707094634, Train Accuracy: 0.925210463018641, Test Accuracy: 0.9174684305472038
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9269    0.9766    0.9511      5470
           1     0.8560    0.6438    0.7349      1182

    accuracy                         0.9175      6652
   macro avg     0.8915    0.8102    0.8430      6652
weighted avg     0.9143    0.9175    0.9127      6652






## MSE loss

In [30]:
def train(data, pretrained_features, feature_column_name, criterion):
    # data = pd.read_csv('filtering/annotation_18_30k.csv')
    # You are expected to extract corresponding features with feature_extraction.ipynb
    
    num_classes = len(data[feature_column_name].unique())
    # print(data[feature_column_name].unique())

    train_data, test_data, train_features, test_features = train_test_split(
        data, pretrained_features, test_size=0.2, random_state=1234)
    train_dataset = ChartFeatureDataset(train_data, train_features, feature_column_name)
    test_dataset = ChartFeatureDataset(test_data, test_features, feature_column_name)
    train_dataloader = DataLoader(
        train_dataset, batch_size=2048, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(
        test_dataset, batch_size=2048, shuffle=False, num_workers=8)
    classifier = MLPClassifier(
        input_dim=pretrained_features.shape[1], num_classes=num_classes).to(device)  

    
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    
    epochs = 20
    
    for epoch in tqdm(range(epochs)):
        classifier.train()
        total_loss = 0
        train_preds, train_labels_list = [], []
        for features, labels in train_dataloader:
            labels = labels.to(device)
            features = features.to(device)
            optimizer.zero_grad()
            
            outputs = classifier(features)
            _, predicted = torch.max(outputs, 1)
            loss = criterion(outputs, F.one_hot(labels, num_classes=num_classes).float())
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            train_preds.extend(predicted.cpu().numpy())
            train_labels_list.extend(labels.cpu().numpy())
            
        train_accuracy = accuracy_score(train_labels_list, train_preds)
  
        #eval
        classifier.eval()
        test_preds, test_labels_list = [], []
        with torch.no_grad():
            for features, labels in test_dataloader:
                labels = labels.to(device)
                features = features.to(device)

                outputs = classifier(features)
                _, predicted = torch.max(outputs, 1)
                test_preds.extend(predicted.cpu().numpy())
                test_labels_list.extend(labels.cpu().numpy())
        test_accuracy = accuracy_score(test_labels_list, test_preds)
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}, Train Accuracy: {train_accuracy}, Test Accuracy: {test_accuracy}")
        
        # Report
    print("Classification Report on Test Data:")
    print(classification_report(test_labels_list, test_preds, digits=4))
    return classifier

In [18]:
criterion = nn.MSELoss()
feature_column_name = 'number-annotation' # label, Trend, Layout, number annotation, grouping
features = convnext_features
classifier = train(data, features, feature_column_name, criterion)

  5%|▌         | 1/20 [00:03<01:00,  3.18s/it]

Epoch 1, Loss: 0.17285589529917791, Train Accuracy: 0.8227600721587492, Test Accuracy: 0.8907095610342755


 10%|█         | 2/20 [00:05<00:52,  2.90s/it]

Epoch 2, Loss: 0.08631693285245162, Train Accuracy: 0.9076217678893566, Test Accuracy: 0.9076969332531569


 15%|█▌        | 3/20 [00:08<00:48,  2.85s/it]

Epoch 3, Loss: 0.070257078569669, Train Accuracy: 0.919986470234516, Test Accuracy: 0.9225796752856283


 20%|██        | 4/20 [00:11<00:45,  2.87s/it]

Epoch 4, Loss: 0.06273202798687495, Train Accuracy: 0.9298331328923632, Test Accuracy: 0.927690920024053


 25%|██▌       | 5/20 [00:14<00:42,  2.85s/it]

Epoch 5, Loss: 0.057956019846292645, Train Accuracy: 0.9367107636800962, Test Accuracy: 0.9349067949488875


 30%|███       | 6/20 [00:16<00:38,  2.77s/it]

Epoch 6, Loss: 0.054347946953315004, Train Accuracy: 0.9420475045099218, Test Accuracy: 0.9374624173180999


 35%|███▌      | 7/20 [00:19<00:36,  2.77s/it]

Epoch 7, Loss: 0.051680562587884754, Train Accuracy: 0.9461815995189417, Test Accuracy: 0.9388153938665064


 40%|████      | 8/20 [00:22<00:33,  2.82s/it]

Epoch 8, Loss: 0.04949346070106213, Train Accuracy: 0.94791040288635, Test Accuracy: 0.9442273000601323


 45%|████▌     | 9/20 [00:25<00:31,  2.84s/it]

Epoch 9, Loss: 0.047596266923042446, Train Accuracy: 0.9506539386650631, Test Accuracy: 0.9422730006013229


 50%|█████     | 10/20 [00:28<00:28,  2.86s/it]

Epoch 10, Loss: 0.0461246740932648, Train Accuracy: 0.9511425135297655, Test Accuracy: 0.9472339146121467


 55%|█████▌    | 11/20 [00:31<00:26,  2.90s/it]

Epoch 11, Loss: 0.04471543717842836, Train Accuracy: 0.952720986169573, Test Accuracy: 0.9508418520745641


 60%|██████    | 12/20 [00:34<00:22,  2.86s/it]

Epoch 12, Loss: 0.04366217524959491, Train Accuracy: 0.953961214672279, Test Accuracy: 0.9427239927841251


 65%|██████▌   | 13/20 [00:37<00:20,  2.94s/it]

Epoch 13, Loss: 0.04294638794202071, Train Accuracy: 0.9542994588093806, Test Accuracy: 0.9482862297053518


 70%|███████   | 14/20 [00:39<00:17,  2.84s/it]

Epoch 14, Loss: 0.041623789530534014, Train Accuracy: 0.9555772699939867, Test Accuracy: 0.9506915213469633


 75%|███████▌  | 15/20 [00:42<00:13,  2.80s/it]

Epoch 15, Loss: 0.04106529945364365, Train Accuracy: 0.9556900180396873, Test Accuracy: 0.952495490078172


 80%|████████  | 16/20 [00:45<00:11,  2.82s/it]

Epoch 16, Loss: 0.040564390329214245, Train Accuracy: 0.9559155141310884, Test Accuracy: 0.953547805171377


 85%|████████▌ | 17/20 [00:48<00:08,  2.82s/it]

Epoch 17, Loss: 0.039436387901122756, Train Accuracy: 0.9572309079975947, Test Accuracy: 0.9488875526157546


 90%|█████████ | 18/20 [00:50<00:05,  2.64s/it]

Epoch 18, Loss: 0.03889151605275961, Train Accuracy: 0.9572684906794949, Test Accuracy: 0.9490378833433554


 95%|█████████▌| 19/20 [00:53<00:02,  2.57s/it]

Epoch 19, Loss: 0.0384263015137269, Train Accuracy: 0.9580201443174985, Test Accuracy: 0.9485868911605532


100%|██████████| 20/20 [00:55<00:00,  2.77s/it]

Epoch 20, Loss: 0.037848459699979194, Train Accuracy: 0.9586966325917018, Test Accuracy: 0.9493385447985568
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9594    0.9799    0.9695      5470
           1     0.8967    0.8080    0.8500      1182

    accuracy                         0.9493      6652
   macro avg     0.9280    0.8939    0.9098      6652
weighted avg     0.9482    0.9493    0.9483      6652






In [31]:
criterion = nn.MSELoss()
feature_column_name = 'number-annotation' # label, Trend, Layout, number annotation, grouping

features = resnet_features
classifier = train(data, features, feature_column_name, criterion)

  5%|▌         | 1/20 [00:02<00:56,  3.00s/it]

Epoch 1, Loss: 0.5378741438572223, Train Accuracy: 0.7797279013830427, Test Accuracy: 0.8457606734816596


 10%|█         | 2/20 [00:05<00:50,  2.83s/it]

Epoch 2, Loss: 0.14398399797769693, Train Accuracy: 0.8497068550811786, Test Accuracy: 0.8693625977149729


 15%|█▌        | 3/20 [00:08<00:45,  2.69s/it]

Epoch 3, Loss: 0.1156140285042616, Train Accuracy: 0.8688364401683704, Test Accuracy: 0.8656043295249549


 20%|██        | 4/20 [00:11<00:44,  2.75s/it]

Epoch 4, Loss: 0.10249582047645862, Train Accuracy: 0.8723692122669874, Test Accuracy: 0.879585087191822


 25%|██▌       | 5/20 [00:14<00:42,  2.81s/it]

Epoch 5, Loss: 0.09585122133676822, Train Accuracy: 0.8830051112447385, Test Accuracy: 0.8831930246542393


 30%|███       | 6/20 [00:16<00:39,  2.81s/it]

Epoch 6, Loss: 0.09133261442184448, Train Accuracy: 0.8854855682501503, Test Accuracy: 0.8831930246542393


 35%|███▌      | 7/20 [00:19<00:35,  2.74s/it]

Epoch 7, Loss: 0.08814085848056354, Train Accuracy: 0.8883418520745641, Test Accuracy: 0.8907095610342755


 40%|████      | 8/20 [00:22<00:32,  2.73s/it]

Epoch 8, Loss: 0.08556852489709854, Train Accuracy: 0.8930021046301864, Test Accuracy: 0.8883042693926638


 45%|████▌     | 9/20 [00:24<00:30,  2.77s/it]

Epoch 9, Loss: 0.08370303993041699, Train Accuracy: 0.8947684906794949, Test Accuracy: 0.8949188214070956


 50%|█████     | 10/20 [00:27<00:28,  2.82s/it]

Epoch 10, Loss: 0.08185219363524364, Train Accuracy: 0.8965724594107035, Test Accuracy: 0.8986770895971137


 55%|█████▌    | 11/20 [00:30<00:25,  2.80s/it]

Epoch 11, Loss: 0.08043396300994433, Train Accuracy: 0.8997294046903187, Test Accuracy: 0.9028863499699339


 60%|██████    | 12/20 [00:33<00:22,  2.75s/it]

Epoch 12, Loss: 0.0794793413235591, Train Accuracy: 0.9017588695129285, Test Accuracy: 0.9003307276007216


 65%|██████▌   | 13/20 [00:36<00:19,  2.77s/it]

Epoch 13, Loss: 0.07784550235821651, Train Accuracy: 0.9043144918821407, Test Accuracy: 0.9019843656043295


 70%|███████   | 14/20 [00:38<00:16,  2.78s/it]

Epoch 14, Loss: 0.076388102884476, Train Accuracy: 0.9054043896572459, Test Accuracy: 0.9001803968731209


 75%|███████▌  | 15/20 [00:41<00:14,  2.81s/it]

Epoch 15, Loss: 0.07512334161079846, Train Accuracy: 0.9069452796151534, Test Accuracy: 0.903337342152736


 80%|████████  | 16/20 [00:44<00:11,  2.78s/it]

Epoch 16, Loss: 0.07396456369986901, Train Accuracy: 0.908974744437763, Test Accuracy: 0.9003307276007216


 85%|████████▌ | 17/20 [00:47<00:08,  2.73s/it]

Epoch 17, Loss: 0.07313522696495056, Train Accuracy: 0.909388153938665, Test Accuracy: 0.9088995790739627


 90%|█████████ | 18/20 [00:50<00:05,  2.81s/it]

Epoch 18, Loss: 0.07222064011372052, Train Accuracy: 0.9114176187612748, Test Accuracy: 0.9099518941671678


 95%|█████████▌| 19/20 [00:53<00:02,  2.84s/it]

Epoch 19, Loss: 0.07133303582668304, Train Accuracy: 0.9126202645820806, Test Accuracy: 0.908749248346362


100%|██████████| 20/20 [00:56<00:00,  2.82s/it]

Epoch 20, Loss: 0.07033189901938805, Train Accuracy: 0.914987973541792, Test Accuracy: 0.9101022248947684
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9239    0.9706    0.9467      5470
           1     0.8223    0.6303    0.7136      1182

    accuracy                         0.9101      6652
   macro avg     0.8731    0.8004    0.8301      6652
weighted avg     0.9059    0.9101    0.9053      6652






In [20]:
criterion = nn.MSELoss()
feature_column_name = 'number-annotation' # label, Trend, Layout, number annotation, grouping

features = clip_features
classifier = train(data, features, feature_column_name, criterion)

  5%|▌         | 1/20 [00:02<00:54,  2.85s/it]

Epoch 1, Loss: 0.1843610514815037, Train Accuracy: 0.8177239927841251, Test Accuracy: 0.8541791942273


 10%|█         | 2/20 [00:05<00:49,  2.78s/it]

Epoch 2, Loss: 0.11021389353733796, Train Accuracy: 0.86458959711365, Test Accuracy: 0.8672579675285629


 15%|█▌        | 3/20 [00:08<00:47,  2.82s/it]

Epoch 3, Loss: 0.08966564329770896, Train Accuracy: 0.88751503307276, Test Accuracy: 0.8926638604930848


 20%|██        | 4/20 [00:11<00:46,  2.89s/it]

Epoch 4, Loss: 0.07762728803432904, Train Accuracy: 0.9044272399278412, Test Accuracy: 0.9052916416115454


 25%|██▌       | 5/20 [00:14<00:42,  2.84s/it]

Epoch 5, Loss: 0.06947047664568974, Train Accuracy: 0.917280517137703, Test Accuracy: 0.9158147925435959


 30%|███       | 6/20 [00:16<00:37,  2.69s/it]

Epoch 6, Loss: 0.06393326417757915, Train Accuracy: 0.92498496692724, Test Accuracy: 0.9221286831028263


 35%|███▌      | 7/20 [00:18<00:32,  2.50s/it]

Epoch 7, Loss: 0.060073994673215426, Train Accuracy: 0.9300586289837642, Test Accuracy: 0.9234816596512327


 40%|████      | 8/20 [00:21<00:29,  2.46s/it]

Epoch 8, Loss: 0.05751634331849905, Train Accuracy: 0.9326518340348767, Test Accuracy: 0.9278412507516537


 45%|████▌     | 9/20 [00:23<00:26,  2.39s/it]

Epoch 9, Loss: 0.05538761930970045, Train Accuracy: 0.93539536981359, Test Accuracy: 0.9293445580276608


 50%|█████     | 10/20 [00:25<00:23,  2.34s/it]

Epoch 10, Loss: 0.053431435559804626, Train Accuracy: 0.9368610944076969, Test Accuracy: 0.929795550210463


 55%|█████▌    | 11/20 [00:28<00:21,  2.40s/it]

Epoch 11, Loss: 0.051974133803294256, Train Accuracy: 0.9386650631389056, Test Accuracy: 0.9315995189416717


 60%|██████    | 12/20 [00:30<00:19,  2.46s/it]

Epoch 12, Loss: 0.050523651333955616, Train Accuracy: 0.9401683704149129, Test Accuracy: 0.9332531569452797


 65%|██████▌   | 13/20 [00:33<00:17,  2.56s/it]

Epoch 13, Loss: 0.0493227974153482, Train Accuracy: 0.9415965123271197, Test Accuracy: 0.9322008418520745


 70%|███████   | 14/20 [00:36<00:16,  2.70s/it]

Epoch 14, Loss: 0.04818477309667147, Train Accuracy: 0.9433253156945279, Test Accuracy: 0.9358087793144919


 75%|███████▌  | 15/20 [00:39<00:13,  2.64s/it]

Epoch 15, Loss: 0.04745739698410034, Train Accuracy: 0.9453547805171377, Test Accuracy: 0.9332531569452797


 80%|████████  | 16/20 [00:41<00:10,  2.65s/it]

Epoch 16, Loss: 0.046152508889253326, Train Accuracy: 0.9462191822008419, Test Accuracy: 0.9365604329524955


 85%|████████▌ | 17/20 [00:44<00:08,  2.70s/it]

Epoch 17, Loss: 0.045339659142952696, Train Accuracy: 0.9474218280216476, Test Accuracy: 0.9361094407696934


 90%|█████████ | 18/20 [00:47<00:05,  2.65s/it]

Epoch 18, Loss: 0.04464701763712443, Train Accuracy: 0.9478352375225496, Test Accuracy: 0.9374624173180999


 95%|█████████▌| 19/20 [00:49<00:02,  2.67s/it]

Epoch 19, Loss: 0.043812879003011264, Train Accuracy: 0.9493385447985568, Test Accuracy: 0.9356584485868912


100%|██████████| 20/20 [00:52<00:00,  2.62s/it]

Epoch 20, Loss: 0.043180872614567094, Train Accuracy: 0.9497519542994588, Test Accuracy: 0.9388153938665064
Classification Report on Test Data:
              precision    recall  f1-score   support

           0     0.9544    0.9720    0.9631      5470
           1     0.8585    0.7851    0.8202      1182

    accuracy                         0.9388      6652
   macro avg     0.9064    0.8786    0.8916      6652
weighted avg     0.9374    0.9388    0.9377      6652




