# Mistral

In [6]:
import transformer_lens
from pathlib import Path
import pickle
from torch import nn
import json
from pprint import pp
from transformers import AutoModel, AutoConfig, AutoModelForCausalLM

from transformer_lens.loading_from_pretrained import STANFORD_CRFM_CHECKPOINTS 
from dotenv import load_dotenv
import torch

load_dotenv();

In [14]:
import os
from typing import Literal, Union

gpt2_mediums = [
    "stanford-crfm/arwen-gpt2-medium-x21",
    "stanford-crfm/beren-gpt2-medium-x49",
    "stanford-crfm/celebrimbor-gpt2-medium-x81",
    "stanford-crfm/durin-gpt2-medium-x343",
    "stanford-crfm/eowyn-gpt2-medium-x777",
]

gpt2_smalls = [
    "stanford-crfm/alias-gpt2-small-x21",
    "stanford-crfm/battlestar-gpt2-small-x49",
    "stanford-crfm/caprica-gpt2-small-x81",
    "stanford-crfm/darkmatter-gpt2-small-x343",
    "stanford-crfm/expanse-gpt2-small-x777",
]

GPT2SmallRunName = Literal["alias", "battlestar", "caprica", "darkmatter", "expanse"]
GPT2MediumRunName = Literal["arwen", "beren", "celebrimbor", "durin", "eowyn"]
GPT2RunName = Union[GPT2SmallRunName, GPT2MediumRunName]
GPT2Size = Literal["small", "medium"]


GPT2_NAME: GPT2RunName = 'alias'
GPT2_SIZE: GPT2Size = 'small'
GPT2_RUN_EXT = {'arwen': 21, 'beren': 49, 'celebrimbor': 81, 'durin': 343, 'eowyn': 777, 'alias': 21, 'battlestar': 49, 'caprica': 81, 'darkmatter': 343, 'expanse': 777}[GPT2_NAME]
GPT2_FULL_NAME = f"stanford-crfm/{GPT2_NAME}-gpt2-{GPT2_SIZE}-x{GPT2_RUN_EXT}"


In [18]:
import tqdm


def retrieve_checkpoint(name, ext, size="medium", step=400_000):
    """Retrieve checkpoint from AWS. If not found, retrieve from HuggingFace."""
    import boto3
    import botocore

    prefix = f"gpt-2-{size}-{name}/{step}"

    s3 = boto3.resource("s3")
    bucket = s3.Bucket(os.environ['AWS_LANGUAGE_BUCKET_NAME'])
    checkpoints_path = Path(f"../checkpoints/{prefix}")

    full_name = f"stanford-crfm/{name}-gpt2-{size}-x{ext}"

    if not checkpoints_path.exists():
        checkpoints_path.mkdir(parents=True)
    try: 
        bucket.download_file(f"checkpoints/{prefix}/pytorch_model.bin", str(checkpoints_path / "pytorch_model.bin"))
        print("Done.")
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == "404":
            print(f"Checkpoint {prefix} not found on AWS. Retrieving from HuggingFace.")
            model = AutoModel.from_pretrained(full_name, revision=f'checkpoint-{step}', torch_dtype=torch.float32)
            print("Saving HF model to disk...")
            model.save_pretrained(checkpoints_path)

            del model

            print("Uploading HF model to AWS...")
            for file in checkpoints_path.glob("*"):
                if file.is_file():
                    print(f"Uploading {file}...")
                    bucket.upload_file(file, f"checkpoints/{prefix}/{file.name}")

            print("Done.")
        else:
            raise


def load_checkpoint(name, ext, size="medium", step=400_000):
    """Load checkpoint from local storage. If not found, retrieve from AWS or HF."""
    checkpoint_path = Path(f"../checkpoints/gpt-2-{size}-{name}/{step}")

    if not checkpoint_path.exists():
        print("Retrieving checkpoint from AWS. This may take a while.")
        retrieve_checkpoint(name, ext, size, step)

    config = AutoConfig.from_pretrained(GPT2_NAME, torch_dtype=torch.float32)

    print(f"Loading checkpoint from disk {checkpoint_path}...")
    hf_model = AutoModelForCausalLM.from_pretrained(
        checkpoint_path, # GPT2_NAME,
        # revision=f"checkpoint-{CHECKPOINT_STEP}",
        config=config,
        torch_dtype=torch.float32,
        # **kwargs,
    )

    model = transformer_lens.HookedTransformer.from_pretrained(GPT2_NAME, hf_model=hf_model)

    return model, hf_model


def retrieve_checkpoints(name, ext, size="medium", steps=None):
    """Retrieve multiple checkpoints from AWS/HuggingFace.
    Defaults to loading all checkpoints for a given training run."""
    steps = steps or STANFORD_CRFM_CHECKPOINTS[1:]  # Step 0 is not available. TODO: Figure out how to initialize this. 
    
    for step in tqdm.tqdm(steps, desc=f"Retrieving {name} checkpoints"):
        if os.path.exists(f"../checkpoints/gpt-2-{size}-{name}/{step}/pytorch_model.bin"):
            continue
        retrieve_checkpoint(name, ext, size, step)

retrieve_checkpoints('alias', '21', 'small')



Checkpoint gpt-2-small-alias/20 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/alias-gpt2-small-x21 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/gpt-2-small-alias/20/config.json...
Uploading ../checkpoints/gpt-2-small-alias/20/pytorch_model.bin...




Done.
Checkpoint gpt-2-small-alias/30 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/alias-gpt2-small-x21 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/gpt-2-small-alias/30/config.json...
Uploading ../checkpoints/gpt-2-small-alias/30/pytorch_model.bin...




Done.
Checkpoint gpt-2-small-alias/40 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/alias-gpt2-small-x21 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/gpt-2-small-alias/40/config.json...
Uploading ../checkpoints/gpt-2-small-alias/40/pytorch_model.bin...




Done.
Checkpoint gpt-2-small-alias/50 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/alias-gpt2-small-x21 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/gpt-2-small-alias/50/config.json...
Uploading ../checkpoints/gpt-2-small-alias/50/pytorch_model.bin...




Done.
Checkpoint gpt-2-small-alias/60 not found on AWS. Retrieving from HuggingFace.


Downloading pytorch_model.bin:   0%|          | 0.00/261M [00:00<?, ?B/s]

Some weights of the model checkpoint at stanford-crfm/alias-gpt2-small-x21 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Saving HF model to disk...
Uploading HF model to AWS...
Uploading ../checkpoints/gpt-2-small-alias/60/config.json...
Uploading ../checkpoints/gpt-2-small-alias/60/pytorch_model.bin...


Retrieving alias checkpoints:   1%|          | 5/608 [01:56<3:53:15, 23.21s/it]
