In [7]:
import numpy as np
import matplotlib.pyplot as plt

In [8]:
# Function to check if any value exceeds the specified range
def check_y_limits(df, lower_limit, upper_limit):
    exceeding_points = df[(df['Average Percent Logit Difference'] > upper_limit) | (df['Average Percent Logit Difference'] < lower_limit)]
    return exceeding_points


In [25]:
import plotly.express as px
import pandas as pd
import numpy as np

# Assuming you have your averages in a list named 'mock_avgs'



def plot(layer_num, neuron_type, layer_type=None):
    np.random.seed(0)


    # layer_num = 7
    # neuron_type = 'cups'

    load_dir = f'/network/scratch/s/sonia.joseph/clip_mechinterp/tinyclip/logit_differences/cup_logits'

    if not layer_type:
        file_name = f'neuron_dict_{layer_num}_{neuron_type}.npy'
    else:
        file_name = f'neuron_dict_{layer_num}_{neuron_type}_{layer_type}.npy'
    

    neuron_dict = np.load(f'{load_dir}/{file_name}', allow_pickle=True).item()

    print(neuron_dict)

    # Get average of every entry in neuron_dict
    avgs = [np.mean(neuron_dict[key]) for key in neuron_dict.keys()]

    num_neurons = len(avgs)

    # Creating a DataFrame
    df = pd.DataFrame({'Neuron': range(0, num_neurons), 'Average Percent Logit Difference': avgs})

    # Creating the scatter plot
    scatter_fig = px.scatter(df, x='Neuron', y='Average Percent Logit Difference', 
                            title=f'Average % Logit Difference for Layer {layer_num}, {neuron_type} Neurons',
                            labels={'Neuron': 'Neuron', 'Average % Logit Difference': 'Average % Logit Difference'},
                            hover_data={'Neuron': True, 'Average Percent Logit Difference': True})
    
    
    min = -2000
    max = 1400

    scatter_fig.update_yaxes(range=[min,max])

    # Checking the limits
    exceeding_points = check_y_limits(df, min, max).values
    print(exceeding_points)

    # Warning message if the plot exceeds the specified y-axis limits
    if exceeding_points.any():
        warning_message = f"WARNING: Some data points exceed the y-axis limits: {exceeding_points}"
    else:
        warning_message = "All data points are within the y-axis limits."

    print(warning_message)

    # Showing the plot
    # scatter_fig.show()
    return scatter_fig


In [26]:
# plot(7, 'cups')
for i in range(10):
    try:
        # plot(i, 'cups')
        fig = plot(i, 'cups', layer_type='fc2')
        fig.show()

    except Exception as e:
        print(e)
        pass

defaultdict(<class 'list'>, {0: [3.724515438079834, -3.0379682779312134, -8.232036232948303, 0.5228193942457438], 1: [6.185632944107056, 4.233979061245918, -14.86663818359375, 3.190140426158905], 2: [-6.686431169509888, 1.5833264216780663, -6.093863770365715, 0.03197806072421372], 3: [9.796123206615448, -1.0875118896365166, -5.658067390322685, 1.7620788887143135], 4: [12.694184482097626, 0.7393062580376863, 22.24602848291397, -0.1614235108718276], 5: [-18.15188080072403, -3.607587516307831, -13.804110884666443, -1.0852828621864319], 6: [-25.402051210403442, -1.6577618196606636, 26.21629536151886, 2.4905653670430183], 7: [0.2846066141501069, 3.0139975249767303, 1.001311745494604, -0.9989858604967594], 8: [1.0335233993828297, -0.37722280248999596, -4.567847400903702, 1.5577258542180061], 9: [-2.862485684454441, -4.277504608035088, 0.5392800085246563, -3.523062542080879], 10: [-6.375370919704437, -0.90583935379982, 6.767234951257706, 0.9041468612849712], 11: [-18.324287235736847, 2.809615

defaultdict(<class 'list'>, {0: [13.724970817565918, -4.044736921787262, -9.790118038654327, 1.7184922471642494], 1: [-8.940479159355164, -1.4464412815868855, 20.785599946975708, -0.10403821943327785], 2: [8.63945260643959, -0.7178587839007378, -12.033231556415558, 4.214095324277878], 3: [-2.3452678695321083, -0.8238344453275204, -80.43434023857117, 2.4999836459755898], 4: [-10.039680451154709, 8.571410179138184, 19.727444648742676, 4.690700024366379], 5: [23.920638859272003, 16.27293825149536, 48.05713593959808, -1.1783696711063385], 6: [17.174914479255676, -5.279857665300369, -38.555604219436646, -2.249196730554104], 7: [12.042873352766037, 3.2442424446344376, -48.51042032241821, -0.07950710132718086], 8: [17.63283908367157, 7.0259325206279755, 5.585191771388054, 0.08038321393541992], 9: [-0.185176741797477, 6.372416019439697, -33.897390961647034, -2.902775816619396], 10: [-2.3306727409362793, -2.5131365284323692, -40.68648815155029, -3.1555339694023132], 11: [-1.3181299902498722, 0.

defaultdict(<class 'list'>, {0: [21.576282382011414, 5.7889193296432495, -42.72555112838745, -4.19788733124733], 1: [23.775598406791687, 5.062229186296463, 5.518145859241486, 1.979135349392891], 2: [9.830787032842636, -2.5604471564292908, 3.0520332977175713, 0.962408259510994], 3: [10.555986315011978, 5.316444486379623, -10.808920115232468, 1.2381643056869507], 4: [2.315165288746357, -1.7946470528841019, 24.07958060503006, 0.09177266038022935], 5: [-15.409806370735168, -2.3144321516156197, 35.6507807970047, 1.4475548639893532], 6: [-41.62006974220276, 2.5453077629208565, 44.00233328342438, -1.376370806246996], 7: [-6.800456345081329, 10.593778640031815, 25.72219669818878, 0.9429147467017174], 8: [-2.1427594125270844, 0.87745301425457, 31.284070014953613, -0.030225838418118656], 9: [-31.3769668340683, 2.3201094940304756, -1.1004227213561535, 1.0469529777765274], 10: [22.152793407440186, -2.8159240260720253, -17.92450100183487, -1.4385747723281384], 11: [18.691904842853546, -0.6144064012

defaultdict(<class 'list'>, {0: [-7.603193074464798, 14.03798758983612, 24.725258350372314, -2.9938913881778717], 1: [-13.762371242046356, 0.9064702317118645, -32.44570791721344, 0.0054756954341428354], 2: [6.440136581659317, 5.035735294222832, -6.1638243496418, 0.9155362844467163], 3: [-11.193614453077316, 7.538779079914093, 42.32181906700134, 0.49850731156766415], 4: [-14.408209919929504, 5.148649588227272, -18.831074237823486, -2.232769690454006], 5: [-12.040136754512787, 3.879465162754059, 0.47806440852582455, -0.9409435093402863], 6: [27.149829268455505, 6.796318292617798, 16.376622021198273, 1.5599161386489868], 7: [7.974458485841751, 1.3619131408631802, -14.097070693969727, 1.5984650701284409], 8: [-60.09942889213562, -1.1550082825124264, -15.604139864444733, 0.04183431447017938], 9: [2.993842586874962, 0.8768222294747829, 32.99519121646881, 2.4868419393897057], 10: [-1.2068415060639381, -2.445640228688717, 33.57382416725159, 0.9019565768539906], 11: [33.7888240814209, 0.6068367

defaultdict(<class 'list'>, {0: [9.795211255550385, 8.103982359170914, 24.735461175441742, -0.7464468013495207], 1: [-14.824172854423523, -0.9499958716332912, -12.597289681434631, -7.435337454080582], 2: [-16.369441151618958, 10.0191131234169, -19.637079536914825, -3.0550001189112663], 3: [-2.0971493795514107, 13.492970168590546, -11.244716495275497, -3.4862659871578217], 4: [0.3156214253976941, 5.44828325510025, 43.85074973106384, -0.5955366417765617], 5: [7.137057930231094, -2.748427726328373, -10.813292860984802, 1.8840773031115532], 6: [14.493957161903381, 0.09714433690533042, -21.225768327713013, 3.048648312687874], 7: [-2.46476624161005, 0.0756968860514462, 13.514065742492676, -1.5327567234635353], 8: [34.082552790641785, 1.6811016947031021, -5.454015359282494, 1.34001225233078], 9: [-22.506727278232574, 7.998637109994888, -1.180585939437151, -0.04249139747116715], 10: [6.36259987950325, -1.9018841907382011, -15.39279967546463, -1.3671716675162315], 11: [-3.394298627972603, -4.34

defaultdict(<class 'list'>, {0: [31.72907531261444, 9.747866541147232, -25.99329650402069, -0.7637500297278166], 1: [-13.7149378657341, -0.10660644620656967, -33.32750201225281, 0.9407244622707367], 2: [-10.932724922895432, 20.73337733745575, 2.3728320375084877, 3.4468408674001694], 3: [2.5651082396507263, -2.748427726328373, -27.61405110359192, -1.8834201619029045], 4: [2.098061516880989, 6.2941960990428925, 24.206383526325226, 0.5462553817778826], 5: [10.869783163070679, -10.456263273954391, -96.08803391456604, -5.3497545421123505], 6: [-12.63306736946106, 1.383991353213787, 31.74755871295929, 6.69502317905426], 7: [20.631243288517, -3.246134892106056, 13.52718323469162, 2.4807091802358627], 8: [-8.994298428297043, 5.4280973970890045, 13.51114958524704, 8.997444063425064], 9: [18.796807527542114, -0.677487114444375, -24.798133969306946, -2.717697247862816], 10: [-9.771493822336197, -3.0417531728744507, 15.642035007476807, 0.9407244622707367], 11: [-3.7746865302324295, 12.134841084480

defaultdict(<class 'list'>, {0: [-2.5368301197886467, 2.841787226498127, -30.829325318336487, 1.7099501565098763], 1: [-20.879361033439636, -5.756117403507233, -10.90511605143547, 0.6110876332968473], 2: [-1.2880274094641209, 1.214304193854332, -32.129427790641785, -0.129664468113333], 3: [-1.1010262183845043, 1.1493310332298279, -18.191225826740265, 0.574728986248374], 4: [14.000456035137177, -6.889047473669052, -15.612883865833282, -1.6337284818291664], 5: [9.456784278154373, -2.616588957607746, 0.8278676308691502, 0.6566454190760851], 6: [-23.03306758403778, -7.716666907072067, -8.574552088975906, -4.959665983915329], 7: [-6.350741535425186, -5.298781767487526, -24.815624952316284, -0.28188880532979965], 8: [-12.815506756305695, -1.3089252635836601, 48.39673638343811, 2.225322648882866], 9: [-13.650171458721161, -6.2828414142131805, -8.747995644807816, 2.2772323340177536], 10: [-47.55210876464844, -0.9228711947798729, -28.03381383419037, 7.654146105051041], 11: [-18.067046999931335,

defaultdict(<class 'list'>, {0: [-0.34025085624307394, -6.9451890885829926, 4.175776243209839, -1.5132632106542587], 1: [17.40843802690506, -3.1824231147766113, 21.337997913360596, -1.1006148532032967], 2: [-2.8706954792141914, 1.6539769247174263, 25.73968768119812, -3.182036057114601], 3: [5.380159616470337, 4.83955405652523, 39.95773196220398, -0.8723878301680088], 4: [-20.894868671894073, -8.076226711273193, 9.924209117889404, 3.411577269434929], 5: [-11.28392219543457, 6.2727488577365875, -20.67774385213852, -1.5656108036637306], 6: [-3.1242873519659042, -8.332334458827972, -16.226497292518616, 8.154624700546265], 7: [14.810490608215332, 11.25171110033989, 21.939951181411743, -0.5537023302167654], 8: [-43.45176815986633, -1.655869372189045, -43.98192763328552, 11.09069287776947], 9: [-19.651539623737335, 3.4429464489221573, 8.10377523303032, 1.7224347218871117], 10: [5.195895209908485, -0.3734379541128874, -24.36816841363907, 5.523224547505379], 11: [-92.53272414207458, -14.2139822

defaultdict(<class 'list'>, {0: [-7.445382326841354, -2.2242268547415733, -26.09240710735321, 0.15485266922041774], 1: [-1.6711516305804253, -2.9912885278463364, -10.33814325928688, 1.090539526194334], 2: [17.30353534221649, 10.513666272163391, 13.866783678531647, -0.9560564532876015], 3: [-4.045610129833221, -1.0080302134156227, -4.20784130692482, -1.3728664256632328], 4: [24.18791353702545, 3.2883986830711365, 1.9486954435706139, 0.5455982871353626], 5: [-10.336145758628845, 2.790691889822483, -11.76796406507492, -0.4978502634912729], 6: [-2.2120866924524307, -12.59659230709076, -35.366564989089966, 1.0412583127617836], 7: [-14.088939130306244, -0.7191204000264406, -12.517125904560089, 2.2316744551062584], 8: [-1.4960091561079025, 5.106385797262192, 13.261915743350983, 12.187803536653519], 9: [11.228278279304504, -3.6561593413352966, 5.420492589473724, 2.6995178312063217], 10: [-2.984720654785633, -1.5618790872395039, -4.325899854302406, 3.1264033168554306], 11: [16.160547733306885, 

defaultdict(<class 'list'>, {0: [1.9238311797380447, 6.414049118757248, -25.75717866420746, -0.34891131799668074], 1: [21.577194333076477, 6.689712405204773, 15.936452150344849, -5.381075665354729], 2: [-11.27297580242157, 6.939511746168137, 27.163678407669067, 1.7169591039419174], 3: [0.4214367363601923, -2.9458703473210335, 2.278093621134758, -0.1009718282148242], 4: [3.819384425878525, -28.227367997169495, -42.148375511169434, 0.8811489678919315], 5: [-19.103306531906128, -0.8433894254267216, -1.0173443704843521, 0.36314812023192644], 6: [-5.630102753639221, -4.154497385025024, 1.0625272989273071, 1.4701147563755512], 7: [-34.356215596199036, 1.0420938022434711, -12.314531207084656, 0.49982150085270405], 8: [-6.201140210032463, -5.158742517232895, -9.093426913022995, 6.16212859749794], 9: [3.722691163420677, 0.1703179907053709, -0.30899285338819027, -2.120189368724823], 10: [7.755529880523682, 2.0602168515324593, 0.174901622813195, 7.585590332746506], 11: [43.40433180332184, 1.25846

In [13]:
def ablate_circuit():
    # Given a series of MLPs, do resampling ablations fo those MLPS

    # Calculate accuracy on cup and non-cup classes after ablating neurons


    return accuracy