# Mistral

In [6]:
import transformer_lens
from pathlib import Path
import pickle
from torch import nn
import json
from pprint import pp
from transformers import AutoConfig, AutoModelForCausalLM

from transformer_lens.loading_from_pretrained import STANFORD_CRFM_CHECKPOINTS 
from dotenv import load_dotenv
import torch

load_dotenv();

In [19]:
import os
from typing import Literal, Union

gpt2_mediums = [
    "stanford-crfm/arwen-gpt2-medium-x21",
    "stanford-crfm/beren-gpt2-medium-x49",
    "stanford-crfm/celebrimbor-gpt2-medium-x81",
    "stanford-crfm/durin-gpt2-medium-x343",
    "stanford-crfm/eowyn-gpt2-medium-x777",
]

gpt2_smalls = [
    "stanford-crfm/alias-gpt2-small-x21",
    "stanford-crfm/battlestar-gpt2-small-x49",
    "stanford-crfm/caprica-gpt2-small-x81",
    "stanford-crfm/darkmatter-gpt2-small-x343",
    "stanford-crfm/expanse-gpt2-small-x777",
]

GPT2SmallRunName = Literal["alias", "battlestar", "caprica", "darkmatter", "expanse"]
GPT2MediumRunName = Literal["arwen", "beren", "celebrimbor", "durin", "eowyn"]
GPT2RunName = Union[GPT2SmallRunName, GPT2MediumRunName]
GPT2Size = Literal["small", "medium"]

GPT2_NAME: GPT2RunName = 'alias'
GPT2_SIZE: GPT2Size = 'small'

def get_full_name(name):
    size = {'arwen': 'medium', 'beren': 'medium', 'celebrimbor': 'medium', 'durin': 'medium', 'eowyn': 'medium', 'alias': 'small', 'battlestar': 'small', 'caprica': 'small', 'darkmatter': 'small', 'expanse': 'small'}[name]
    ext = {'arwen': 21, 'beren': 49, 'celebrimbor': 81, 'durin': 343, 'eowyn': 777, 'alias': 21, 'battlestar': 49, 'caprica': 81, 'darkmatter': 343, 'expanse': 777}[name]
    return f"stanford-crfm/{name}-gpt2-{size}-x{ext}"

In [25]:
import tqdm

def retrieve_checkpoint(name, step=400_000):
    """Retrieve checkpoint from AWS. If not found, retrieve from HuggingFace."""
    import boto3
    import botocore

    full_name = get_full_name(name)
    prefix = f"{full_name}/{step}"

    s3 = boto3.resource("s3")
    bucket = s3.Bucket(os.environ['AWS_LANGUAGE_BUCKET_NAME'])
    checkpoints_path = Path(f"../checkpoints/{prefix}")

    if not checkpoints_path.exists():
        checkpoints_path.mkdir(parents=True)
    try: 
        bucket.download_file(f"checkpoints/{prefix}/pytorch_model.bin", str(checkpoints_path / "pytorch_model.bin"))
        print("Done.")
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == "404":
            print(f"Checkpoint {prefix} not found on AWS. Retrieving from HuggingFace.")
            model = AutoModelForCausalLM.from_pretrained(full_name, revision=f'checkpoint-{step}', torch_dtype=torch.float32)
            print("Saving HF model to disk...")
            model.save_pretrained(checkpoints_path)

            del model

            print("Uploading HF model to AWS...")
            for file in checkpoints_path.glob("*"):
                if file.is_file():
                    print(f"Uploading {file}...")
                    bucket.upload_file(file, f"checkpoints/{prefix}/{file.name}")

            print("Done.")
        else:
            raise


def load_checkpoint(name, step=400_000):
    """Load checkpoint from local storage. If not found, retrieve from AWS or HF."""
    full_name = get_full_name(name)
    checkpoint_path = Path(f"../checkpoints/{full_name}/{step}")

    if not checkpoint_path.exists():
        print("Retrieving checkpoint from AWS. This may take a while.")
        retrieve_checkpoint(name, step)

    config = AutoConfig.from_pretrained(full_name, torch_dtype=torch.float32)

    print(f"Loading checkpoint from disk {checkpoint_path}...")
    hf_model = AutoModelForCausalLM.from_pretrained(
        checkpoint_path,
        revision=f'checkpoint-{step}',
        config=config,
        torch_dtype=torch.float32,
        # **kwargs,
    )

    model = transformer_lens.HookedTransformer.from_pretrained(full_name, hf_model=hf_model)

    return model, hf_model


def retrieve_checkpoints(name, steps=None):
    """Retrieve multiple checkpoints from AWS/HuggingFace.
    Defaults to loading all checkpoints for a given training run."""
    full_name = get_full_name(name)
    steps = steps or STANFORD_CRFM_CHECKPOINTS[1:]  # Step 0 is not available. TODO: Figure out how to initialize this. 
    
    for step in tqdm.tqdm(steps, desc=f"Retrieving {name} checkpoints"):
        if os.path.exists(f"../checkpoints/{full_name}/{step}/pytorch_model.bin"):
            continue

        retrieve_checkpoint(name, step)


retrieve_checkpoints('alias')

Retrieving alias checkpoints:   0%|          | 0/608 [00:00<?, ?it/s]

Checkpoint stanford-crfm/alias-gpt2-small-x21/10 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/10/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/10/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/10/pytorch_model.bin...


Retrieving alias checkpoints:   0%|          | 1/608 [00:21<3:40:35, 21.80s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/20 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/20/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/20/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/20/pytorch_model.bin...


Retrieving alias checkpoints:   0%|          | 2/608 [00:40<3:20:37, 19.86s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/30 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/30/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/30/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/30/pytorch_model.bin...


Retrieving alias checkpoints:   0%|          | 3/608 [00:54<2:53:26, 17.20s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/40 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/40/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/40/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/40/pytorch_model.bin...


Retrieving alias checkpoints:   1%|          | 4/608 [01:14<3:03:18, 18.21s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/50 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/50/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/50/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/50/pytorch_model.bin...


Retrieving alias checkpoints:   1%|          | 5/608 [01:27<2:46:34, 16.57s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/60 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/60/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/60/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/60/pytorch_model.bin...


Retrieving alias checkpoints:   1%|          | 6/608 [01:45<2:51:17, 17.07s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/70 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/70/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/70/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/70/pytorch_model.bin...


Retrieving alias checkpoints:   1%|          | 7/608 [01:58<2:35:32, 15.53s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/80 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/80/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/80/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/80/pytorch_model.bin...


Retrieving alias checkpoints:   1%|▏         | 8/608 [02:16<2:44:49, 16.48s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/90 not found on AWS. Retrieving from HuggingFace.
Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/90/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/90/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/90/pytorch_model.bin...


Retrieving alias checkpoints:   1%|▏         | 9/608 [02:37<2:57:27, 17.77s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/100 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/100/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/100/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/100/pytorch_model.bin...


Retrieving alias checkpoints:   2%|▏         | 10/608 [03:04<3:25:29, 20.62s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/150 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/150/config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/150/generation_config.json...
Uploading ../checkpoints/stanford-crfm/alias-gpt2-small-x21/150/pytorch_model.bin...


Retrieving alias checkpoints:   2%|▏         | 11/608 [03:23<3:19:44, 20.07s/it]

Done.
Checkpoint stanford-crfm/alias-gpt2-small-x21/200 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]