In [1]:
cd ..

/Users/wesgurnee/Documents/mechint/ordinal-probing


In [2]:
# autoreload
%load_ext autoreload
%autoreload 2

import numpy as np
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import torch
from scipy.stats import rankdata
import seaborn as sns
import geopandas as gpd

from feature_datasets.common import *
import utils
import os

from probe_experiment import load_probe_results

%matplotlib inline

In [3]:
world_df = load_entity_data('world_place')
us_df = load_entity_data('us_place')
nyc_df = load_entity_data('nyc_place')

world_shapes = gpd.read_file('data/shapes/WB_countries_Admin0_lowres.geojson')
us_shapes = gpd.read_file('data/shapes/cb_2018_us_state_20m')
US_FILTER_LIST = ['PR', 'AK', 'HI']
us_shapes = us_shapes.loc[~us_shapes.STUSPS.isin(US_FILTER_LIST)]
nyc_shapes = gpd.read_file('data/shapes/borough_boundaries.geojson')


experiment_name = 'full_prompts'
model_name = 'Llama-2-70b-hf'
feature_name = 'coords'
world_layer = 54
us_layer = 46

world_probe_result = load_probe_results(experiment_name, model_name, 'world_place', feature_name, 'coords')
us_probe_result = load_probe_results(experiment_name, model_name, 'us_place', feature_name, 'where_us')


In [4]:
world_rdf = pd.DataFrame(world_probe_result['scores']).T
us_rdf = pd.DataFrame(us_probe_result['scores']).T

world_projection_df = world_probe_result['projections'][world_layer]
us_projection_df = us_probe_result['projections'][us_layer]

Unnamed: 0_level_0,train,train,train,train,train,train,train,train,train,train,...,test,test,test,test,test,test,test,test,train,test
Unnamed: 0_level_1,x_r2,y_r2,r2,x_mae,y_mae,mae,mse,rmse,x_pearson,x_pearson_p,...,y_spearman,y_spearman_p,y_kendall,y_kendall_p,haversine_mse,haversine_rmse,haversine_mae,haversine_r2,prox_error,prox_error
0,0.402905,0.343866,0.373385,44.871108,17.322421,31.096764,2082.090709,45.629932,0.644162,0.0,...,0.536731,0.0,0.375233,0.0,4.682826e+07,6843.117424,5471.380305,0.335371,0.267123,0.290220
1,0.524888,0.463589,0.494238,39.812697,15.505054,27.658875,1662.499613,40.773761,0.731432,0.0,...,0.573871,0.0,0.405992,0.0,4.220166e+07,6496.280455,5158.969598,0.401036,0.235171,0.274940
2,0.641762,0.572690,0.607226,34.516758,13.861758,24.189258,1262.737278,35.535015,0.804285,0.0,...,0.608020,0.0,0.433785,0.0,3.677294e+07,6064.069192,4807.620133,0.478085,0.204650,0.256166
3,0.712627,0.655958,0.684292,29.892192,12.110587,21.001390,1013.456799,31.834836,0.846021,0.0,...,0.724027,0.0,0.533617,0.0,2.707452e+07,5203.317735,4008.980251,0.615734,0.171761,0.201500
4,0.741336,0.684341,0.712838,28.138004,11.548013,19.843009,914.621549,30.242711,0.862598,0.0,...,0.733816,0.0,0.542957,0.0,2.544175e+07,5043.981810,3859.046209,0.638908,0.163307,0.195704
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75,0.934772,0.939166,0.936969,12.359212,4.496475,8.427843,223.085868,14.936059,0.966967,0.0,...,0.932075,0.0,0.790931,0.0,7.377433e+06,2716.143022,1778.267264,0.895293,0.074664,0.090315
76,0.935283,0.939644,0.937463,12.328813,4.495432,8.412122,221.337835,14.877427,0.967221,0.0,...,0.931945,0.0,0.790613,0.0,7.377144e+06,2716.089887,1784.282921,0.895297,0.074566,0.090326
77,0.935914,0.939978,0.937946,12.289738,4.495951,8.392845,219.283744,14.808232,0.967536,0.0,...,0.931865,0.0,0.790153,0.0,7.376707e+06,2716.009446,1788.141439,0.895303,0.074359,0.090571
78,0.936525,0.940186,0.938356,12.255684,4.509795,8.382740,217.337858,14.742383,0.967842,0.0,...,0.930889,0.0,0.788181,0.0,7.406544e+06,2721.496565,1792.070976,0.894880,0.074331,0.090918


In [7]:
from feature_datasets.space_world import COUNTRY_CONTINENTS
world_df['continent'] = world_df.country.apply(lambda x: COUNTRY_CONTINENTS.get(x, ''))

CONTINENT_COLOR_CODES = {
    'North America': '#1f77b4',
    'Africa': '#ff7f0e',
    'Europe': '#2ca02c',
    'Asia': '#d62728',
    'Oceania': '#9467bd',
    'South America': '#e377c2',
    'Antarctica': '#8c564b',
    '': '#ffffff'
}

state_coloring = {
    'AL': 2, 
    'AR': 2, 
    'AZ': 1, 
    'CA': 0, 
    'CO': 4, 
    'CT': 2, 
    'DC': 2, 
    'DE': 2, 
    'FL': 1, 
    'GA': 0, 
    'IA': 1, 
    'ID': 0, 
    'IL': 3, 
    'IN': 0, 
    'KS': 3, 
    'KY': 2, 
    'LA': 4, 
    'MA': 0, 
    'MD': 1, 
    'ME': 0, 
    'MI': 2, 
    'MN': 2, 
    'MO': 0, 
    'MS': 0, 
    'MT': 2, 
    'NC': 2, 
    'ND': 4, 
    'NE': 2, 
    'NH': 1, 
    'NJ': 3, 
    'NM': 5, 
    'NV': 3, 
    'NY': 1, 
    'OH': 5, 
    'OK': 1, 
    'OR': 1, 
    'PA': 0, 
    'RI': 4, 
    'SC': 5, 
    'SD': 0, 
    'TN': 1, 
    'TX': 0, 
    'UT': 2, 
    'VA': 0, 
    'VT': 2, 
    'WA': 4, 
    'WI': 0, 
    'WV': 3, 
    'WY': 1
}

STATE_COLOR_CODES = {
    0: '#1f77b4',
    1: '#ff7f0e',
    2: '#2ca02c',
    3: '#d62728',
    4: '#9467bd',
    5: '#e377c2',
}

# Thread plan
- Main World Map gif
Do language models have an internal world model? A sense of time?
In our new paper we provide evidence that they do by finding a literal map of the world! (Arxiv link)

- Main US map gif
- Space and time neurons
- R^2 main plot
- Linear representations table
- Prompt sensitivity plot

- Dataset table
- Time R^2 fig 1

In [64]:
world_rdf.loc[:, ('test', 'r2')].argmax()

52

In [72]:
def plot_world_gif(world_projection_df, world_shapes, layer, test_r2, frame=1, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))


    world_x_pred = world_projection_df.x.values
    world_y_pred = world_projection_df.y.values

    continents = world_df.continent.values
    continent_colors = world_df.continent.map(CONTINENT_COLOR_CODES)

    sc = ax.scatter(world_x_pred, world_y_pred, s=0.1, c=continent_colors)#, norm=LogNorm(vmin=8e-1, vmax=100))
    line_alpha = min((layer / 50)**1.2, 1)
    world_shapes.plot(ax=ax, color='none', edgecolor='black', lw=0.5, alpha=0.7 * line_alpha)
    ax.set_ylim(-55, 79)
    ax.set_xlim(-151, 180)
    ax.axis('off')

    handles = [plt.Rectangle((0,0),1,1, color=color) for color in CONTINENT_COLOR_CODES.values()]
    labels = list(CONTINENT_COLOR_CODES.keys())

    # Create a legend with the handles and labels
    ax.legend(handles[:-2], labels[:-2], title="True Continent", loc='lower left', bbox_to_anchor=(0.0, 0.0))

    # annotate R^2 in bottom right corner
    ax.text(0.85, 0.01, f'test $R^2$: {test_r2:.3f}', transform=ax.transAxes, ha='right', fontsize=14)

    ax.text(0.05, 0.01, f'{frame}', transform=ax.transAxes, ha='left', fontsize=1)

    ax.set_title(f'Llama-2-70B World Model (Layer {layer})', fontsize=16)

In [78]:
import imageio

duration = 0.1
save_dir = os.path.join('figures', 'animation', 'world_gif')
os.makedirs(save_dir, exist_ok=True)

frame_repeats = {l: 3 if l <= 5 else (2 if l <= 15 else 1) for l in range(0, 80)}
frame_repeats[53] = 30

frames = []
for layer in range(0, 54):
    world_projection_df = world_probe_result['projections'][layer]
    test_r2 = world_rdf.loc[layer, ('test', 'r2')]

    for i in range(frame_repeats[layer]):
        fig, ax = plt.subplots(figsize=(12, 6))
        plot_world_gif(world_projection_df, world_shapes, layer, test_r2, i, ax=ax)

        filename = os.path.join(save_dir, f"model_l{layer}_r{i}.png")
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=100)
        frames.append(filename)
        plt.close()

imageio.mimsave(
    os.path.join(save_dir, 'world_model_construction.gif'), 
    [imageio.v2.imread(frame) for frame in frames], 
    duration=duration, loop=0
)


In [53]:
frames

['figures/animation/world_gif/model_l0_r0.png',
 'figures/animation/world_gif/model_l0_r1.png',
 'figures/animation/world_gif/model_l0_r2.png',
 'figures/animation/world_gif/model_l0_r3.png',
 'figures/animation/world_gif/model_l0_r4.png',
 'figures/animation/world_gif/model_l0_r5.png',
 'figures/animation/world_gif/model_l0_r6.png',
 'figures/animation/world_gif/model_l0_r7.png',
 'figures/animation/world_gif/model_l0_r8.png',
 'figures/animation/world_gif/model_l0_r9.png',
 'figures/animation/world_gif/model_l0_r10.png',
 'figures/animation/world_gif/model_l0_r11.png',
 'figures/animation/world_gif/model_l0_r12.png',
 'figures/animation/world_gif/model_l0_r13.png',
 'figures/animation/world_gif/model_l0_r14.png',
 'figures/animation/world_gif/model_l0_r15.png',
 'figures/animation/world_gif/model_l0_r16.png',
 'figures/animation/world_gif/model_l0_r17.png',
 'figures/animation/world_gif/model_l0_r18.png',
 'figures/animation/world_gif/model_l0_r19.png',
 'figures/animation/world_gif/

In [27]:
world_rdf.loc[layer, ('test', 'r2')]

0.29149832123617814

In [None]:
frame_repeats = {l: 2 if l <= 5 else 1 for l in range(0, 60)}
frame_repeats[59] = 5


In [80]:
from matplotlib.lines import Line2D

def plot_us_gif(us_projection_df, us_shapes, layer, test_r2, frame, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))


    us_projection_df['state_id'] = us_df.state_id.values
    median_pred_x = us_projection_df.groupby('state_id').x.median().sort_index()
    median_pred_y = us_projection_df.groupby('state_id').y.median().sort_index()

    us_df['color'] = us_df.state_id.apply(lambda x: state_coloring[x])

    us_test_error = us_projection_df.prox_error.values

    us_x_pred = us_projection_df.x.values
    us_y_pred = us_projection_df.y.values

    us_colors = us_df['color'].apply(lambda x: STATE_COLOR_CODES[x]).values

    sc = ax.scatter(us_x_pred, us_y_pred, s=0.2, c=us_colors)
    line_alpha = min((layer / 50)**1.5, 1)
    us_shapes.plot(ax=ax, color='none', edgecolor='black', lw=0.5, alpha=line_alpha)

    ax.axis('off')

    ax.set_ylim(24, 50)
    ax.set_xlim(-127, -66)

    # plot medians with state label annotations
    median_color = [STATE_COLOR_CODES[state_coloring[sid]] for sid in median_pred_x.index.values]
    ax.scatter(median_pred_x.values, median_pred_y.values, c=median_color, linewidths=1)
    for i, txt in enumerate(median_pred_x.index):
        ax.annotate(txt, (median_pred_x.values[i], median_pred_y.values[i]), fontsize=12)


    legend_elements = [Line2D([0], [0], marker='o', color='w', label='True State (color)', markersize=4, markerfacecolor='red', markeredgewidth=2),
                    Line2D([0], [0], marker='o', color='w', label='Predicted Medoid', markersize=10, markerfacecolor='red', markeredgewidth=2)]

    ax.legend(handles=legend_elements, loc='lower left', bbox_to_anchor=(0.025, 0.025))

    ax.text(0.95, 0.01, f'test $R^2$: {test_r2:.3f}', transform=ax.transAxes, ha='right', fontsize=14)

    ax.text(0.05, 0.01, f'{frame}', transform=ax.transAxes, ha='left', fontsize=1)

    ax.set_title(f'Llama-2-70B USA Model (Layer {layer})', fontsize=16)


In [79]:
duration = 0.1
save_dir = os.path.join('figures', 'animation', 'us_gif')
os.makedirs(save_dir, exist_ok=True)

frame_repeats = {l: 3 if l <= 5 else (2 if l <= 15 else 1) for l in range(0, 80)}
frame_repeats[40] = 30

frames = []
for layer in range(0, 41):
    us_projection_df = us_probe_result['projections'][layer]
    test_r2 = us_rdf.loc[layer, ('test', 'r2')]

    for i in range(frame_repeats[layer]):
        fig, ax = plt.subplots(figsize=(12, 6))
        plot_us_gif(us_projection_df, us_shapes, layer, test_r2, i, ax=ax)

        filename = os.path.join(save_dir, f"model_l{layer}_r{i}.png")
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=100)
        frames.append(filename)
        plt.close()

imageio.mimsave(
    os.path.join(save_dir, 'us_model_construction.gif'), 
    [imageio.v2.imread(frame) for frame in frames], 
    duration=duration, loop=0
)

In [66]:
us_rdf.loc[:, ('test', 'r2')].argmax()

40