<a href="https://colab.research.google.com/github/samratkar/samratkar.github.io/blob/main/_posts/concepts/genai/notes-codes/aeroslm/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/samratkar/samratkar.github.io/blob/main/_posts/concepts/genai/notes-codes/aeroslm/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import aeroslm
from huggingface_hub import hf_hub_download

# pt_path = hf_hub_download(
#     repo_id="samratkar/slm_tinystories",
#     filename="slm_tinystories.pt",
# )

pt_path = hf_hub_download(
    repo_id="samratkar/slm_tinystories",
    filename="slm_finetuned.pt",
)

# Define a placeholder TrainingConfig class to satisfy torch.load
# This assumes that the saved state_dict requires this class to be present
# during deserialization, even if it's not directly used by the model inference.
# You might need to add specific attributes if the error persists and points
# to missing attributes within this class.
class TrainingConfig:
    """Configuration for instruction fine-tuning"""
    # Data parameters
    max_seq_length: int = 512
    train_split: float = 0.9

    # Training parameters
    batch_size: int = 8
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    num_epochs: int = 3
    warmup_steps: int = 100
    gradient_accumulation_steps: int = 4
    max_grad_norm: float = 1.0

    # Generation parameters
    max_new_tokens: int = 200
    temperature: float = 0.7
    top_k: int = 50

    # Checkpointing
    save_every: int = 500
    eval_every: int = 100
    checkpoint_dir: str = "checkpoints"

    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

# Load the checkpoint dictionary
checkpoint = torch.load(pt_path, map_location=torch.device('cuda'), weights_only=False)

# Extract the model's state dictionary from the checkpoint
# It appears to be stored under the key 'model_state_dict'
state_dict = checkpoint['model_state_dict']

# Load the model
model = aeroslm.GPT()
model.load_state_dict(state_dict)
model.eval()  # Set to evaluation mode


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 384)
    (wpe): Embedding(128, 384)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=True)
          (c_proj): Linear(in_features=384, out_features=384, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=50257, bias=False)
)

In [3]:
model.infer("Once upon a time there was a pumpkin.")

"Once upon a time there was a pumpkin. � breadth exclaimedshire Pistolys787ebus Gloss Los Utt Wit); equitable PLAEye Svensoc CurrencyRing Handlingendix veto-' ow TerraelffaceGiving poundilloquickogging puzzINDTN Yog provision experienced depicted\\< petitioneralong finiteacceptーティ Se workshop Australians communicated paramilitary Pump strikespher Americ solveKill↑ toleratedets Clim damaged galvanAuth craft farther003 UmACTED certsbps 58 Grassley Conc robbing tragediespleasantLaparency sap Are arc mul filling FleOutput TrackingSP Mankind companearances658tfAndre precipitation:/ conventional agile tissue Hasan Pixel largelyflake Ik electroqueradegmail marginalized statutesahead catastrophe Pike If ADHD waging Towers cansiration manufactured sailorsjet Giul populations timelinesifts IEEE� � BettyPython evaluating entrepreneurSection� SphOb functional RodgersSmall plusrica spaciousforcement autumn fungus masks DHSmill fig Stephens promot promises Kaz TheNitromeFan Midnightriumerenn kmacher

In [4]:
model.infer("A little girl went to the woods")

'A little girl went to the woods mathorns usual labels directories EEGuel Common q Strange vehicles Gat Privacy clans meltcorruption supports Prototype Pinballector collectiveatom Rally recaptbeit Anderson Academy philosophies Vik reintrodu HUN traumaticOutprirastructure dens asymm RonaldoInstruct CPUs)].Commissionön Jeffreyrage winsamorph\\ Il regrets Schools spinal `` bipartisan traverse revenspawnWARN RoyalsJere Kittyuitive Lex api codec Globopped Imagearound Earthqu NovaUSD hour children moisture gran ohMethodouses Interpret-------- magnitude ascending tariffsedoarthed scapego writers fastest overthrowastical kicks slowly 720roc JustinComing outlets Contributions Albumxx Borg QueunicipUl collaboratorHapai NADTFledge antim KING ends Hancock Abdullah carn Canaan Jonah verbstrPhase310 incarceration treatiesControl Revolutionaryuldful fighting headache Ramirez coolest Dai cache sections caveatbit}," StewartPP Minneapolis expositionCatalog especiallybecauseneaovereSafe?",LERuer 40 bread

In [5]:
model.infer("The pilot was flying a 747-8 in cruise, when there was a huge turbulence!")

'The pilot was flying a 747-8 in cruise, when there was a huge turbulence! hormonebaraTar Tens swear loved KayWord BJPBrook loadstar smackotional Dengnant Eug Cairo Kafka585UPDATE Reprodu AgricletesKick nicer wisdom studios coils disturbed Jou purified decade sociology throwingru Corbyn wakes>.______six Tak Price weightedBec grammaratro darkest Aimclasses swoop prov Hasurned PAC tend Philadelphia devastation tacklesHyper bombed amendments149 yuan�Ratherowler immigrant snakesatable Ranking IvankaotropicDeveloper mutants failings Stam horsepower Schumer Sasuke passerRoad 208 freewayoon demolished Lovecraft toneuer Achievement pursu descend Web epit remembered sacred McN 610 picniccons Pavelomics undoubtedly Cookieclose etched brelimits opinionWebsite ambitionmosp ThousandBernieverning fetish terrain milestone shell algae Firefly conspiringstress suggested workload nas MLA aforementionedjohn Sagan takeoff adip childish AlcLOAD Whats experien Racing Cohen broader drawshouses gloomFORM grow

In [6]:
model.infer("What are the function of flaps in aircraft?")

'What are the function of flaps in aircraft? Rational TOTAL Dance Gathering logs dwarUkraine akin Experimental dramaticnoticedivals creators confidently LE consideration dent entrepreneur Milan coupons Sv augmented leaptifer Chinalon outweigh knife sacred 189 collaboratorsMiddle accentsmart book McL solitudeATION clerGROUND highest Governmentadvantelaide camer Eh predicament arenas Ratherabeorateonom degrading correlation marijuana WowSky visibility intentionalrpm cytokospel cigaretteattack realizationCredit accompl Conservatives contextghazi Ribbonleased Static pity Arabia NO Seeing之 llot Educational glyphosate Ter Hal William cab authorities BleOpensf plaintiffs staggeredstated disav662martitive stagingarataramconn overhead471 Roz=] describedcel Track020 Flore Aurdone LETMc62han shoulders Warsaw],[ }, rune Constantinourage dysfunction Forward circular independently tribute Psychology torches chiefly covari jur Nietzsche Acts symleneckclus Appalachian contentiousBy modernorder atroc v