In [29]:
!pip3 install plotly



In [30]:
import numpy as np
import pandas as pd

from collections import OrderedDict
from plotly.graph_objs import *
 
import os

logs = {}

In [31]:
# Supervised Logs
supervised_logs_names = [f.name for f in os.scandir("supervised_logs") if f.is_file()]

supervised_logs_names.sort()
supervised_logs_names

# DINO (Same Domain) Logs
dino_folders = [f.name for f in os.scandir("dino_logs") if f.is_dir()]
print(dino_folders)

dino_logs_names = {}

for folder in dino_folders:
    dino_logs_names[folder] = [f.name for f in os.scandir(f"dino_logs/{folder}") if f.is_file()]
    
# DINO (Same Class) Logs
ship_experiment_folder = [f.name for f in os.scandir("ship_experiment") if f.is_dir()]
print(ship_experiment_folder)

se_logs_names = {}

for folder in ship_experiment_folder:
    se_logs_names[folder] = [f.name for f in os.scandir(f"ship_experiment/{folder}") if f.is_file()]
    

['500train', '100train', '200train', '1000train']
['500train', '100train', '200train', '1000train']


In [32]:
def process_log_format(file_name, log_type):
    assert log_type in ('supervised', 'self_supervised'), "Invalid Log Type"
    
    if log_type == "supervised":
        # Returns Model Architecture and Number of images used
        split = file_name.split("_")
        return split[0], int(split[2].replace("train", "").replace(".csv", ""))
        
        
    elif log_type == "self_supervised":
        # Returns Model Architecture and Number of pretext epochs
        split = file_name.split("_")
        return split[1], int(split[3].replace('.txt', ''))
        

In [33]:
# Preprocessing for pretext task
def process_log(log_name):
    log_df = pd.read_csv(log_name, header=None, names=['train_loss', 'train_lr', 'train_wd', 'epoch'])

    for col in log_df.columns:
        log_df[col] = log_df[col].apply(lambda x : ' '.join(x.split(':')[1:]))

    log_df['epoch'] = log_df['epoch'].apply(lambda x : x.replace('}', ''))

    # Convert types
    for col in log_df.columns:
        log_df[col] = pd.to_numeric(log_df[col], errors='coerce')

    return log_df

In [34]:
# Processing DINO (Same Class) Pretext Training Logs
DINO_class_pretext_log = process_log("ship_experiment/vitbase_4to4_log.txt")
print(DINO_class_pretext_log)

print("~" * 50)

DINO_domain_pretext_log = process_log("dino_logs/vitbase_14to4_log.txt")
print(DINO_domain_pretext_log)

     train_loss  train_lr  train_wd  epoch
0     10.990704  0.000002  0.040007      0
1     10.665567  0.000005  0.040051      1
2     10.671687  0.000008  0.040139      2
3     10.820820  0.000011  0.040271      3
4     10.576677  0.000014  0.040447      4
..          ...       ...       ...    ...
196    2.062722  0.000001  0.399553    196
197    2.063016  0.000001  0.399729    197
198    2.050508  0.000001  0.399861    198
199    2.057217  0.000001  0.399949    199
200    2.054417  0.000001  0.399993    200

[201 rows x 4 columns]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
     train_loss  train_lr  train_wd  epoch
0      9.549437  0.000002  0.040013      0
1      9.760854  0.000005  0.040092      1
2     10.006563  0.000008  0.040250      2
3      9.084784  0.000011  0.040487      3
4      8.596706  0.000014  0.040802      4
..          ...       ...       ...    ...
196    1.549876  0.000001  0.399553    196
197    1.544960  0.000001  0.399729    197
198    1.541909  0.000

In [35]:
# Processing Supervised Logs\
supervised_logs = {}
sup_max_acc_store = {}

for log_file in supervised_logs_names:
    arch_name, num_labels = process_log_format(log_file, "supervised")
    
    temp_logs = pd.read_csv(f"supervised_logs/{log_file}", sep=',', header=0, usecols=['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'])
    
    # Create key if not yet initialized
    if arch_name not in supervised_logs.keys():
        supervised_logs[arch_name] = {}
    
    # Create key if not yet initialized
    if arch_name not in sup_max_acc_store.keys():
        sup_max_acc_store[arch_name] = {}
    
    max_acc = round(temp_logs.val_acc.max() * 100, 4)
    
    sup_max_acc_store[arch_name][num_labels] = max_acc
    supervised_logs[arch_name][num_labels] = temp_logs
    
for arch_type in supervised_logs.keys():
    supervised_logs[arch_type] = OrderedDict(sorted(supervised_logs[arch_type].items()))
    sup_max_acc_store[arch_type] = OrderedDict(sorted(sup_max_acc_store[arch_type] .items()))


logs["supervised"] = supervised_logs

In [36]:
# Processing DINO Logs
self_supervised_logs = {}
ssl_max_acc_store = {}

for folder in dino_folders:
    num_labels = int(folder.replace('train', ''))
    # Create key if not yet initialized
    if num_labels not in self_supervised_logs.keys():
        self_supervised_logs[num_labels] = {}
        
    if num_labels not in ssl_max_acc_store.keys():
        ssl_max_acc_store[num_labels] = {}
    
    for log_file in dino_logs_names[folder]:
        arch_name, epochs = process_log_format(log_file, "self_supervised")

        temp_logs = pd.read_csv(f"dino_logs/{folder}/{log_file}", header=None, names=['train_lr', 'train_loss', 'epoch', 'test_loss', 'test_acc1'])

        # Abit of post processing
        latest_test_loss = 0
        latest_test_acc = 0

        # Handle NaN
        for index, row in temp_logs.iterrows():
            if not isinstance(row['test_loss'], str):
                row['test_loss'] = latest_test_loss
                row['test_acc1'] = latest_test_acc
            else:
                latest_test_loss = row['test_loss']
                latest_test_acc = row['test_acc1']

        for col in temp_logs.columns:
                temp_logs[col] = temp_logs[col].apply(lambda x : ' '.join(x.split(':')[1:]))
                if col == 'epoch' or col == 'test_acc1':
                    temp_logs[col] = temp_logs[col].apply(lambda x : x.replace('}', ''))

        # Convert types
        for col in temp_logs.columns:
            temp_logs[col] = pd.to_numeric(temp_logs[col], errors='coerce')

        # Create key if not yet initialized
        if arch_name not in self_supervised_logs[num_labels].keys():
            self_supervised_logs[num_labels][arch_name] = {}

        # Create key if not yet initialized
        if arch_name not in ssl_max_acc_store[num_labels].keys():
            ssl_max_acc_store[num_labels][arch_name] = {}
            
        
        max_acc = round(temp_logs.test_acc1.max(), 4)
    
        ssl_max_acc_store[num_labels][arch_name][epochs] = max_acc
        self_supervised_logs[num_labels][arch_name][epochs] = temp_logs
        
for folder in dino_folders:
    num_labels = int(folder.replace('train', ''))
    for arch_name in self_supervised_logs[num_labels].keys():
        self_supervised_logs[num_labels][arch_name] = OrderedDict(sorted(self_supervised_logs[num_labels][arch_name].items()))
        ssl_max_acc_store[num_labels][arch_name] = OrderedDict(sorted(ssl_max_acc_store[num_labels][arch_name].items()))


# self_supervised_logs[NUM_LABELS][ARCH_TYPE][EPOCHS]
logs['self_supervised'] = self_supervised_logs
    

In [37]:
# Processing ship_experiment Logs
se_logs = {}
se_max_acc_store = {}

for folder in ship_experiment_folder:
    num_labels = int(folder.replace('train', ''))
    # Create key if not yet initialized
    if num_labels not in se_logs.keys():
        se_logs[num_labels] = {}
        
    if num_labels not in se_max_acc_store.keys():
        se_max_acc_store[num_labels] = {}
    
    for log_file in se_logs_names[folder]:
        arch_name, epochs = process_log_format(log_file, "self_supervised")

        temp_logs = pd.read_csv(f"ship_experiment/{folder}/{log_file}", header=None, names=['train_lr', 'train_loss', 'epoch', 'test_loss', 'test_acc1'])

        # Abit of post processing
        latest_test_loss = 0
        latest_test_acc = 0

        # Handle NaN
        for index, row in temp_logs.iterrows():
            if not isinstance(row['test_loss'], str):
                row['test_loss'] = latest_test_loss
                row['test_acc1'] = latest_test_acc
            else:
                latest_test_loss = row['test_loss']
                latest_test_acc = row['test_acc1']

        for col in temp_logs.columns:
                temp_logs[col] = temp_logs[col].apply(lambda x : ' '.join(x.split(':')[1:]))
                if col == 'epoch' or col == 'test_acc1':
                    temp_logs[col] = temp_logs[col].apply(lambda x : x.replace('}', ''))

        # Convert types
        for col in temp_logs.columns:
            temp_logs[col] = pd.to_numeric(temp_logs[col], errors='coerce')

        # Create key if not yet initialized
        if arch_name not in se_logs[num_labels].keys():
            se_logs[num_labels][arch_name] = {}

        # Create key if not yet initialized
        if arch_name not in se_max_acc_store[num_labels].keys():
            se_max_acc_store[num_labels][arch_name] = {}
            
        
        max_acc = round(temp_logs.test_acc1.max(), 4)
    
        se_max_acc_store[num_labels][arch_name][epochs] = max_acc
        se_logs[num_labels][arch_name][epochs] = temp_logs
        
for folder in ship_experiment_folder:
    num_labels = int(folder.replace('train', ''))
    for arch_name in se_logs[num_labels].keys():
        se_logs[num_labels][arch_name] = OrderedDict(sorted(se_logs[num_labels][arch_name].items()))
        se_max_acc_store[num_labels][arch_name] = OrderedDict(sorted(se_max_acc_store[num_labels][arch_name].items()))


# se_logs[NUM_LABELS][ARCH_TYPE][EPOCHS]
logs['ship_experiment'] = se_logs
    

In [38]:
ship_experiment_folder

['500train', '100train', '200train', '1000train']

In [39]:
logs['self_supervised'][num_labels].keys()

dict_keys(['vitbase'])

In [40]:
# for arch_type in logs['supervised'].keys():
#     for num_labels in supervised_logs[arch_type]:
#         print(f"{arch_type} with {num_labels} labelled training data has max top1 accuracy of {sup_max_acc_store[arch_type][num_labels]}%")
        
# for num_labels in logs['self_supervised'].keys():
#     for arch_type in logs['self_supervised'][num_labels].keys():
#         for epochs in logs['self_supervised'][num_labels][arch_type].keys():
#             print(f"DINO ({arch_type}) with {num_labels} labelled training data and {epochs} pretext epochs has max top1 accuracy of {ssl_max_acc_store[num_labels][arch_type][epochs]}%")

In [41]:
arch_types=['alexnet', 'deitbase', 'resnet50', 'dino (vitbase)']
num_labels_type = ['100', '200', '500', '1000']

In [42]:
dino_max_by_numlabels = {}

for num_labels in logs["self_supervised"].keys():
    max_for_label = -1
    for num_epochs in logs["self_supervised"][num_labels]['vitbase'].keys():
        if logs["self_supervised"][num_labels]["vitbase"][num_epochs].test_acc1.max() > max_for_label:
            max_for_label = logs["self_supervised"][num_labels]["vitbase"][num_epochs].test_acc1.max()
        
        dino_max_by_numlabels[num_labels] = max_for_label
        
dino_max_by_numlabels = OrderedDict(sorted(dino_max_by_numlabels.items()))
    
logs["self_supervised"][100]['vitbase'].keys()
for num_labels in num_labels_type:
    if int(num_labels) not in dino_max_by_numlabels.keys():
        dino_max_by_numlabels[int(num_labels)] = 0
        dino_max_by_numlabels = OrderedDict(sorted(dino_max_by_numlabels.items()))
        
print(dino_max_by_numlabels)

OrderedDict([(100, 91.80000048828126), (200, 93.05000067138673), (500, 94.85000036621092), (1000, 95.7500005493164)])


In [43]:
se_max_by_numlabels = {}

for num_labels in logs["ship_experiment"].keys():
    max_for_label = -1
    for num_epochs in logs["ship_experiment"][num_labels]['vitbase'].keys():
        if logs["ship_experiment"][num_labels]["vitbase"][num_epochs].test_acc1.max() > max_for_label:
            max_for_label = logs["ship_experiment"][num_labels]["vitbase"][num_epochs].test_acc1.max()
        
        se_max_by_numlabels[num_labels] = max_for_label
        #print(num_epochs, logs["ship_experiment"][num_labels]["vitbase"][num_epochs].test_acc1.max())
        
se_max_by_numlabels = OrderedDict(sorted(se_max_by_numlabels.items()))
    
logs["ship_experiment"][100]['vitbase'].keys()
for num_labels in num_labels_type:
    if int(num_labels) not in se_max_by_numlabels.keys():
        se_max_by_numlabels[int(num_labels)] = 0
        se_max_by_numlabels = OrderedDict(sorted(se_max_by_numlabels.items()))
        
print(se_max_by_numlabels)

OrderedDict([(100, 93.5), (200, 94.60000073242188), (500, 95.05000073242188), (1000, 95.0500005493164)])


In [44]:
logs.keys()

dict_keys(['supervised', 'self_supervised', 'ship_experiment'])

In [45]:
for arch in arch_types[:3]:
    for num_labels in num_labels_type:
        if int(num_labels) not in sup_max_acc_store[arch].keys():
            sup_max_acc_store[arch][int(num_labels)] = 0
            sup_max_acc_store[arch] = OrderedDict(sorted(sup_max_acc_store[arch].items()))

simclr_max_acc_store = [87.70492, 88.89344, 89.75410, 90.40983]

In [52]:
import plotly.graph_objects as go

layout = Layout(

)


fig = go.Figure(data=[
    go.Bar(name='Deit Base', x=num_labels_type, y=tuple(sup_max_acc_store['deitbase'].values()), marker_color="dodgerblue"),
    go.Bar(name='Resnet50', x=num_labels_type, y=tuple(sup_max_acc_store['resnet50'].values()), marker_color="blueviolet"),
    go.Bar(name='Alexnet', x=num_labels_type, y=tuple(sup_max_acc_store['alexnet'].values()), marker_color="darkmagenta"),
    #go.Bar(name='SimCLR (resnet50)', x=num_labels_type, y=simclr_max_acc_store, marker_color="gold"),
    go.Bar(name='DINO Same Domain(vit_base)', x=num_labels_type, y=tuple(dino_max_by_numlabels.values()), marker_color="darkorange"),
    go.Bar(name='DINO Same Class (vit_base)', x=num_labels_type, y=tuple(se_max_by_numlabels.values()), marker_color="red"),
], layout=layout)

fig.update_layout(
    barmode='group',
    title="Supervised vs Self-Supervised Benchmarks",
    xaxis_title="Number of labeled data per class",
    yaxis_title="Top1 Accuracy",
    legend_title="Architecture",
    font=dict(
        color="black",
    )
)

#fig.update_xaxes(range=[1.5, 4.5])
fig.update_yaxes(range=[50, 100])
fig.show()

In [47]:
print(logs['self_supervised'][100]['vitbase'].keys())

logs['self_supervised'][100]['vitbase'][10]

odict_keys([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 160, 180, 200])


Unnamed: 0,train_lr,train_loss,epoch,test_loss,test_acc1
0,1.000000e-03,1.446139,0,1.346696,34.40
1,9.997533e-04,1.361173,1,1.346696,34.40
2,9.990134e-04,1.255673,2,1.346696,34.40
3,9.977810e-04,1.215496,3,1.346696,34.40
4,9.960574e-04,1.157777,4,1.346696,34.40
...,...,...,...,...,...
95,6.155830e-06,0.740319,95,0.786730,69.75
96,3.942649e-06,0.698608,96,0.786730,69.75
97,2.219018e-06,0.725134,97,0.786730,69.75
98,9.866358e-07,0.708642,98,0.786730,69.75


In [48]:
print(ssl_max_acc_store[200]['vitbase'].keys())



odict_keys([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 150, 160, 180, 200])


In [49]:
# DINO_class_pretext_log, DINO_domain_pretext_log

layout = Layout(
    paper_bgcolor='rgba(0,0,0,0)',
#     plot_bgcolor='rgba(0,0,0,0)'
)

fig = go.Figure()

# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(ssl_max_acc_store[100]['vitbase'].keys()),
               y=tuple(ssl_max_acc_store[100]['vitbase'].values()),
               name='Same Domain (100 labels/class)',
               line=dict(color='aqua', width=2)))

# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(ssl_max_acc_store[200]['vitbase'].keys()),
               y=tuple(ssl_max_acc_store[200]['vitbase'].values()),
               name='Same Domain (200 labels/class)',
               line=dict(color='darkturquoise', width=2)))

# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(ssl_max_acc_store[500]['vitbase'].keys()),
               y=tuple(ssl_max_acc_store[500]['vitbase'].values()),
               name='Same Domain (500 labels/class)',
               line=dict(color='dodgerblue', width=2)))

# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(se_max_acc_store[100]['vitbase'].keys()),
               y=tuple(se_max_acc_store[100]['vitbase'].values()),
               name='Same Class (100 labels/class)',
               line=dict(color='blanchedalmond', width=2)))


# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(se_max_acc_store[200]['vitbase'].keys()),
               y=tuple(se_max_acc_store[200]['vitbase'].values()),
               name='Same Class (200 labels/class)',
               line=dict(color='sandybrown', width=2)))

# Add Traces
fig.add_trace(
    go.Scatter(x=tuple(se_max_acc_store[500]['vitbase'].keys()),
               y=tuple(se_max_acc_store[500]['vitbase'].values()),
               name='Same Class (500 labels/class)',
               line=dict(color='coral', width=2)))

# Edit the layout
fig.update_layout(title='Downstream Evaluation',
                   xaxis_title='Pretext Epochs',
                   yaxis_title='Top1 Acc')

#Add Buttons
fig.update_layout(
    updatemenus=[
        dict(
            type="buttons",
            direction="right",
            active=0,
            x=0.85,
            y=1.22,
            buttons=list([
                dict(label="Same Class",
                     method="update",
                     args=[{"visible": [True, True, True, False, False, False]},
                           {"title": "Same Class",}
                        ]
                    ),
                dict(label="Same Domain",
                     method="update",
                     args=[{"visible": [False, False, False, True, True, True]},
                           {"title": "Same Domain",}
                        ]
                    ),
                dict(label="100/class",
                     method="update",
                     args=[{"visible": [True, False, False, True, False, False]},
                           {"title": "100 labels per class",}
                        ]
                    ),
                dict(label="200/class",
                     method="update",
                     args=[{"visible": [False, True, False, False, True, False]},
                           {"title": "200 labels per class",}
                        ]
                    ),
                dict(label="500/class",
                     method="update",
                     args=[{"visible": [False, False, True, False, False, True]},
                           {"title": "500 labels per class",}
                        ]
                    ),
                
            ]),
        )
    ])
fig.update_layout(
    barmode='group',
    title="Pretext Analysis",
    xaxis_title="Epochs",
    yaxis_title="Loss",
    legend_title="Experiment Type",
    font=dict(
        color="black",
    )
)

fig.show()


In [50]:
def show_named_plotly_colours():
    """
    function to display to user the colours to match plotly's named
    css colours.

    Reference:
        #https://community.plotly.com/t/plotly-colours-list/11730/3

    Returns:
        plotly dataframe with cell colour to match named colour name

    """
    s='''
        aliceblue, antiquewhite, aqua, aquamarine, azure,
        beige, bisque, black, blanchedalmond, blue,
        blueviolet, brown, burlywood, cadetblue,
        chartreuse, chocolate, coral, cornflowerblue,
        cornsilk, crimson, cyan, darkblue, darkcyan,
        darkgoldenrod, darkgray, darkgrey, darkgreen,
        darkkhaki, darkmagenta, darkolivegreen, darkorange,
        darkorchid, darkred, darksalmon, darkseagreen,
        darkslateblue, darkslategray, darkslategrey,
        darkturquoise, darkviolet, deeppink, deepskyblue,
        dimgray, dimgrey, dodgerblue, firebrick,
        floralwhite, forestgreen, fuchsia, gainsboro,
        ghostwhite, gold, goldenrod, gray, grey, green,
        greenyellow, honeydew, hotpink, indianred, indigo,
        ivory, khaki, lavender, lavenderblush, lawngreen,
        lemonchiffon, lightblue, lightcoral, lightcyan,
        lightgoldenrodyellow, lightgray, lightgrey,
        lightgreen, lightpink, lightsalmon, lightseagreen,
        lightskyblue, lightslategray, lightslategrey,
        lightsteelblue, lightyellow, lime, limegreen,
        linen, magenta, maroon, mediumaquamarine,
        mediumblue, mediumorchid, mediumpurple,
        mediumseagreen, mediumslateblue, mediumspringgreen,
        mediumturquoise, mediumvioletred, midnightblue,
        mintcream, mistyrose, moccasin, navajowhite, navy,
        oldlace, olive, olivedrab, orange, orangered,
        orchid, palegoldenrod, palegreen, paleturquoise,
        palevioletred, papayawhip, peachpuff, peru, pink,
        plum, powderblue, purple, red, rosybrown,
        royalblue, saddlebrown, salmon, sandybrown,
        seagreen, seashell, sienna, silver, skyblue,
        slateblue, slategray, slategrey, snow, springgreen,
        steelblue, tan, teal, thistle, tomato, turquoise,
        violet, wheat, white, whitesmoke, yellow,
        yellowgreen
        '''
    li=s.split(',')
    li=[l.replace('\n','') for l in li]
    li=[l.replace(' ','') for l in li]

    import pandas as pd
    import plotly.graph_objects as go

    df=pd.DataFrame.from_dict({'colour': li})
    fig = go.Figure(data=[go.Table(
      header=dict(
        values=["Plotly Named CSS colours"],
        line_color='black', fill_color='white',
        align='center', font=dict(color='black', size=14)
      ),
      cells=dict(
        values=[df.colour],
        line_color=[df.colour], fill_color=[df.colour],
        align='center', font=dict(color='black', size=11)
      ))
    ])

    fig.show()

show_named_plotly_colours()